mirror of
https://github.com/Fishwaldo/huma.git
synced 2025-03-15 19:31:27 +00:00
feat: new dependency model with params/headers
This commit is contained in:
parent
c98d6145e4
commit
fe8d73f51f
9 changed files with 566 additions and 213 deletions
33
README.md
33
README.md
|
@ -235,14 +235,16 @@ The standard `json` tag is supported and can be used to rename a field and mark
|
|||
|
||||
Huma includes a dependency injection system that can be used to pass additional arguments to operation handler functions. You can register global dependencies (ones that do not change from request to request) or contextual dependencies (ones that change with each request).
|
||||
|
||||
Global dependencies are created by just setting some value, while contextual dependencies are implemented using a function that returns the value of the form `func (dep *DependencyType...) (*YourType, error)` where the value you want injected is of `*YourType` and the function arguments can be any previously registered dependency types or one of the hard-coded types:
|
||||
Global dependencies are created by just setting some value, while contextual dependencies are implemented using a function that returns the value of the form `func (deps..., params...) (headers..., *YourType, error)` where the value you want injected is of `*YourType` and the function arguments can be any previously registered dependency types or one of the hard-coded types:
|
||||
|
||||
- `*gin.Context` the current context
|
||||
- `*huma.Operation` the current operation
|
||||
- `huma.ContextDependency()` the current context (returns `*gin.Context`)
|
||||
- `huma.OperationDependency()` the current operation (returns `*huma.Operation`)
|
||||
|
||||
```go
|
||||
// Register a new database connection dependency
|
||||
r.Dependency(db.NewConnection())
|
||||
db := &huma.Dependency{
|
||||
Value: db.NewConnection(),
|
||||
}
|
||||
|
||||
// Register a new request logger dependency. This is contextual because we
|
||||
// will print out the requester's IP address with each log message.
|
||||
|
@ -250,17 +252,22 @@ type MyLogger struct {
|
|||
Info: func(msg string),
|
||||
}
|
||||
|
||||
r.Dependency(func(c *gin.Context) (*MyLogger, error) {
|
||||
return &MyLogger{
|
||||
Info: func(msg string) {
|
||||
fmt.Printf("%s [ip:%s]\n", msg, c.Request.RemoteAddr)
|
||||
},
|
||||
}, nil
|
||||
})
|
||||
logger := &huma.Dependency{
|
||||
Depends: []*huma.Dependency{huma.ContextDependency()},
|
||||
Value: func(c *gin.Context) (*MyLogger, error) {
|
||||
return &MyLogger{
|
||||
Info: func(msg string) {
|
||||
fmt.Printf("%s [ip:%s]\n", msg, c.Request.RemoteAddr)
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
// Use them in any handler just by adding them as arguments!
|
||||
// Use them in any handler by adding them to both `Depends` and the list of
|
||||
// handler function arguments.
|
||||
r.Register(&huma.Operation{
|
||||
// ...
|
||||
Depends: []*huma.Dependency{db, logger},
|
||||
Handler: func(db *db.Connection, log *MyLogger) string {
|
||||
log.Info("test")
|
||||
item := db.Fetch("query")
|
||||
|
@ -269,7 +276,7 @@ r.Register(&huma.Operation{
|
|||
})
|
||||
```
|
||||
|
||||
Note that dependencies cannot be scalar types. Typically you would use a struct or interface like above. Global dependencies cannot be functions.
|
||||
Note that global dependencies cannot be functions. You can wrap them in a struct as a workaround.
|
||||
|
||||
## How it Works
|
||||
|
||||
|
|
244
dependency.go
244
dependency.go
|
@ -11,124 +11,200 @@ import (
|
|||
// ErrDependencyInvalid is returned when registering a dependency fails.
|
||||
var ErrDependencyInvalid = errors.New("dependency invalid")
|
||||
|
||||
// ErrDependencyNotFound is returned when the given type isn't registered
|
||||
// as a dependency.
|
||||
var ErrDependencyNotFound = errors.New("dependency not found")
|
||||
|
||||
// DependencyRegistry let's you register and resolve dependencies based on
|
||||
// their type.
|
||||
type DependencyRegistry struct {
|
||||
registry map[reflect.Type]interface{}
|
||||
// Dependency represents a handler function dependency and its associated
|
||||
// inputs and outputs. Value can be either a struct pointer (global dependency)
|
||||
// or a `func(dependencies, params) (headers, struct pointer, error)` style
|
||||
// function.
|
||||
type Dependency struct {
|
||||
Depends []*Dependency
|
||||
Params []*Param
|
||||
ResponseHeaders []*Header
|
||||
Value interface{}
|
||||
}
|
||||
|
||||
// NewDependencyRegistry creates a new blank dependency registry.
|
||||
func NewDependencyRegistry() *DependencyRegistry {
|
||||
return &DependencyRegistry{
|
||||
registry: make(map[reflect.Type]interface{}),
|
||||
}
|
||||
var contextDependency Dependency
|
||||
var operationDependency Dependency
|
||||
|
||||
// ContextDependency returns a dependency for the current request's
|
||||
// `*gin.Context`.
|
||||
func ContextDependency() *Dependency {
|
||||
return &contextDependency
|
||||
}
|
||||
|
||||
// Add a new dependency to the registry.
|
||||
func (dr *DependencyRegistry) Add(item interface{}) error {
|
||||
if dr.registry == nil {
|
||||
dr.registry = make(map[reflect.Type]interface{})
|
||||
// OperationDependency returns a dependency for the current `*huma.Operation`.
|
||||
func OperationDependency() *Dependency {
|
||||
return &operationDependency
|
||||
}
|
||||
|
||||
// validate that the dependency deps/params/headers match the function
|
||||
// signature or that the value is not a function.
|
||||
func (d *Dependency) validate(returnType reflect.Type) error {
|
||||
if d == &contextDependency || d == &operationDependency {
|
||||
// Hard-coded known dependencies. These are special and have no value.
|
||||
return nil
|
||||
}
|
||||
|
||||
val := reflect.ValueOf(item)
|
||||
outType := val.Type()
|
||||
if d.Value == nil {
|
||||
return fmt.Errorf("value must be set: %w", ErrDependencyInvalid)
|
||||
}
|
||||
|
||||
valType := val.Type()
|
||||
if val.Kind() == reflect.Func {
|
||||
for i := 0; i < valType.NumIn(); i++ {
|
||||
argType := valType.In(i)
|
||||
if argType.String() == "*gin.Context" || argType.String() == "*huma.Operation" {
|
||||
// Known hard-coded dependencies. Skip them.
|
||||
continue
|
||||
}
|
||||
v := reflect.ValueOf(d.Value)
|
||||
|
||||
if argType.Kind() != reflect.Ptr {
|
||||
return fmt.Errorf("should be pointer *%s: %w", argType, ErrDependencyInvalid)
|
||||
}
|
||||
|
||||
if _, ok := dr.registry[argType]; !ok {
|
||||
return fmt.Errorf("unknown dependency type %s, are dependencies defined in order? %w", argType, ErrDependencyNotFound)
|
||||
}
|
||||
if v.Kind() != reflect.Func {
|
||||
if returnType != nil && returnType != v.Type() {
|
||||
return fmt.Errorf("return type should be %s but got %s: %w", v.Type(), returnType, ErrDependencyInvalid)
|
||||
}
|
||||
|
||||
if val.Type().NumOut() != 2 || val.Type().Out(1).Name() != "error" {
|
||||
return fmt.Errorf("function should return (your-type, error): %w", ErrDependencyInvalid)
|
||||
// This is just a static value. It shouldn't have params/headers/etc.
|
||||
if len(d.Params) > 0 {
|
||||
return fmt.Errorf("global dependency should not have params: %w", ErrDependencyInvalid)
|
||||
}
|
||||
|
||||
outType = val.Type().Out(0)
|
||||
|
||||
if outType.Kind() != reflect.Ptr {
|
||||
return fmt.Errorf("should be pointer *%s: %w", outType, ErrDependencyInvalid)
|
||||
if len(d.ResponseHeaders) > 0 {
|
||||
return fmt.Errorf("global dependency should not set headers: %w", ErrDependencyInvalid)
|
||||
}
|
||||
|
||||
if _, ok := dr.registry[outType]; ok {
|
||||
return fmt.Errorf("duplicate type %s: %w", outType.String(), ErrDependencyInvalid)
|
||||
}
|
||||
} else {
|
||||
if valType.Kind() != reflect.Ptr {
|
||||
return fmt.Errorf("should be pointer *%s: %w", valType, ErrDependencyInvalid)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, ok := dr.registry[valType]; ok {
|
||||
return fmt.Errorf("duplicate type %s: %w", valType.String(), ErrDependencyInvalid)
|
||||
fn := v.Type()
|
||||
lenArgs := len(d.Depends) + len(d.Params)
|
||||
if fn.NumIn() != lenArgs {
|
||||
// TODO: generate suggested func signature
|
||||
return fmt.Errorf("function signature should have %d args but got %s: %w", lenArgs, fn, ErrDependencyInvalid)
|
||||
}
|
||||
|
||||
for _, dep := range d.Depends {
|
||||
if err := dep.validate(nil); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// To prevent mistakes we limit dependencies to non-scalar types, since
|
||||
// scalars like strings/numbers are typically used for params like headers.
|
||||
switch outType.Kind() {
|
||||
case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.String:
|
||||
return fmt.Errorf("dependeny cannot be scalar type %s: %w", outType.Kind(), ErrDependencyInvalid)
|
||||
for i, p := range d.Params {
|
||||
if err := validateParam(p, fn.In(len(d.Depends)+i)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
dr.registry[outType] = item
|
||||
lenReturn := len(d.ResponseHeaders) + 2
|
||||
if fn.NumOut() != lenReturn {
|
||||
return fmt.Errorf("function should return %d values but got %d: %w", lenReturn, fn.NumOut(), ErrDependencyInvalid)
|
||||
}
|
||||
|
||||
for i, h := range d.ResponseHeaders {
|
||||
if err := validateHeader(h, fn.Out(i)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get a resolved dependency from the registry.
|
||||
func (dr *DependencyRegistry) Get(op *Operation, c *gin.Context, t reflect.Type) (interface{}, error) {
|
||||
if t.String() == "*gin.Context" {
|
||||
// Special case: current gin context.
|
||||
return c, nil
|
||||
// AllParams returns all parameters for all dependencies in the graph of this
|
||||
// dependency in depth-first order without duplicates.
|
||||
func (d *Dependency) AllParams() []*Param {
|
||||
params := []*Param{}
|
||||
seen := map[*Param]bool{}
|
||||
|
||||
for _, p := range d.Params {
|
||||
seen[p] = true
|
||||
params = append(params, p)
|
||||
}
|
||||
|
||||
if t.String() == "*huma.Operation" {
|
||||
// Special case: current operation.
|
||||
return op, nil
|
||||
for _, d := range d.Depends {
|
||||
for _, p := range d.AllParams() {
|
||||
if _, ok := seen[p]; !ok {
|
||||
seen[p] = true
|
||||
|
||||
params = append(params, p)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if f, ok := dr.registry[t]; ok {
|
||||
// This argument matches a known registered dependency. If it's a
|
||||
// function, then call it, otherwise just return the value.
|
||||
vf := reflect.ValueOf(f)
|
||||
if vf.Kind() == reflect.Func {
|
||||
// Build the input argument list, which can consist of other dependencies.
|
||||
args := make([]reflect.Value, vf.Type().NumIn())
|
||||
return params
|
||||
}
|
||||
|
||||
for i := 0; i < vf.Type().NumIn(); i++ {
|
||||
v, err := dr.Get(op, c, vf.Type().In(i))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
args[i] = reflect.ValueOf(v)
|
||||
// AllResponseHeaders returns all response headers for all dependencies in
|
||||
// the graph of this dependency in depth-first order without duplicates.
|
||||
func (d *Dependency) AllResponseHeaders() []*Header {
|
||||
headers := []*Header{}
|
||||
seen := map[*Header]bool{}
|
||||
|
||||
for _, h := range d.ResponseHeaders {
|
||||
seen[h] = true
|
||||
headers = append(headers, h)
|
||||
}
|
||||
|
||||
for _, d := range d.Depends {
|
||||
for _, h := range d.AllResponseHeaders() {
|
||||
if _, ok := seen[h]; !ok {
|
||||
seen[h] = true
|
||||
|
||||
headers = append(headers, h)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
out := vf.Call(args)
|
||||
return headers
|
||||
}
|
||||
|
||||
if !out[1].IsNil() {
|
||||
return nil, out[1].Interface().(error)
|
||||
}
|
||||
return out[0].Interface(), nil
|
||||
// Resolve the value of the dependency. Returns (response headers, value, error).
|
||||
func (d *Dependency) Resolve(c *gin.Context, op *Operation) (map[string]string, interface{}, error) {
|
||||
// Identity dependencies are first. Just return if it's one of them.
|
||||
if d == &contextDependency {
|
||||
return nil, c, nil
|
||||
}
|
||||
|
||||
if d == &operationDependency {
|
||||
return nil, op, nil
|
||||
}
|
||||
|
||||
v := reflect.ValueOf(d.Value)
|
||||
if v.Kind() != reflect.Func {
|
||||
// Not a function, just return the global value.
|
||||
return nil, d.Value, nil
|
||||
}
|
||||
|
||||
// Generate the input arguments
|
||||
in := make([]reflect.Value, 0, v.Type().NumIn())
|
||||
headers := map[string]string{}
|
||||
|
||||
// Resolve each sub-dependency
|
||||
for _, dep := range d.Depends {
|
||||
dHeaders, dVal, err := dep.Resolve(c, op)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Not a function, just return the value.
|
||||
return f, nil
|
||||
for h, hv := range dHeaders {
|
||||
headers[h] = hv
|
||||
}
|
||||
|
||||
in = append(in, reflect.ValueOf(dVal))
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("%s: %w", t, ErrDependencyNotFound)
|
||||
// Get each input parameter
|
||||
for _, param := range d.Params {
|
||||
v, err := getParamValue(c, param)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
in = append(in, reflect.ValueOf(v))
|
||||
}
|
||||
|
||||
// Call the function.
|
||||
out := v.Call(in)
|
||||
|
||||
if last := out[len(out)-1]; !last.IsNil() {
|
||||
// There was an error!
|
||||
return nil, nil, last.Interface().(error)
|
||||
}
|
||||
|
||||
// Get the headers & response value.
|
||||
for i, h := range d.ResponseHeaders {
|
||||
headers[h.Name] = out[i].Interface().(string)
|
||||
}
|
||||
|
||||
return headers, out[len(d.ResponseHeaders)].Interface(), nil
|
||||
}
|
||||
|
|
|
@ -1,58 +1,128 @@
|
|||
package huma
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/alecthomas/assert"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestDependencyNested(t *testing.T) {
|
||||
type Dep1 struct{}
|
||||
type Dep2 struct{}
|
||||
type Dep3 struct{}
|
||||
func TestGlobalDepEmpty(t *testing.T) {
|
||||
d := Dependency{}
|
||||
|
||||
registry := NewDependencyRegistry()
|
||||
assert.NoError(t, registry.Add(&Dep1{}))
|
||||
typ := reflect.TypeOf(123)
|
||||
|
||||
assert.NoError(t, registry.Add(func(d1 *Dep1) (*Dep2, error) {
|
||||
return &Dep2{}, nil
|
||||
}))
|
||||
|
||||
assert.NoError(t, registry.Add(func(d1 *Dep1, d2 *Dep2) (*Dep3, error) {
|
||||
return &Dep3{}, nil
|
||||
}))
|
||||
assert.Error(t, d.validate(typ))
|
||||
}
|
||||
|
||||
func TestDependencyOrder(t *testing.T) {
|
||||
type Dep1 struct{}
|
||||
type Dep2 struct{}
|
||||
func TestGlobalDepWrongType(t *testing.T) {
|
||||
d := Dependency{
|
||||
Value: "test",
|
||||
}
|
||||
|
||||
registry := NewDependencyRegistry()
|
||||
typ := reflect.TypeOf(123)
|
||||
|
||||
assert.Error(t, registry.Add(func(d2 *Dep2) (*Dep1, error) {
|
||||
return &Dep1{}, nil
|
||||
}))
|
||||
assert.Error(t, d.validate(typ))
|
||||
}
|
||||
|
||||
func TestDependencyNotPointer(t *testing.T) {
|
||||
type Dep1 struct{}
|
||||
func TestGlobalDepParams(t *testing.T) {
|
||||
d := Dependency{
|
||||
Params: []*Param{
|
||||
HeaderParam("foo", "description", "hello"),
|
||||
},
|
||||
Value: "test",
|
||||
}
|
||||
|
||||
registry := NewDependencyRegistry()
|
||||
typ := reflect.TypeOf("test")
|
||||
|
||||
assert.Error(t, registry.Add(Dep1{}))
|
||||
assert.Error(t, registry.Add(func() (Dep1, error) {
|
||||
return Dep1{}, nil
|
||||
}))
|
||||
assert.Error(t, d.validate(typ))
|
||||
}
|
||||
|
||||
func TestDependencyDupe(t *testing.T) {
|
||||
type Dep1 struct{}
|
||||
func TestGlobalDepHeaders(t *testing.T) {
|
||||
d := Dependency{
|
||||
ResponseHeaders: []*Header{ResponseHeader("foo", "description")},
|
||||
Value: "test",
|
||||
}
|
||||
|
||||
registry := NewDependencyRegistry()
|
||||
typ := reflect.TypeOf("test")
|
||||
|
||||
assert.NoError(t, registry.Add(&Dep1{}))
|
||||
assert.Error(t, registry.Add(&Dep1{}))
|
||||
assert.Error(t, registry.Add(func() (*Dep1, error) {
|
||||
return nil, nil
|
||||
}))
|
||||
assert.Error(t, d.validate(typ))
|
||||
}
|
||||
|
||||
func TestDepContext(t *testing.T) {
|
||||
d := Dependency{
|
||||
Depends: []*Dependency{
|
||||
ContextDependency(),
|
||||
},
|
||||
Value: func(c *gin.Context) (*gin.Context, error) { return c, nil },
|
||||
}
|
||||
|
||||
mock := &gin.Context{}
|
||||
|
||||
typ := reflect.TypeOf(mock)
|
||||
assert.NoError(t, d.validate(typ))
|
||||
|
||||
_, v, err := d.Resolve(mock, &Operation{})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, v, mock)
|
||||
}
|
||||
|
||||
func TestDepOperation(t *testing.T) {
|
||||
d := Dependency{
|
||||
Depends: []*Dependency{
|
||||
OperationDependency(),
|
||||
},
|
||||
Value: func(o *Operation) (*Operation, error) { return o, nil },
|
||||
}
|
||||
|
||||
mock := &Operation{}
|
||||
|
||||
typ := reflect.TypeOf(mock)
|
||||
assert.NoError(t, d.validate(typ))
|
||||
|
||||
_, v, err := d.Resolve(&gin.Context{}, mock)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, v, mock)
|
||||
}
|
||||
func TestDepFuncWrongArgs(t *testing.T) {
|
||||
d := Dependency{
|
||||
Params: []*Param{
|
||||
HeaderParam("foo", "desc", ""),
|
||||
},
|
||||
Value: func() (string, error) {
|
||||
return "", nil
|
||||
},
|
||||
}
|
||||
|
||||
assert.Error(t, d.validate(reflect.TypeOf("")))
|
||||
}
|
||||
|
||||
func TestDepFunc(t *testing.T) {
|
||||
d := Dependency{
|
||||
Params: []*Param{
|
||||
HeaderParam("x-in", "desc", ""),
|
||||
},
|
||||
ResponseHeaders: []*Header{
|
||||
ResponseHeader("x-out", "desc"),
|
||||
},
|
||||
Value: func(xin string) (string, string, error) {
|
||||
return "xout", "value", nil
|
||||
},
|
||||
}
|
||||
|
||||
c := &gin.Context{
|
||||
Request: &http.Request{
|
||||
Header: http.Header{
|
||||
"x-in": []string{"xin"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
assert.NoError(t, d.validate(reflect.TypeOf("")))
|
||||
h, v, err := d.Resolve(c, &Operation{})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "xout", h["x-out"])
|
||||
assert.Equal(t, "value", v)
|
||||
}
|
||||
|
|
4
go.mod
4
go.mod
|
@ -4,13 +4,13 @@ go 1.13
|
|||
|
||||
require (
|
||||
github.com/Jeffail/gabs v1.4.0
|
||||
github.com/alecthomas/assert v0.0.0-20170929043011-405dbfeb8e38
|
||||
github.com/alecthomas/assert v0.0.0-20170929043011-405dbfeb8e38 // indirect
|
||||
github.com/alecthomas/colour v0.1.0 // indirect
|
||||
github.com/alecthomas/repr v0.0.0-20181024024818-d37bc2a10ba1 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1
|
||||
github.com/gin-gonic/gin v1.5.0
|
||||
github.com/gosimple/slug v1.9.0
|
||||
github.com/rs/zerolog v1.18.0
|
||||
github.com/rs/zerolog v1.18.0 // indirect
|
||||
github.com/sergi/go-diff v1.1.0 // indirect
|
||||
github.com/stretchr/testify v1.5.1
|
||||
github.com/xeipuuv/gojsonschema v1.2.0
|
||||
|
|
3
go.sum
3
go.sum
|
@ -25,8 +25,10 @@ github.com/gosimple/slug v1.9.0 h1:r5vDcYrFz9BmfIAMC829un9hq7hKM4cHUrsv36LbEqs=
|
|||
github.com/gosimple/slug v1.9.0/go.mod h1:AMZ+sOVe65uByN3kgEyf9WEBKBCSS+dJjMX9x4vDJbg=
|
||||
github.com/json-iterator/go v1.1.7 h1:KfgG9LzI+pYjr4xvmz/5H4FXjokeP+rlHLhv3iH62Fo=
|
||||
github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
||||
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/leodido/go-urn v1.1.0 h1:Sm1gr51B1kKyfD2BlRcLSiEkffoG96g6TPv6eRoEiB8=
|
||||
github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdAPozLkw=
|
||||
|
@ -75,6 +77,7 @@ golang.org/x/tools v0.0.0-20190828213141-aed303cbaa74/go.mod h1:b+2E5dAYhXwXZwtn
|
|||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/go-playground/assert.v1 v1.2.1 h1:xoYuJVE7KT85PYWrN730RguIQO0ePzVRfFMXadIrXTM=
|
||||
gopkg.in/go-playground/assert.v1 v1.2.1/go.mod h1:9RXL0bg/zibRAgZUYszZSwO/z8Y/a8bDuhia5mkpMnE=
|
||||
|
|
69
openapi.go
69
openapi.go
|
@ -153,6 +153,7 @@ type Operation struct {
|
|||
Summary string
|
||||
Description string
|
||||
Tags []string
|
||||
Depends []*Dependency
|
||||
Params []*Param
|
||||
RequestContentType string
|
||||
RequestSchema *Schema
|
||||
|
@ -161,6 +162,54 @@ type Operation struct {
|
|||
Handler interface{}
|
||||
}
|
||||
|
||||
// AllParams returns a list of all the parameters for this operation, including
|
||||
// those for dependencies.
|
||||
func (o *Operation) AllParams() []*Param {
|
||||
params := []*Param{}
|
||||
seen := map[*Param]bool{}
|
||||
|
||||
for _, p := range o.Params {
|
||||
seen[p] = true
|
||||
params = append(params, p)
|
||||
}
|
||||
|
||||
for _, d := range o.Depends {
|
||||
for _, p := range d.AllParams() {
|
||||
if _, ok := seen[p]; !ok {
|
||||
seen[p] = true
|
||||
|
||||
params = append(params, p)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return params
|
||||
}
|
||||
|
||||
// AllResponseHeaders returns a list of all the parameters for this operation,
|
||||
// including those for dependencies.
|
||||
func (o *Operation) AllResponseHeaders() []*Header {
|
||||
headers := []*Header{}
|
||||
seen := map[*Header]bool{}
|
||||
|
||||
for _, h := range o.ResponseHeaders {
|
||||
seen[h] = true
|
||||
headers = append(headers, h)
|
||||
}
|
||||
|
||||
for _, d := range o.Depends {
|
||||
for _, h := range d.AllResponseHeaders() {
|
||||
if _, ok := seen[h]; !ok {
|
||||
seen[h] = true
|
||||
|
||||
headers = append(headers, h)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return headers
|
||||
}
|
||||
|
||||
// Server describes an OpenAPI 3 API server location
|
||||
type Server struct {
|
||||
URL string `json:"url"`
|
||||
|
@ -189,6 +238,7 @@ type OpenAPI struct {
|
|||
Version string
|
||||
Servers []*Server
|
||||
Paths map[string][]*Operation
|
||||
// TODO: Depends []*Dependency
|
||||
}
|
||||
|
||||
// OpenAPIHandler returns a new handler function to generate an OpenAPI spec.
|
||||
|
@ -224,7 +274,7 @@ func OpenAPIHandler(api *OpenAPI) func(*gin.Context) {
|
|||
openapi.Set(op.Tags, "paths", path, method, "tags")
|
||||
}
|
||||
|
||||
for _, param := range op.Params {
|
||||
for _, param := range op.AllParams() {
|
||||
if param.internal {
|
||||
// Skip internal-only parameters.
|
||||
continue
|
||||
|
@ -260,7 +310,7 @@ func OpenAPIHandler(api *OpenAPI) func(*gin.Context) {
|
|||
}
|
||||
|
||||
headerMap := map[string]*Header{}
|
||||
for _, header := range op.ResponseHeaders {
|
||||
for _, header := range op.AllResponseHeaders() {
|
||||
headerMap[header.Name] = header
|
||||
}
|
||||
|
||||
|
@ -268,7 +318,22 @@ func OpenAPIHandler(api *OpenAPI) func(*gin.Context) {
|
|||
status := fmt.Sprintf("%v", resp.StatusCode)
|
||||
openapi.Set(resp.Description, "paths", path, method, "responses", status, "description")
|
||||
|
||||
headers := make([]string, 0, len(resp.Headers))
|
||||
seen := map[string]bool{}
|
||||
for _, name := range resp.Headers {
|
||||
headers = append(headers, name)
|
||||
seen[name] = true
|
||||
}
|
||||
for _, dep := range op.Depends {
|
||||
for _, header := range dep.AllResponseHeaders() {
|
||||
if _, ok := seen[header.Name]; !ok {
|
||||
headers = append(headers, header.Name)
|
||||
seen[header.Name] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, name := range headers {
|
||||
header := headerMap[name]
|
||||
openapi.Set(header, "paths", path, method, "responses", status, "headers", header.Name)
|
||||
}
|
||||
|
|
41
router.go
41
router.go
|
@ -125,7 +125,6 @@ func getRequestBody(c *gin.Context, t reflect.Type, op *Operation) (interface{},
|
|||
type Router struct {
|
||||
api *OpenAPI
|
||||
engine *gin.Engine
|
||||
deps *DependencyRegistry
|
||||
}
|
||||
|
||||
// NewRouter creates a new Huma router for handling API requests with
|
||||
|
@ -140,7 +139,6 @@ func NewRouterWithGin(engine *gin.Engine, api *OpenAPI) *Router {
|
|||
r := &Router{
|
||||
api: api,
|
||||
engine: engine,
|
||||
deps: NewDependencyRegistry(),
|
||||
}
|
||||
|
||||
if r.api.Paths == nil {
|
||||
|
@ -214,17 +212,17 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
|||
// })
|
||||
//
|
||||
// Panics on invalid input to force stop execution on service startup.
|
||||
func (r *Router) Dependency(f interface{}) {
|
||||
if err := r.deps.Add(f); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
// func (r *Router) Dependency(f interface{}) {
|
||||
// if err := r.deps.Add(f); err != nil {
|
||||
// panic(err)
|
||||
// }
|
||||
// }
|
||||
|
||||
// Register a new operation.
|
||||
func (r *Router) Register(op *Operation) {
|
||||
// First, make sure the operation and handler make sense, as well as pre-
|
||||
// generating any schemas for use later during request handling.
|
||||
if err := op.validate(r.deps.registry); err != nil {
|
||||
if err := op.validate(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
|
@ -263,22 +261,21 @@ func (r *Router) Register(op *Operation) {
|
|||
in := make([]reflect.Value, 0, method.Type().NumIn())
|
||||
|
||||
// Process any dependencies first.
|
||||
for i := 0; i < method.Type().NumIn(); i++ {
|
||||
argType := method.Type().In(i)
|
||||
v, err := r.deps.Get(op, c, argType)
|
||||
for _, dep := range op.Depends {
|
||||
headers, value, err := dep.Resolve(c, op)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrDependencyNotFound) {
|
||||
// No match, so we're done with dependencies. Keep going below
|
||||
// processing params.
|
||||
break
|
||||
}
|
||||
|
||||
// Getting the dependency value failed.
|
||||
// TODO: better error code/messaging?
|
||||
c.AbortWithError(500, err)
|
||||
return
|
||||
// TODO: better error handling
|
||||
c.AbortWithStatusJSON(500, ErrorModel{
|
||||
Message: "Couldn't get dependency",
|
||||
//Errors: []error{err},
|
||||
})
|
||||
}
|
||||
in = append(in, reflect.ValueOf(v))
|
||||
|
||||
for k, v := range headers {
|
||||
c.Header(k, v)
|
||||
}
|
||||
|
||||
in = append(in, reflect.ValueOf(value))
|
||||
}
|
||||
|
||||
for _, param := range op.Params {
|
||||
|
|
148
router_test.go
148
router_test.go
|
@ -67,6 +67,120 @@ func BenchmarkHuma(b *testing.B) {
|
|||
}
|
||||
}
|
||||
|
||||
func BenchmarkGinComplex(b *testing.B) {
|
||||
dep1 := "dep1"
|
||||
dep2 := func(c *gin.Context) string {
|
||||
_ = c.GetHeader("x-foo")
|
||||
return "dep2"
|
||||
}
|
||||
dep3 := func(c *gin.Context) (string, string) {
|
||||
return "xbar", "dep3"
|
||||
}
|
||||
|
||||
g := gin.New()
|
||||
g.GET("/hello", func(c *gin.Context) {
|
||||
_ = dep1
|
||||
_ = dep2(c)
|
||||
h, _ := dep3(c)
|
||||
|
||||
c.Header("x-bar", h)
|
||||
|
||||
name := c.Query("name")
|
||||
if name == "test" {
|
||||
c.JSON(400, &ErrorModel{
|
||||
Message: "Name cannot be test",
|
||||
})
|
||||
}
|
||||
if name == "" {
|
||||
name = "world"
|
||||
}
|
||||
|
||||
c.Header("x-baz", "xbaz")
|
||||
c.JSON(200, &helloResponse{
|
||||
Message: "Hello, " + name,
|
||||
})
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest(http.MethodGet, "/hello", nil)
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
g.ServeHTTP(w, req)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHumaComplex(b *testing.B) {
|
||||
r := NewRouterWithGin(gin.New(), &OpenAPI{
|
||||
Title: "Benchmark test",
|
||||
Version: "1.0.0",
|
||||
})
|
||||
|
||||
dep1 := &Dependency{
|
||||
Value: "dep1",
|
||||
}
|
||||
|
||||
dep2 := &Dependency{
|
||||
Depends: []*Dependency{ContextDependency(), dep1},
|
||||
Params: []*Param{
|
||||
HeaderParam("x-foo", "desc", ""),
|
||||
},
|
||||
Value: func(c *gin.Context, d1 string, xfoo string) (string, error) {
|
||||
return "dep2", nil
|
||||
},
|
||||
}
|
||||
|
||||
dep3 := &Dependency{
|
||||
Depends: []*Dependency{dep1},
|
||||
ResponseHeaders: []*Header{
|
||||
ResponseHeader("x-bar", "desc"),
|
||||
},
|
||||
Value: func(d1 string) (string, string, error) {
|
||||
return "xbar", "dep3", nil
|
||||
},
|
||||
}
|
||||
|
||||
r.Register(&Operation{
|
||||
Method: http.MethodGet,
|
||||
Path: "/hello",
|
||||
Description: "Greet the world",
|
||||
Depends: []*Dependency{
|
||||
ContextDependency(), dep2, dep3,
|
||||
},
|
||||
Params: []*Param{
|
||||
QueryParam("name", "desc", "world"),
|
||||
},
|
||||
ResponseHeaders: []*Header{
|
||||
ResponseHeader("x-baz", "desc"),
|
||||
},
|
||||
Responses: []*Response{
|
||||
ResponseJSON(200, "Return a greeting", "x-baz"),
|
||||
ResponseError(500, "desc"),
|
||||
},
|
||||
Handler: func(c *gin.Context, d2, d3, name string) (string, *helloResponse, *ErrorModel) {
|
||||
if name == "test" {
|
||||
return "", nil, &ErrorModel{
|
||||
Message: "Name cannot be test",
|
||||
}
|
||||
}
|
||||
|
||||
return "xbaz", &helloResponse{
|
||||
Message: "Hello, " + name,
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest(http.MethodGet, "/hello?name=Daniel", nil)
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
r.ServeHTTP(w, req)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter(t *testing.T) {
|
||||
type EchoResponse struct {
|
||||
Value string `json:"value" description:"The echoed back word"`
|
||||
|
@ -253,30 +367,38 @@ func TestRouterDependencies(t *testing.T) {
|
|||
Get func() string
|
||||
}
|
||||
|
||||
// Inject datastore as a global instance.
|
||||
r.Dependency(&DB{
|
||||
Get: func() string {
|
||||
return "Hello, "
|
||||
// Datastore is a global dependency, set by value.
|
||||
db := &Dependency{
|
||||
Value: &DB{
|
||||
Get: func() string {
|
||||
return "Hello, "
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
type Logger struct {
|
||||
Log func(msg string)
|
||||
}
|
||||
|
||||
// Inject logger as a contextual instance from the Gin context.
|
||||
r.Dependency(func(c *gin.Context) (*Logger, error) {
|
||||
return &Logger{
|
||||
Log: func(msg string) {
|
||||
fmt.Println(fmt.Sprintf("%s [uri:%s]", msg, c.FullPath()))
|
||||
},
|
||||
}, nil
|
||||
})
|
||||
// Logger is a contextual instance from the gin request context.
|
||||
log := &Dependency{
|
||||
Depends: []*Dependency{
|
||||
ContextDependency(),
|
||||
},
|
||||
Value: func(c *gin.Context) (*Logger, error) {
|
||||
return &Logger{
|
||||
Log: func(msg string) {
|
||||
fmt.Println(fmt.Sprintf("%s [uri:%s]", msg, c.FullPath()))
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
r.Register(&Operation{
|
||||
Method: http.MethodGet,
|
||||
Path: "/hello",
|
||||
Description: "Basic hello world",
|
||||
Depends: []*Dependency{ContextDependency(), db, log},
|
||||
Params: []*Param{
|
||||
QueryParam("name", "Your name", ""),
|
||||
},
|
||||
|
|
95
validate.go
95
validate.go
|
@ -13,10 +13,6 @@ import (
|
|||
// ErrFieldRequired is returned when a field is blank but has been required.
|
||||
var ErrFieldRequired = errors.New("field is required")
|
||||
|
||||
// ErrContextNotFirst is returned when a registered operation has a handler
|
||||
// that takes a context but it is not the first parameter of the function.
|
||||
var ErrContextNotFirst = errors.New("context should be first parameter")
|
||||
|
||||
// ErrParamsMustMatch is returned when a registered operation has a handler
|
||||
// function that takes the wrong number of arguments.
|
||||
var ErrParamsMustMatch = errors.New("handler function args must match registered params")
|
||||
|
@ -31,9 +27,40 @@ var ErrResponsesMustMatch = errors.New("handler function return values must matc
|
|||
|
||||
var paramRe = regexp.MustCompile(`:([^/]+)|{([^}]+)}`)
|
||||
|
||||
func validateParam(p *Param, t reflect.Type) error {
|
||||
p.typ = t
|
||||
|
||||
if p.Schema == nil {
|
||||
s, err := GenerateSchema(p.typ)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.Schema = s
|
||||
|
||||
if p.def != nil {
|
||||
p.Schema.Default = p.def
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateHeader(h *Header, t reflect.Type) error {
|
||||
if h.Schema == nil {
|
||||
// Generate the schema from the handler function types.
|
||||
s, err := GenerateSchema(t)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
h.Schema = s
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validate checks that the operation is well-formed (e.g. handler signature
|
||||
// matches the given params) and generates schemas if needed.
|
||||
func (o *Operation) validate(deps map[reflect.Type]interface{}) error {
|
||||
func (o *Operation) validate() error {
|
||||
if o.Method == "" {
|
||||
return fmt.Errorf("Method: %w", ErrFieldRequired)
|
||||
}
|
||||
|
@ -70,27 +97,32 @@ func (o *Operation) validate(deps map[reflect.Type]interface{}) error {
|
|||
o.Path = paramRe.ReplaceAllString(o.Path, ":$1$2")
|
||||
}
|
||||
|
||||
types := []reflect.Type{}
|
||||
for i := 0; i < method.Type().NumIn(); i++ {
|
||||
for i, dep := range o.Depends {
|
||||
paramType := method.Type().In(i)
|
||||
|
||||
if paramType.String() == "*gin.Context" || paramType.String() == "*huma.Operation" {
|
||||
// Known hard-coded dependencies. Skip them.
|
||||
continue
|
||||
}
|
||||
|
||||
if paramType.String() == "gin.Context" {
|
||||
return fmt.Errorf("gin context should be pointer *gin.Context: %w", ErrDependencyInvalid)
|
||||
return fmt.Errorf("gin.Context should be pointer *gin.Context: %w", ErrDependencyInvalid)
|
||||
}
|
||||
|
||||
if paramType.String() == "huma.Operation" {
|
||||
return fmt.Errorf("operation should be pointer *huma.Operation: %w", ErrDependencyInvalid)
|
||||
return fmt.Errorf("huma.Operation should be pointer *huma.Operation: %w", ErrDependencyInvalid)
|
||||
}
|
||||
|
||||
if _, ok := deps[paramType]; ok {
|
||||
// This matches a registered dependency type, so it's not a normal
|
||||
// param. Skip it.
|
||||
continue
|
||||
if err := dep.validate(paramType); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
types := []reflect.Type{}
|
||||
for i := len(o.Depends); i < method.Type().NumIn(); i++ {
|
||||
paramType := method.Type().In(i)
|
||||
|
||||
if paramType.String() == "gin.Context" {
|
||||
return fmt.Errorf("gin.Context should be pointer *gin.Context: %w", ErrDependencyInvalid)
|
||||
}
|
||||
|
||||
if paramType.String() == "huma.Operation" {
|
||||
return fmt.Errorf("huma.Operation should be pointer *huma.Operation: %w", ErrDependencyInvalid)
|
||||
}
|
||||
|
||||
types = append(types, paramType)
|
||||
|
@ -123,21 +155,8 @@ func (o *Operation) validate(deps map[reflect.Type]interface{}) error {
|
|||
}
|
||||
|
||||
p := o.Params[i]
|
||||
p.typ = paramType
|
||||
if p.Schema == nil {
|
||||
// Auto-generate a schema for this parameter
|
||||
s, err := GenerateSchema(paramType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.Schema = s
|
||||
|
||||
if p.def != nil {
|
||||
if reflect.ValueOf(p.def).Type() != paramType {
|
||||
|
||||
}
|
||||
p.Schema.Default = p.def
|
||||
}
|
||||
if err := validateParam(p, paramType); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -148,14 +167,8 @@ func (o *Operation) validate(deps map[reflect.Type]interface{}) error {
|
|||
}
|
||||
|
||||
for i, header := range o.ResponseHeaders {
|
||||
if header.Schema == nil {
|
||||
// Generate the schema from the handler function types.
|
||||
headerType := method.Type().Out(i)
|
||||
s, err := GenerateSchema(headerType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
header.Schema = s
|
||||
if err := validateHeader(header, method.Type().Out(i)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue