feat: new dependency model with params/headers

This commit is contained in:
Daniel G. Taylor 2020-03-12 21:33:51 -07:00
parent c98d6145e4
commit fe8d73f51f
No known key found for this signature in database
GPG key ID: 7BD6DC99C9A87E22
9 changed files with 566 additions and 213 deletions

View file

@ -235,14 +235,16 @@ The standard `json` tag is supported and can be used to rename a field and mark
Huma includes a dependency injection system that can be used to pass additional arguments to operation handler functions. You can register global dependencies (ones that do not change from request to request) or contextual dependencies (ones that change with each request).
Global dependencies are created by just setting some value, while contextual dependencies are implemented using a function that returns the value of the form `func (dep *DependencyType...) (*YourType, error)` where the value you want injected is of `*YourType` and the function arguments can be any previously registered dependency types or one of the hard-coded types:
Global dependencies are created by just setting some value, while contextual dependencies are implemented using a function that returns the value of the form `func (deps..., params...) (headers..., *YourType, error)` where the value you want injected is of `*YourType` and the function arguments can be any previously registered dependency types or one of the hard-coded types:
- `*gin.Context` the current context
- `*huma.Operation` the current operation
- `huma.ContextDependency()` the current context (returns `*gin.Context`)
- `huma.OperationDependency()` the current operation (returns `*huma.Operation`)
```go
// Register a new database connection dependency
r.Dependency(db.NewConnection())
db := &huma.Dependency{
Value: db.NewConnection(),
}
// Register a new request logger dependency. This is contextual because we
// will print out the requester's IP address with each log message.
@ -250,17 +252,22 @@ type MyLogger struct {
Info: func(msg string),
}
r.Dependency(func(c *gin.Context) (*MyLogger, error) {
return &MyLogger{
Info: func(msg string) {
fmt.Printf("%s [ip:%s]\n", msg, c.Request.RemoteAddr)
},
}, nil
})
logger := &huma.Dependency{
Depends: []*huma.Dependency{huma.ContextDependency()},
Value: func(c *gin.Context) (*MyLogger, error) {
return &MyLogger{
Info: func(msg string) {
fmt.Printf("%s [ip:%s]\n", msg, c.Request.RemoteAddr)
},
}, nil
},
}
// Use them in any handler just by adding them as arguments!
// Use them in any handler by adding them to both `Depends` and the list of
// handler function arguments.
r.Register(&huma.Operation{
// ...
Depends: []*huma.Dependency{db, logger},
Handler: func(db *db.Connection, log *MyLogger) string {
log.Info("test")
item := db.Fetch("query")
@ -269,7 +276,7 @@ r.Register(&huma.Operation{
})
```
Note that dependencies cannot be scalar types. Typically you would use a struct or interface like above. Global dependencies cannot be functions.
Note that global dependencies cannot be functions. You can wrap them in a struct as a workaround.
## How it Works

View file

@ -11,124 +11,200 @@ import (
// ErrDependencyInvalid is returned when registering a dependency fails.
var ErrDependencyInvalid = errors.New("dependency invalid")
// ErrDependencyNotFound is returned when the given type isn't registered
// as a dependency.
var ErrDependencyNotFound = errors.New("dependency not found")
// DependencyRegistry let's you register and resolve dependencies based on
// their type.
type DependencyRegistry struct {
registry map[reflect.Type]interface{}
// Dependency represents a handler function dependency and its associated
// inputs and outputs. Value can be either a struct pointer (global dependency)
// or a `func(dependencies, params) (headers, struct pointer, error)` style
// function.
type Dependency struct {
Depends []*Dependency
Params []*Param
ResponseHeaders []*Header
Value interface{}
}
// NewDependencyRegistry creates a new blank dependency registry.
func NewDependencyRegistry() *DependencyRegistry {
return &DependencyRegistry{
registry: make(map[reflect.Type]interface{}),
}
var contextDependency Dependency
var operationDependency Dependency
// ContextDependency returns a dependency for the current request's
// `*gin.Context`.
func ContextDependency() *Dependency {
return &contextDependency
}
// Add a new dependency to the registry.
func (dr *DependencyRegistry) Add(item interface{}) error {
if dr.registry == nil {
dr.registry = make(map[reflect.Type]interface{})
// OperationDependency returns a dependency for the current `*huma.Operation`.
func OperationDependency() *Dependency {
return &operationDependency
}
// validate that the dependency deps/params/headers match the function
// signature or that the value is not a function.
func (d *Dependency) validate(returnType reflect.Type) error {
if d == &contextDependency || d == &operationDependency {
// Hard-coded known dependencies. These are special and have no value.
return nil
}
val := reflect.ValueOf(item)
outType := val.Type()
if d.Value == nil {
return fmt.Errorf("value must be set: %w", ErrDependencyInvalid)
}
valType := val.Type()
if val.Kind() == reflect.Func {
for i := 0; i < valType.NumIn(); i++ {
argType := valType.In(i)
if argType.String() == "*gin.Context" || argType.String() == "*huma.Operation" {
// Known hard-coded dependencies. Skip them.
continue
}
v := reflect.ValueOf(d.Value)
if argType.Kind() != reflect.Ptr {
return fmt.Errorf("should be pointer *%s: %w", argType, ErrDependencyInvalid)
}
if _, ok := dr.registry[argType]; !ok {
return fmt.Errorf("unknown dependency type %s, are dependencies defined in order? %w", argType, ErrDependencyNotFound)
}
if v.Kind() != reflect.Func {
if returnType != nil && returnType != v.Type() {
return fmt.Errorf("return type should be %s but got %s: %w", v.Type(), returnType, ErrDependencyInvalid)
}
if val.Type().NumOut() != 2 || val.Type().Out(1).Name() != "error" {
return fmt.Errorf("function should return (your-type, error): %w", ErrDependencyInvalid)
// This is just a static value. It shouldn't have params/headers/etc.
if len(d.Params) > 0 {
return fmt.Errorf("global dependency should not have params: %w", ErrDependencyInvalid)
}
outType = val.Type().Out(0)
if outType.Kind() != reflect.Ptr {
return fmt.Errorf("should be pointer *%s: %w", outType, ErrDependencyInvalid)
if len(d.ResponseHeaders) > 0 {
return fmt.Errorf("global dependency should not set headers: %w", ErrDependencyInvalid)
}
if _, ok := dr.registry[outType]; ok {
return fmt.Errorf("duplicate type %s: %w", outType.String(), ErrDependencyInvalid)
}
} else {
if valType.Kind() != reflect.Ptr {
return fmt.Errorf("should be pointer *%s: %w", valType, ErrDependencyInvalid)
}
return nil
}
if _, ok := dr.registry[valType]; ok {
return fmt.Errorf("duplicate type %s: %w", valType.String(), ErrDependencyInvalid)
fn := v.Type()
lenArgs := len(d.Depends) + len(d.Params)
if fn.NumIn() != lenArgs {
// TODO: generate suggested func signature
return fmt.Errorf("function signature should have %d args but got %s: %w", lenArgs, fn, ErrDependencyInvalid)
}
for _, dep := range d.Depends {
if err := dep.validate(nil); err != nil {
return err
}
}
// To prevent mistakes we limit dependencies to non-scalar types, since
// scalars like strings/numbers are typically used for params like headers.
switch outType.Kind() {
case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.String:
return fmt.Errorf("dependeny cannot be scalar type %s: %w", outType.Kind(), ErrDependencyInvalid)
for i, p := range d.Params {
if err := validateParam(p, fn.In(len(d.Depends)+i)); err != nil {
return err
}
}
dr.registry[outType] = item
lenReturn := len(d.ResponseHeaders) + 2
if fn.NumOut() != lenReturn {
return fmt.Errorf("function should return %d values but got %d: %w", lenReturn, fn.NumOut(), ErrDependencyInvalid)
}
for i, h := range d.ResponseHeaders {
if err := validateHeader(h, fn.Out(i)); err != nil {
return err
}
}
return nil
}
// Get a resolved dependency from the registry.
func (dr *DependencyRegistry) Get(op *Operation, c *gin.Context, t reflect.Type) (interface{}, error) {
if t.String() == "*gin.Context" {
// Special case: current gin context.
return c, nil
// AllParams returns all parameters for all dependencies in the graph of this
// dependency in depth-first order without duplicates.
func (d *Dependency) AllParams() []*Param {
params := []*Param{}
seen := map[*Param]bool{}
for _, p := range d.Params {
seen[p] = true
params = append(params, p)
}
if t.String() == "*huma.Operation" {
// Special case: current operation.
return op, nil
for _, d := range d.Depends {
for _, p := range d.AllParams() {
if _, ok := seen[p]; !ok {
seen[p] = true
params = append(params, p)
}
}
}
if f, ok := dr.registry[t]; ok {
// This argument matches a known registered dependency. If it's a
// function, then call it, otherwise just return the value.
vf := reflect.ValueOf(f)
if vf.Kind() == reflect.Func {
// Build the input argument list, which can consist of other dependencies.
args := make([]reflect.Value, vf.Type().NumIn())
return params
}
for i := 0; i < vf.Type().NumIn(); i++ {
v, err := dr.Get(op, c, vf.Type().In(i))
if err != nil {
return nil, err
}
args[i] = reflect.ValueOf(v)
// AllResponseHeaders returns all response headers for all dependencies in
// the graph of this dependency in depth-first order without duplicates.
func (d *Dependency) AllResponseHeaders() []*Header {
headers := []*Header{}
seen := map[*Header]bool{}
for _, h := range d.ResponseHeaders {
seen[h] = true
headers = append(headers, h)
}
for _, d := range d.Depends {
for _, h := range d.AllResponseHeaders() {
if _, ok := seen[h]; !ok {
seen[h] = true
headers = append(headers, h)
}
}
}
out := vf.Call(args)
return headers
}
if !out[1].IsNil() {
return nil, out[1].Interface().(error)
}
return out[0].Interface(), nil
// Resolve the value of the dependency. Returns (response headers, value, error).
func (d *Dependency) Resolve(c *gin.Context, op *Operation) (map[string]string, interface{}, error) {
// Identity dependencies are first. Just return if it's one of them.
if d == &contextDependency {
return nil, c, nil
}
if d == &operationDependency {
return nil, op, nil
}
v := reflect.ValueOf(d.Value)
if v.Kind() != reflect.Func {
// Not a function, just return the global value.
return nil, d.Value, nil
}
// Generate the input arguments
in := make([]reflect.Value, 0, v.Type().NumIn())
headers := map[string]string{}
// Resolve each sub-dependency
for _, dep := range d.Depends {
dHeaders, dVal, err := dep.Resolve(c, op)
if err != nil {
return nil, nil, err
}
// Not a function, just return the value.
return f, nil
for h, hv := range dHeaders {
headers[h] = hv
}
in = append(in, reflect.ValueOf(dVal))
}
return nil, fmt.Errorf("%s: %w", t, ErrDependencyNotFound)
// Get each input parameter
for _, param := range d.Params {
v, err := getParamValue(c, param)
if err != nil {
return nil, nil, err
}
in = append(in, reflect.ValueOf(v))
}
// Call the function.
out := v.Call(in)
if last := out[len(out)-1]; !last.IsNil() {
// There was an error!
return nil, nil, last.Interface().(error)
}
// Get the headers & response value.
for i, h := range d.ResponseHeaders {
headers[h.Name] = out[i].Interface().(string)
}
return headers, out[len(d.ResponseHeaders)].Interface(), nil
}

View file

@ -1,58 +1,128 @@
package huma
import (
"net/http"
"reflect"
"testing"
"github.com/alecthomas/assert"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
)
func TestDependencyNested(t *testing.T) {
type Dep1 struct{}
type Dep2 struct{}
type Dep3 struct{}
func TestGlobalDepEmpty(t *testing.T) {
d := Dependency{}
registry := NewDependencyRegistry()
assert.NoError(t, registry.Add(&Dep1{}))
typ := reflect.TypeOf(123)
assert.NoError(t, registry.Add(func(d1 *Dep1) (*Dep2, error) {
return &Dep2{}, nil
}))
assert.NoError(t, registry.Add(func(d1 *Dep1, d2 *Dep2) (*Dep3, error) {
return &Dep3{}, nil
}))
assert.Error(t, d.validate(typ))
}
func TestDependencyOrder(t *testing.T) {
type Dep1 struct{}
type Dep2 struct{}
func TestGlobalDepWrongType(t *testing.T) {
d := Dependency{
Value: "test",
}
registry := NewDependencyRegistry()
typ := reflect.TypeOf(123)
assert.Error(t, registry.Add(func(d2 *Dep2) (*Dep1, error) {
return &Dep1{}, nil
}))
assert.Error(t, d.validate(typ))
}
func TestDependencyNotPointer(t *testing.T) {
type Dep1 struct{}
func TestGlobalDepParams(t *testing.T) {
d := Dependency{
Params: []*Param{
HeaderParam("foo", "description", "hello"),
},
Value: "test",
}
registry := NewDependencyRegistry()
typ := reflect.TypeOf("test")
assert.Error(t, registry.Add(Dep1{}))
assert.Error(t, registry.Add(func() (Dep1, error) {
return Dep1{}, nil
}))
assert.Error(t, d.validate(typ))
}
func TestDependencyDupe(t *testing.T) {
type Dep1 struct{}
func TestGlobalDepHeaders(t *testing.T) {
d := Dependency{
ResponseHeaders: []*Header{ResponseHeader("foo", "description")},
Value: "test",
}
registry := NewDependencyRegistry()
typ := reflect.TypeOf("test")
assert.NoError(t, registry.Add(&Dep1{}))
assert.Error(t, registry.Add(&Dep1{}))
assert.Error(t, registry.Add(func() (*Dep1, error) {
return nil, nil
}))
assert.Error(t, d.validate(typ))
}
func TestDepContext(t *testing.T) {
d := Dependency{
Depends: []*Dependency{
ContextDependency(),
},
Value: func(c *gin.Context) (*gin.Context, error) { return c, nil },
}
mock := &gin.Context{}
typ := reflect.TypeOf(mock)
assert.NoError(t, d.validate(typ))
_, v, err := d.Resolve(mock, &Operation{})
assert.NoError(t, err)
assert.Equal(t, v, mock)
}
func TestDepOperation(t *testing.T) {
d := Dependency{
Depends: []*Dependency{
OperationDependency(),
},
Value: func(o *Operation) (*Operation, error) { return o, nil },
}
mock := &Operation{}
typ := reflect.TypeOf(mock)
assert.NoError(t, d.validate(typ))
_, v, err := d.Resolve(&gin.Context{}, mock)
assert.NoError(t, err)
assert.Equal(t, v, mock)
}
func TestDepFuncWrongArgs(t *testing.T) {
d := Dependency{
Params: []*Param{
HeaderParam("foo", "desc", ""),
},
Value: func() (string, error) {
return "", nil
},
}
assert.Error(t, d.validate(reflect.TypeOf("")))
}
func TestDepFunc(t *testing.T) {
d := Dependency{
Params: []*Param{
HeaderParam("x-in", "desc", ""),
},
ResponseHeaders: []*Header{
ResponseHeader("x-out", "desc"),
},
Value: func(xin string) (string, string, error) {
return "xout", "value", nil
},
}
c := &gin.Context{
Request: &http.Request{
Header: http.Header{
"x-in": []string{"xin"},
},
},
}
assert.NoError(t, d.validate(reflect.TypeOf("")))
h, v, err := d.Resolve(c, &Operation{})
assert.NoError(t, err)
assert.Equal(t, "xout", h["x-out"])
assert.Equal(t, "value", v)
}

4
go.mod
View file

@ -4,13 +4,13 @@ go 1.13
require (
github.com/Jeffail/gabs v1.4.0
github.com/alecthomas/assert v0.0.0-20170929043011-405dbfeb8e38
github.com/alecthomas/assert v0.0.0-20170929043011-405dbfeb8e38 // indirect
github.com/alecthomas/colour v0.1.0 // indirect
github.com/alecthomas/repr v0.0.0-20181024024818-d37bc2a10ba1 // indirect
github.com/davecgh/go-spew v1.1.1
github.com/gin-gonic/gin v1.5.0
github.com/gosimple/slug v1.9.0
github.com/rs/zerolog v1.18.0
github.com/rs/zerolog v1.18.0 // indirect
github.com/sergi/go-diff v1.1.0 // indirect
github.com/stretchr/testify v1.5.1
github.com/xeipuuv/gojsonschema v1.2.0

3
go.sum
View file

@ -25,8 +25,10 @@ github.com/gosimple/slug v1.9.0 h1:r5vDcYrFz9BmfIAMC829un9hq7hKM4cHUrsv36LbEqs=
github.com/gosimple/slug v1.9.0/go.mod h1:AMZ+sOVe65uByN3kgEyf9WEBKBCSS+dJjMX9x4vDJbg=
github.com/json-iterator/go v1.1.7 h1:KfgG9LzI+pYjr4xvmz/5H4FXjokeP+rlHLhv3iH62Fo=
github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/leodido/go-urn v1.1.0 h1:Sm1gr51B1kKyfD2BlRcLSiEkffoG96g6TPv6eRoEiB8=
github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdAPozLkw=
@ -75,6 +77,7 @@ golang.org/x/tools v0.0.0-20190828213141-aed303cbaa74/go.mod h1:b+2E5dAYhXwXZwtn
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/go-playground/assert.v1 v1.2.1 h1:xoYuJVE7KT85PYWrN730RguIQO0ePzVRfFMXadIrXTM=
gopkg.in/go-playground/assert.v1 v1.2.1/go.mod h1:9RXL0bg/zibRAgZUYszZSwO/z8Y/a8bDuhia5mkpMnE=

View file

@ -153,6 +153,7 @@ type Operation struct {
Summary string
Description string
Tags []string
Depends []*Dependency
Params []*Param
RequestContentType string
RequestSchema *Schema
@ -161,6 +162,54 @@ type Operation struct {
Handler interface{}
}
// AllParams returns a list of all the parameters for this operation, including
// those for dependencies.
func (o *Operation) AllParams() []*Param {
params := []*Param{}
seen := map[*Param]bool{}
for _, p := range o.Params {
seen[p] = true
params = append(params, p)
}
for _, d := range o.Depends {
for _, p := range d.AllParams() {
if _, ok := seen[p]; !ok {
seen[p] = true
params = append(params, p)
}
}
}
return params
}
// AllResponseHeaders returns a list of all the parameters for this operation,
// including those for dependencies.
func (o *Operation) AllResponseHeaders() []*Header {
headers := []*Header{}
seen := map[*Header]bool{}
for _, h := range o.ResponseHeaders {
seen[h] = true
headers = append(headers, h)
}
for _, d := range o.Depends {
for _, h := range d.AllResponseHeaders() {
if _, ok := seen[h]; !ok {
seen[h] = true
headers = append(headers, h)
}
}
}
return headers
}
// Server describes an OpenAPI 3 API server location
type Server struct {
URL string `json:"url"`
@ -189,6 +238,7 @@ type OpenAPI struct {
Version string
Servers []*Server
Paths map[string][]*Operation
// TODO: Depends []*Dependency
}
// OpenAPIHandler returns a new handler function to generate an OpenAPI spec.
@ -224,7 +274,7 @@ func OpenAPIHandler(api *OpenAPI) func(*gin.Context) {
openapi.Set(op.Tags, "paths", path, method, "tags")
}
for _, param := range op.Params {
for _, param := range op.AllParams() {
if param.internal {
// Skip internal-only parameters.
continue
@ -260,7 +310,7 @@ func OpenAPIHandler(api *OpenAPI) func(*gin.Context) {
}
headerMap := map[string]*Header{}
for _, header := range op.ResponseHeaders {
for _, header := range op.AllResponseHeaders() {
headerMap[header.Name] = header
}
@ -268,7 +318,22 @@ func OpenAPIHandler(api *OpenAPI) func(*gin.Context) {
status := fmt.Sprintf("%v", resp.StatusCode)
openapi.Set(resp.Description, "paths", path, method, "responses", status, "description")
headers := make([]string, 0, len(resp.Headers))
seen := map[string]bool{}
for _, name := range resp.Headers {
headers = append(headers, name)
seen[name] = true
}
for _, dep := range op.Depends {
for _, header := range dep.AllResponseHeaders() {
if _, ok := seen[header.Name]; !ok {
headers = append(headers, header.Name)
seen[header.Name] = true
}
}
}
for _, name := range headers {
header := headerMap[name]
openapi.Set(header, "paths", path, method, "responses", status, "headers", header.Name)
}

View file

@ -125,7 +125,6 @@ func getRequestBody(c *gin.Context, t reflect.Type, op *Operation) (interface{},
type Router struct {
api *OpenAPI
engine *gin.Engine
deps *DependencyRegistry
}
// NewRouter creates a new Huma router for handling API requests with
@ -140,7 +139,6 @@ func NewRouterWithGin(engine *gin.Engine, api *OpenAPI) *Router {
r := &Router{
api: api,
engine: engine,
deps: NewDependencyRegistry(),
}
if r.api.Paths == nil {
@ -214,17 +212,17 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// })
//
// Panics on invalid input to force stop execution on service startup.
func (r *Router) Dependency(f interface{}) {
if err := r.deps.Add(f); err != nil {
panic(err)
}
}
// func (r *Router) Dependency(f interface{}) {
// if err := r.deps.Add(f); err != nil {
// panic(err)
// }
// }
// Register a new operation.
func (r *Router) Register(op *Operation) {
// First, make sure the operation and handler make sense, as well as pre-
// generating any schemas for use later during request handling.
if err := op.validate(r.deps.registry); err != nil {
if err := op.validate(); err != nil {
panic(err)
}
@ -263,22 +261,21 @@ func (r *Router) Register(op *Operation) {
in := make([]reflect.Value, 0, method.Type().NumIn())
// Process any dependencies first.
for i := 0; i < method.Type().NumIn(); i++ {
argType := method.Type().In(i)
v, err := r.deps.Get(op, c, argType)
for _, dep := range op.Depends {
headers, value, err := dep.Resolve(c, op)
if err != nil {
if errors.Is(err, ErrDependencyNotFound) {
// No match, so we're done with dependencies. Keep going below
// processing params.
break
}
// Getting the dependency value failed.
// TODO: better error code/messaging?
c.AbortWithError(500, err)
return
// TODO: better error handling
c.AbortWithStatusJSON(500, ErrorModel{
Message: "Couldn't get dependency",
//Errors: []error{err},
})
}
in = append(in, reflect.ValueOf(v))
for k, v := range headers {
c.Header(k, v)
}
in = append(in, reflect.ValueOf(value))
}
for _, param := range op.Params {

View file

@ -67,6 +67,120 @@ func BenchmarkHuma(b *testing.B) {
}
}
func BenchmarkGinComplex(b *testing.B) {
dep1 := "dep1"
dep2 := func(c *gin.Context) string {
_ = c.GetHeader("x-foo")
return "dep2"
}
dep3 := func(c *gin.Context) (string, string) {
return "xbar", "dep3"
}
g := gin.New()
g.GET("/hello", func(c *gin.Context) {
_ = dep1
_ = dep2(c)
h, _ := dep3(c)
c.Header("x-bar", h)
name := c.Query("name")
if name == "test" {
c.JSON(400, &ErrorModel{
Message: "Name cannot be test",
})
}
if name == "" {
name = "world"
}
c.Header("x-baz", "xbaz")
c.JSON(200, &helloResponse{
Message: "Hello, " + name,
})
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/hello", nil)
b.ResetTimer()
for i := 0; i < b.N; i++ {
g.ServeHTTP(w, req)
}
}
func BenchmarkHumaComplex(b *testing.B) {
r := NewRouterWithGin(gin.New(), &OpenAPI{
Title: "Benchmark test",
Version: "1.0.0",
})
dep1 := &Dependency{
Value: "dep1",
}
dep2 := &Dependency{
Depends: []*Dependency{ContextDependency(), dep1},
Params: []*Param{
HeaderParam("x-foo", "desc", ""),
},
Value: func(c *gin.Context, d1 string, xfoo string) (string, error) {
return "dep2", nil
},
}
dep3 := &Dependency{
Depends: []*Dependency{dep1},
ResponseHeaders: []*Header{
ResponseHeader("x-bar", "desc"),
},
Value: func(d1 string) (string, string, error) {
return "xbar", "dep3", nil
},
}
r.Register(&Operation{
Method: http.MethodGet,
Path: "/hello",
Description: "Greet the world",
Depends: []*Dependency{
ContextDependency(), dep2, dep3,
},
Params: []*Param{
QueryParam("name", "desc", "world"),
},
ResponseHeaders: []*Header{
ResponseHeader("x-baz", "desc"),
},
Responses: []*Response{
ResponseJSON(200, "Return a greeting", "x-baz"),
ResponseError(500, "desc"),
},
Handler: func(c *gin.Context, d2, d3, name string) (string, *helloResponse, *ErrorModel) {
if name == "test" {
return "", nil, &ErrorModel{
Message: "Name cannot be test",
}
}
return "xbaz", &helloResponse{
Message: "Hello, " + name,
}, nil
},
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/hello?name=Daniel", nil)
b.ResetTimer()
for i := 0; i < b.N; i++ {
r.ServeHTTP(w, req)
}
}
func TestRouter(t *testing.T) {
type EchoResponse struct {
Value string `json:"value" description:"The echoed back word"`
@ -253,30 +367,38 @@ func TestRouterDependencies(t *testing.T) {
Get func() string
}
// Inject datastore as a global instance.
r.Dependency(&DB{
Get: func() string {
return "Hello, "
// Datastore is a global dependency, set by value.
db := &Dependency{
Value: &DB{
Get: func() string {
return "Hello, "
},
},
})
}
type Logger struct {
Log func(msg string)
}
// Inject logger as a contextual instance from the Gin context.
r.Dependency(func(c *gin.Context) (*Logger, error) {
return &Logger{
Log: func(msg string) {
fmt.Println(fmt.Sprintf("%s [uri:%s]", msg, c.FullPath()))
},
}, nil
})
// Logger is a contextual instance from the gin request context.
log := &Dependency{
Depends: []*Dependency{
ContextDependency(),
},
Value: func(c *gin.Context) (*Logger, error) {
return &Logger{
Log: func(msg string) {
fmt.Println(fmt.Sprintf("%s [uri:%s]", msg, c.FullPath()))
},
}, nil
},
}
r.Register(&Operation{
Method: http.MethodGet,
Path: "/hello",
Description: "Basic hello world",
Depends: []*Dependency{ContextDependency(), db, log},
Params: []*Param{
QueryParam("name", "Your name", ""),
},

View file

@ -13,10 +13,6 @@ import (
// ErrFieldRequired is returned when a field is blank but has been required.
var ErrFieldRequired = errors.New("field is required")
// ErrContextNotFirst is returned when a registered operation has a handler
// that takes a context but it is not the first parameter of the function.
var ErrContextNotFirst = errors.New("context should be first parameter")
// ErrParamsMustMatch is returned when a registered operation has a handler
// function that takes the wrong number of arguments.
var ErrParamsMustMatch = errors.New("handler function args must match registered params")
@ -31,9 +27,40 @@ var ErrResponsesMustMatch = errors.New("handler function return values must matc
var paramRe = regexp.MustCompile(`:([^/]+)|{([^}]+)}`)
func validateParam(p *Param, t reflect.Type) error {
p.typ = t
if p.Schema == nil {
s, err := GenerateSchema(p.typ)
if err != nil {
return err
}
p.Schema = s
if p.def != nil {
p.Schema.Default = p.def
}
}
return nil
}
func validateHeader(h *Header, t reflect.Type) error {
if h.Schema == nil {
// Generate the schema from the handler function types.
s, err := GenerateSchema(t)
if err != nil {
return err
}
h.Schema = s
}
return nil
}
// validate checks that the operation is well-formed (e.g. handler signature
// matches the given params) and generates schemas if needed.
func (o *Operation) validate(deps map[reflect.Type]interface{}) error {
func (o *Operation) validate() error {
if o.Method == "" {
return fmt.Errorf("Method: %w", ErrFieldRequired)
}
@ -70,27 +97,32 @@ func (o *Operation) validate(deps map[reflect.Type]interface{}) error {
o.Path = paramRe.ReplaceAllString(o.Path, ":$1$2")
}
types := []reflect.Type{}
for i := 0; i < method.Type().NumIn(); i++ {
for i, dep := range o.Depends {
paramType := method.Type().In(i)
if paramType.String() == "*gin.Context" || paramType.String() == "*huma.Operation" {
// Known hard-coded dependencies. Skip them.
continue
}
if paramType.String() == "gin.Context" {
return fmt.Errorf("gin context should be pointer *gin.Context: %w", ErrDependencyInvalid)
return fmt.Errorf("gin.Context should be pointer *gin.Context: %w", ErrDependencyInvalid)
}
if paramType.String() == "huma.Operation" {
return fmt.Errorf("operation should be pointer *huma.Operation: %w", ErrDependencyInvalid)
return fmt.Errorf("huma.Operation should be pointer *huma.Operation: %w", ErrDependencyInvalid)
}
if _, ok := deps[paramType]; ok {
// This matches a registered dependency type, so it's not a normal
// param. Skip it.
continue
if err := dep.validate(paramType); err != nil {
return err
}
}
types := []reflect.Type{}
for i := len(o.Depends); i < method.Type().NumIn(); i++ {
paramType := method.Type().In(i)
if paramType.String() == "gin.Context" {
return fmt.Errorf("gin.Context should be pointer *gin.Context: %w", ErrDependencyInvalid)
}
if paramType.String() == "huma.Operation" {
return fmt.Errorf("huma.Operation should be pointer *huma.Operation: %w", ErrDependencyInvalid)
}
types = append(types, paramType)
@ -123,21 +155,8 @@ func (o *Operation) validate(deps map[reflect.Type]interface{}) error {
}
p := o.Params[i]
p.typ = paramType
if p.Schema == nil {
// Auto-generate a schema for this parameter
s, err := GenerateSchema(paramType)
if err != nil {
return err
}
p.Schema = s
if p.def != nil {
if reflect.ValueOf(p.def).Type() != paramType {
}
p.Schema.Default = p.def
}
if err := validateParam(p, paramType); err != nil {
return err
}
}
@ -148,14 +167,8 @@ func (o *Operation) validate(deps map[reflect.Type]interface{}) error {
}
for i, header := range o.ResponseHeaders {
if header.Schema == nil {
// Generate the schema from the handler function types.
headerType := method.Type().Out(i)
s, err := GenerateSchema(headerType)
if err != nil {
return err
}
header.Schema = s
if err := validateHeader(header, method.Type().Out(i)); err != nil {
return err
}
}