refactor: dependency injection code

This commit is contained in:
Daniel G. Taylor 2020-03-12 08:59:01 -07:00
parent 417aa498df
commit c98d6145e4
No known key found for this signature in database
GPG key ID: 7BD6DC99C9A87E22
9 changed files with 280 additions and 76 deletions

View file

@ -2,7 +2,7 @@
[![CI](https://github.com/danielgtaylor/huma/workflows/CI/badge.svg?branch=master)](https://github.com/danielgtaylor/huma/actions?query=workflow%3ACI+branch%3Amaster++) [![codecov](https://codecov.io/gh/danielgtaylor/huma/branch/master/graph/badge.svg)](https://codecov.io/gh/danielgtaylor/huma) [![Docs](https://godoc.org/github.com/danielgtaylor/huma?status.svg)](https://pkg.go.dev/github.com/danielgtaylor/huma?tab=doc) [![Go Report Card](https://goreportcard.com/badge/github.com/danielgtaylor/huma)](https://goreportcard.com/report/github.com/danielgtaylor/huma)
A modern, simple, fast & opinionated REST API framework for Go. The goals of this project are to provide:
A modern, simple, fast & opinionated REST API framework for Go. Pronounced IPA: [/'hjuːmɑ/](https://en.wiktionary.org/wiki/Wiktionary:International_Phonetic_Alphabet). The goals of this project are to provide:
- A modern REST API backend framework for Go developers
- Described by [OpenAPI 3](https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.3.md) & [JSON Schema](https://json-schema.org/)
@ -235,7 +235,10 @@ The standard `json` tag is supported and can be used to rename a field and mark
Huma includes a dependency injection system that can be used to pass additional arguments to operation handler functions. You can register global dependencies (ones that do not change from request to request) or contextual dependencies (ones that change with each request).
Global dependencies are created by just setting some value, while contextual dependencies are implemented using a function that returns the value of the form `func (c *gin.Context, o *huma.Operation) (*YourType, error)` where the value you want injected is of `*YourType`.
Global dependencies are created by just setting some value, while contextual dependencies are implemented using a function that returns the value of the form `func (dep *DependencyType...) (*YourType, error)` where the value you want injected is of `*YourType` and the function arguments can be any previously registered dependency types or one of the hard-coded types:
- `*gin.Context` the current context
- `*huma.Operation` the current operation
```go
// Register a new database connection dependency
@ -247,7 +250,7 @@ type MyLogger struct {
Info: func(msg string),
}
r.Dependency(func(c *gin.Context, o *huma.Operation) (*MyLogger, error) {
r.Dependency(func(c *gin.Context) (*MyLogger, error) {
return &MyLogger{
Info: func(msg string) {
fmt.Printf("%s [ip:%s]\n", msg, c.Request.RemoteAddr)

134
dependency.go Normal file
View file

@ -0,0 +1,134 @@
package huma
import (
"errors"
"fmt"
"reflect"
"github.com/gin-gonic/gin"
)
// ErrDependencyInvalid is returned when registering a dependency fails.
var ErrDependencyInvalid = errors.New("dependency invalid")
// ErrDependencyNotFound is returned when the given type isn't registered
// as a dependency.
var ErrDependencyNotFound = errors.New("dependency not found")
// DependencyRegistry let's you register and resolve dependencies based on
// their type.
type DependencyRegistry struct {
registry map[reflect.Type]interface{}
}
// NewDependencyRegistry creates a new blank dependency registry.
func NewDependencyRegistry() *DependencyRegistry {
return &DependencyRegistry{
registry: make(map[reflect.Type]interface{}),
}
}
// Add a new dependency to the registry.
func (dr *DependencyRegistry) Add(item interface{}) error {
if dr.registry == nil {
dr.registry = make(map[reflect.Type]interface{})
}
val := reflect.ValueOf(item)
outType := val.Type()
valType := val.Type()
if val.Kind() == reflect.Func {
for i := 0; i < valType.NumIn(); i++ {
argType := valType.In(i)
if argType.String() == "*gin.Context" || argType.String() == "*huma.Operation" {
// Known hard-coded dependencies. Skip them.
continue
}
if argType.Kind() != reflect.Ptr {
return fmt.Errorf("should be pointer *%s: %w", argType, ErrDependencyInvalid)
}
if _, ok := dr.registry[argType]; !ok {
return fmt.Errorf("unknown dependency type %s, are dependencies defined in order? %w", argType, ErrDependencyNotFound)
}
}
if val.Type().NumOut() != 2 || val.Type().Out(1).Name() != "error" {
return fmt.Errorf("function should return (your-type, error): %w", ErrDependencyInvalid)
}
outType = val.Type().Out(0)
if outType.Kind() != reflect.Ptr {
return fmt.Errorf("should be pointer *%s: %w", outType, ErrDependencyInvalid)
}
if _, ok := dr.registry[outType]; ok {
return fmt.Errorf("duplicate type %s: %w", outType.String(), ErrDependencyInvalid)
}
} else {
if valType.Kind() != reflect.Ptr {
return fmt.Errorf("should be pointer *%s: %w", valType, ErrDependencyInvalid)
}
if _, ok := dr.registry[valType]; ok {
return fmt.Errorf("duplicate type %s: %w", valType.String(), ErrDependencyInvalid)
}
}
// To prevent mistakes we limit dependencies to non-scalar types, since
// scalars like strings/numbers are typically used for params like headers.
switch outType.Kind() {
case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.String:
return fmt.Errorf("dependeny cannot be scalar type %s: %w", outType.Kind(), ErrDependencyInvalid)
}
dr.registry[outType] = item
return nil
}
// Get a resolved dependency from the registry.
func (dr *DependencyRegistry) Get(op *Operation, c *gin.Context, t reflect.Type) (interface{}, error) {
if t.String() == "*gin.Context" {
// Special case: current gin context.
return c, nil
}
if t.String() == "*huma.Operation" {
// Special case: current operation.
return op, nil
}
if f, ok := dr.registry[t]; ok {
// This argument matches a known registered dependency. If it's a
// function, then call it, otherwise just return the value.
vf := reflect.ValueOf(f)
if vf.Kind() == reflect.Func {
// Build the input argument list, which can consist of other dependencies.
args := make([]reflect.Value, vf.Type().NumIn())
for i := 0; i < vf.Type().NumIn(); i++ {
v, err := dr.Get(op, c, vf.Type().In(i))
if err != nil {
return nil, err
}
args[i] = reflect.ValueOf(v)
}
out := vf.Call(args)
if !out[1].IsNil() {
return nil, out[1].Interface().(error)
}
return out[0].Interface(), nil
}
// Not a function, just return the value.
return f, nil
}
return nil, fmt.Errorf("%s: %w", t, ErrDependencyNotFound)
}

58
dependency_test.go Normal file
View file

@ -0,0 +1,58 @@
package huma
import (
"testing"
"github.com/alecthomas/assert"
)
func TestDependencyNested(t *testing.T) {
type Dep1 struct{}
type Dep2 struct{}
type Dep3 struct{}
registry := NewDependencyRegistry()
assert.NoError(t, registry.Add(&Dep1{}))
assert.NoError(t, registry.Add(func(d1 *Dep1) (*Dep2, error) {
return &Dep2{}, nil
}))
assert.NoError(t, registry.Add(func(d1 *Dep1, d2 *Dep2) (*Dep3, error) {
return &Dep3{}, nil
}))
}
func TestDependencyOrder(t *testing.T) {
type Dep1 struct{}
type Dep2 struct{}
registry := NewDependencyRegistry()
assert.Error(t, registry.Add(func(d2 *Dep2) (*Dep1, error) {
return &Dep1{}, nil
}))
}
func TestDependencyNotPointer(t *testing.T) {
type Dep1 struct{}
registry := NewDependencyRegistry()
assert.Error(t, registry.Add(Dep1{}))
assert.Error(t, registry.Add(func() (Dep1, error) {
return Dep1{}, nil
}))
}
func TestDependencyDupe(t *testing.T) {
type Dep1 struct{}
registry := NewDependencyRegistry()
assert.NoError(t, registry.Add(&Dep1{}))
assert.Error(t, registry.Add(&Dep1{}))
assert.Error(t, registry.Add(func() (*Dep1, error) {
return nil, nil
}))
}

5
go.mod
View file

@ -4,9 +4,14 @@ go 1.13
require (
github.com/Jeffail/gabs v1.4.0
github.com/alecthomas/assert v0.0.0-20170929043011-405dbfeb8e38
github.com/alecthomas/colour v0.1.0 // indirect
github.com/alecthomas/repr v0.0.0-20181024024818-d37bc2a10ba1 // indirect
github.com/davecgh/go-spew v1.1.1
github.com/gin-gonic/gin v1.5.0
github.com/gosimple/slug v1.9.0
github.com/rs/zerolog v1.18.0
github.com/sergi/go-diff v1.1.0 // indirect
github.com/stretchr/testify v1.5.1
github.com/xeipuuv/gojsonschema v1.2.0
)

27
go.sum
View file

@ -1,5 +1,12 @@
github.com/Jeffail/gabs v1.4.0 h1://5fYRRTq1edjfIrQGvdkcd22pkYUrHZ5YC/H2GJVAo=
github.com/Jeffail/gabs v1.4.0/go.mod h1:6xMvQMK4k33lb7GUUpaAPh6nKMmemQeg5d4gn7/bOXc=
github.com/alecthomas/assert v0.0.0-20170929043011-405dbfeb8e38 h1:smF2tmSOzy2Mm+0dGI2AIUHY+w0BUc+4tn40djz7+6U=
github.com/alecthomas/assert v0.0.0-20170929043011-405dbfeb8e38/go.mod h1:r7bzyVFMNntcxPZXK3/+KdruV1H5KSlyVY0gc+NgInI=
github.com/alecthomas/colour v0.1.0 h1:nOE9rJm6dsZ66RGWYSFrXw461ZIt9A6+nHgL7FRrDUk=
github.com/alecthomas/colour v0.1.0/go.mod h1:QO9JBoKquHd+jz9nshCh40fOfO+JzsoXy8qTHF68zU0=
github.com/alecthomas/repr v0.0.0-20181024024818-d37bc2a10ba1 h1:GDQdwm/gAcJcLAKQQZGOJ4knlw+7rfEQQcmwTbt4p5E=
github.com/alecthomas/repr v0.0.0-20181024024818-d37bc2a10ba1/go.mod h1:xTS7Pm1pD1mvyM075QCDSRqH6qRLXylzS24ZTpRiSzQ=
github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@ -18,6 +25,9 @@ github.com/gosimple/slug v1.9.0 h1:r5vDcYrFz9BmfIAMC829un9hq7hKM4cHUrsv36LbEqs=
github.com/gosimple/slug v1.9.0/go.mod h1:AMZ+sOVe65uByN3kgEyf9WEBKBCSS+dJjMX9x4vDJbg=
github.com/json-iterator/go v1.1.7 h1:KfgG9LzI+pYjr4xvmz/5H4FXjokeP+rlHLhv3iH62Fo=
github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/leodido/go-urn v1.1.0 h1:Sm1gr51B1kKyfD2BlRcLSiEkffoG96g6TPv6eRoEiB8=
github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdAPozLkw=
github.com/mattn/go-isatty v0.0.9 h1:d5US/mDsogSGW37IV293h//ZFaeajb69h+EHFsv2xGg=
@ -26,10 +36,16 @@ github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OH
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742 h1:Esafd1046DLDQ0W1YjYsBW+p8U2u7vzgW2SQVmlNazg=
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rainycape/unidecode v0.0.0-20150907023854-cb7f23ec59be h1:ta7tUOvsPHVHGom5hKW5VXNc2xZIkfCKP8iaqOyYtUQ=
github.com/rainycape/unidecode v0.0.0-20150907023854-cb7f23ec59be/go.mod h1:MIDFMn7db1kT65GmV94GzpX9Qdi7N/pQlwb+AN8wh+Q=
github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ=
github.com/rs/zerolog v1.18.0 h1:CbAm3kP2Tptby1i9sYy2MGRg0uxIN9cyDb59Ys7W8z8=
github.com/rs/zerolog v1.18.0/go.mod h1:9nvC1axdVrAHcu/s9taAVfBuIdTZLVQmKQyvrUjF5+I=
github.com/sergi/go-diff v1.1.0 h1:we8PVUC3FE2uYfodKH/nBHMSetSfHDR6scGdBi+erh0=
github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM=
github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
@ -47,13 +63,24 @@ github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHo
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ=
github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17UxZ74=
github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y=
github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a h1:aYOabOQFp6Vj6W1F80affTUvO9UxmJRx8K0gsfABByQ=
golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/tools v0.0.0-20190828213141-aed303cbaa74/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/go-playground/assert.v1 v1.2.1 h1:xoYuJVE7KT85PYWrN730RguIQO0ePzVRfFMXadIrXTM=
gopkg.in/go-playground/assert.v1 v1.2.1/go.mod h1:9RXL0bg/zibRAgZUYszZSwO/z8Y/a8bDuhia5mkpMnE=
gopkg.in/go-playground/validator.v9 v9.29.1 h1:SvGtYmN60a5CVKTOzMSyfzWDeZRxRuGvRQyEAKbw1xc=
gopkg.in/go-playground/validator.v9 v9.29.1/go.mod h1:+c9/zcJMFNgbLvly1L1V+PpxWdVbfP1avr/N00E2vyQ=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.4 h1:/eiJrUcujPVeJ3xlSWaiNi3uSVmDGBK1pDHUHAnao1I=
gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=

View file

@ -18,9 +18,6 @@ import (
// is not a valid value.
var ErrInvalidParamLocation = errors.New("invalid parameter location")
// ErrDependencyInvalid is returned when registering a dependency fails.
var ErrDependencyInvalid = errors.New("dependency invalid")
func getParamValue(c *gin.Context, param *Param) (interface{}, error) {
var pstr string
switch param.In {
@ -81,7 +78,6 @@ func getParamValue(c *gin.Context, param *Param) (interface{}, error) {
func getRequestBody(c *gin.Context, t reflect.Type, op *Operation) (interface{}, bool) {
val := reflect.New(t).Interface()
if op.RequestSchema != nil {
body, err := ioutil.ReadAll(c.Request.Body)
if err != nil {
@ -129,7 +125,7 @@ func getRequestBody(c *gin.Context, t reflect.Type, op *Operation) (interface{},
type Router struct {
api *OpenAPI
engine *gin.Engine
deps map[reflect.Type]interface{}
deps *DependencyRegistry
}
// NewRouter creates a new Huma router for handling API requests with
@ -144,18 +140,13 @@ func NewRouterWithGin(engine *gin.Engine, api *OpenAPI) *Router {
r := &Router{
api: api,
engine: engine,
deps: NewDependencyRegistry(),
}
if r.api.Paths == nil {
r.api.Paths = make(map[string][]*Operation)
}
// Add the default context dependency.
r.deps = make(map[reflect.Type]interface{})
r.Dependency(func(c *gin.Context, o *Operation) (*gin.Context, error) {
return c, nil
})
// Set up handlers for the auto-generated spec and docs.
r.engine.GET("/openapi.json", OpenAPIHandler(r.api))
r.engine.GET("/docs", func(c *gin.Context) {
@ -195,13 +186,19 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// Dependency registers a new dependency type to be injected into handler
// functions, e.g. for loggers, metrics, datastores, etc. Provide a value
// or a function to return a contextual value/error.
// or a function to return a contextual value/error. Dependency functions
// can take their own dependencies. To prevent circular dependency loops,
// a function can only depend on previously defined dependencies.
//
// Some dependency types are built in:
// - `*gin.Context` the current Gin request execution context
// - `*huma.Operation` the current Huma operation
//
// // Register a global dependency like a datastore
// router.Dependency(&MyDB{...})
//
// // Register a contextual dependency like a logger
// router.Dependency(func (c *gin.Context, o *huma.Operation) (*MyLogger, error) {
// router.Dependency(func (c *gin.Context) (*MyLogger, error) {
// return &MyLogger{Tags: []string{c.Request.RemoteAddr}}, nil
// })
//
@ -215,46 +212,19 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// return item
// }
// })
//
// Panics on invalid input to force stop execution on service startup.
func (r *Router) Dependency(f interface{}) {
fVal := reflect.ValueOf(f)
outType := fVal.Type()
if fVal.Kind() == reflect.Func {
fType := fVal.Type()
if fType.NumIn() != 2 {
panic(fmt.Errorf("function should take 2 arguments (*gin.Context, *huma.Operation) but got %s: %w", fType.String(), ErrDependencyInvalid))
}
if fType.In(0).String() != "*gin.Context" || fType.In(1).String() != "*huma.Operation" {
panic(fmt.Errorf("function should take (*gin.Context, *huma.Operation) but got (%s, %s): %w", fType.In(0).String(), fType.In(1).String(), ErrDependencyInvalid))
}
if fVal.Type().NumOut() != 2 || fVal.Type().Out(1).Name() != "error" {
panic(fmt.Errorf("function should return (your-type, error): %w", ErrDependencyInvalid))
}
outType = fVal.Type().Out(0)
if _, ok := r.deps[outType]; ok {
panic(fmt.Errorf("duplicate type %s: %w", outType.String(), ErrDependencyInvalid))
}
if err := r.deps.Add(f); err != nil {
panic(err)
}
// To prevent mistakes we limit dependencies to non-scalar types, since
// scalars like strings/numbers are typically used for params like headers.
switch outType.Kind() {
case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.String:
panic(fmt.Errorf("dependeny cannot be scalar type %s: %w", outType.Kind(), ErrDependencyInvalid))
}
r.deps[outType] = f
}
// Register a new operation.
func (r *Router) Register(op *Operation) {
// First, make sure the operation and handler make sense, as well as pre-
// generating any schemas for use later during request handling.
if err := op.validate(r.deps); err != nil {
if err := op.validate(r.deps.registry); err != nil {
panic(err)
}
@ -295,28 +265,20 @@ func (r *Router) Register(op *Operation) {
// Process any dependencies first.
for i := 0; i < method.Type().NumIn(); i++ {
argType := method.Type().In(i)
if f, ok := r.deps[argType]; ok {
// This handler argument matches a known registered dependency. If it's
// a function, then call it, otherwise just use the value.
var v reflect.Value
vf := reflect.ValueOf(f)
if vf.Kind() == reflect.Func {
args := []reflect.Value{reflect.ValueOf(c), reflect.ValueOf(op)}
out := vf.Call(args)
if !out[1].IsNil() {
c.AbortWithError(500, out[1].Interface().(error))
return
}
v = out[0]
} else {
v = reflect.ValueOf(f)
v, err := r.deps.Get(op, c, argType)
if err != nil {
if errors.Is(err, ErrDependencyNotFound) {
// No match, so we're done with dependencies. Keep going below
// processing params.
break
}
in = append(in, v)
} else {
// No match, so we're done with dependencies. Keep going below
// processing params.
break
// Getting the dependency value failed.
// TODO: better error code/messaging?
c.AbortWithError(500, err)
return
}
in = append(in, reflect.ValueOf(v))
}
for _, param := range op.Params {

View file

@ -265,7 +265,7 @@ func TestRouterDependencies(t *testing.T) {
}
// Inject logger as a contextual instance from the Gin context.
r.Dependency(func(c *gin.Context, o *Operation) (*Logger, error) {
r.Dependency(func(c *gin.Context) (*Logger, error) {
return &Logger{
Log: func(msg string) {
fmt.Println(fmt.Sprintf("%s [uri:%s]", msg, c.FullPath()))

View file

@ -1,7 +1,7 @@
package huma
import (
"errors"
"fmt"
"reflect"
"strconv"
"strings"
@ -126,7 +126,7 @@ func GenerateSchema(t reflect.Type) (*Schema, error) {
case reflect.Ptr:
return GenerateSchema(t.Elem())
default:
return nil, errors.New("unsupported type")
return nil, fmt.Errorf("unsupported type %s from %s", t.Kind(), t)
}
return schema, nil

View file

@ -74,6 +74,19 @@ func (o *Operation) validate(deps map[reflect.Type]interface{}) error {
for i := 0; i < method.Type().NumIn(); i++ {
paramType := method.Type().In(i)
if paramType.String() == "*gin.Context" || paramType.String() == "*huma.Operation" {
// Known hard-coded dependencies. Skip them.
continue
}
if paramType.String() == "gin.Context" {
return fmt.Errorf("gin context should be pointer *gin.Context: %w", ErrDependencyInvalid)
}
if paramType.String() == "huma.Operation" {
return fmt.Errorf("operation should be pointer *huma.Operation: %w", ErrDependencyInvalid)
}
if _, ok := deps[paramType]; ok {
// This matches a registered dependency type, so it's not a normal
// param. Skip it.
@ -98,12 +111,14 @@ func (o *Operation) validate(deps map[reflect.Type]interface{}) error {
for i, paramType := range types {
if i == len(types)-1 && requestBody {
// The last item has no associated param.
s, err := GenerateSchema(paramType)
if err != nil {
return err
// The last item has no associated param. It is a request body.
if o.RequestSchema == nil {
s, err := GenerateSchema(paramType)
if err != nil {
return err
}
o.RequestSchema = s
}
o.RequestSchema = s
continue
}