feat: functional options for resources/operations

This commit is contained in:
Daniel G. Taylor 2020-04-13 22:47:40 -07:00
parent 74d7661b0d
commit 9a5d037f99
No known key found for this signature in database
GPG key ID: 7BD6DC99C9A87E22
18 changed files with 1229 additions and 1200 deletions

101
README.md
View file

@ -294,29 +294,6 @@ notes.With(
).Get("Get a list of all notes", func () []*NoteSummary {
// Implementation goes here
})
// The above idiom is common enough when needing to change response codes
// or allow certain response headers that there is a shortcut:
notes.
JSON(http.StatusCreated, "Success", "expires").
Get("Get a list of all notes", func () []*NoteSummary {
// Implementation goes here
})
```
Alternatively you can provide a `*huma.Operation` instance to the resource if you want more flexibility or prefer this style over chaining:
```go
// Create the operation
notes.Operation(http.MethodGet, &huma.Operation{
Description: "Get a list of all notes"
Responses: []*huma.Response{
huma.ResponseJSON(http.StatusOK, "List of notes", "expires"),
},
Handler: func () []*NoteSummary {
// Implementation goes here
}
})
```
> :whale: Operations map an HTTP action verb to a resource. You might `POST` a new note or `GET` a user. Sometimes the mapping is less obvious and you can consider using a sub-resource. For example, rather than unliking a post, maybe you `DELETE` the `/posts/{id}/likes` resource.
@ -371,15 +348,17 @@ Get("Get a note by its ID", func(id string) (*huma.ErrorModel, *Note) {
You can also declare parameters with additional validation logic:
```go
huma.PathParam("id", "Note ID", &huma.Schema{
s := &schema.Schema{
MinLength: 1,
MaxLength: 32,
})
}
huma.PathParam("id", "Note ID", huma.Schema(s))
```
Once a parameter is declared it will get parsed, validated, and then sent to your handler function. If parsing or validation fails, the client gets a 400-level HTTP error.
> :whale: If a proxy is providing e.g. authentication or rate-limiting and exposes additional internal-only information then use the internal parameters like `huma.HeaderParamInternal("UserID", "Parsed user from the auth system", "nobody")`. Internal parameters are never included in the generated OpenAPI 3 spec or documentation.
> :whale: If a proxy is providing e.g. authentication or rate-limiting and exposes additional internal-only information then use the internal parameters like `huma.HeaderParam("UserID", "Parsed user from the auth system", "nobody", huma.Internal())`. Internal parameters are never included in the generated OpenAPI 3 spec or documentation.
## Request & Response Models
@ -433,11 +412,11 @@ Get("Description", func() (*huma.ErrorModel, *Note) {
Whichever model is not `nil` will get sent back to the client.
Empty responses, e.g. a `204 No Content` or `304 Not Modified` are also supported. Use `huma.ResponseEmpty` paired with a simple boolean to return a response without a body. Passing `false` acts like `nil` for models and prevents that response from being sent.
Empty responses, e.g. a `204 No Content` or `304 Not Modified` are also supported by setting a `ContentType` of `""`. Use `huma.Response` paired with a simple boolean to return a response without a body. Passing `false` acts like `nil` for models and prevents that response from being sent.
```go
r.Resource("/notes",
huma.ResponseEmpty(http.StatusNoContent, "This should have no body")).
huma.Response(http.StatusNoContent, "This should have no body")).
Get("description", func() bool {
return true
})
@ -490,8 +469,8 @@ For example:
```go
r.Resource("/notes",
huma.Header("expires", "Expiration date for this content"),
huma.ResponseText(http.StatusOK, "Success", "expires")).
Get("description", func() (string, string) {
huma.ResponseText(http.StatusOK, "Success", huma.Headers("expires"))
).Get("description", func() (string, string) {
expires := time.Now().Add(7 * 24 * time.Hour).MarshalText()
return expires, "Hello!"
})
@ -521,9 +500,7 @@ Global dependencies are created by just setting some value, while contextual dep
```go
// Register a new database connection dependency
db := &huma.Dependency{
Value: db.NewConnection(),
}
db := huma.SimpleDependency(db.NewConnection())
// Register a new request logger dependency. This is contextual because we
// will print out the requester's IP address with each log message.
@ -531,27 +508,25 @@ type MyLogger struct {
Info: func(msg string),
}
logger := &huma.Dependency{
Dependencies: []*huma.Dependency{huma.ContextDependency()},
Value: func(c *gin.Context) (*MyLogger, error) {
logger := huma.Dependency(
huma.GinContextDependency(),
func(c *gin.Context) (*MyLogger, error) {
return &MyLogger{
Info: func(msg string) {
fmt.Printf("%s [ip:%s]\n", msg, c.Request.RemoteAddr)
},
}, nil
},
}
)
// Use them in any handler by adding them to both `Depends` and the list of
// handler function arguments.
r.Resource("/foo").Operation(http.MethodGet, &huma.Operation{
// ...
Dependencies: []*huma.Dependency{db, logger},
Handler: func(db *db.Connection, log *MyLogger) string {
log.Info("test")
item := db.Fetch("query")
return item.ID
}
r.Resource("/foo").With(
db, logger
).Get("doc", func(db *db.Connection, log *MyLogger) string {
log.Info("test")
item := db.Fetch("query")
return item.ID
})
```
@ -624,7 +599,7 @@ r.Resource("/timeout",
### Request Body Timeouts
By default any handler which takes in a request body parameter will have a read timeout of 15 seconds set on it. If set to nonzero for a handler which does **not** take a body, then the timeout will be set on the underlying connection before calling your handler. The timeout value is configurable at the resource and operation level.
By default any handler which takes in a request body parameter will have a read timeout of 15 seconds set on it. If set to nonzero for a handler which does **not** take a body, then the timeout will be set on the underlying connection before calling your handler.
When triggered, the server sends a 408 Request Timeout as JSON with a message containing the time waited.
@ -635,21 +610,11 @@ type Input struct {
r := huma.NewRouter("My API", "1.0.0")
// Resource-level limit to 5 seconds
r.Resource("/foo").BodyReadTimeout(5 * time.Second).Post(
// Limit to 5 seconds
r.Resource("/foo", huma.BodyReadTimeout(5 * time.Second)).Post(
"Create item", func(input *Input) string {
return "Hello, " + input.ID
})
// Operation-level limit
r.Resource("/foo").Operation(http.MethodPost, &huma.Operation{
// ...
BodyReadTimeout: 5 * time.Second,
Handler: func(input *Input) string {
return "Hello, " + input.ID
},
// ...
})
```
You can also access the underlying TCP connection and set deadlines manually:
@ -675,22 +640,15 @@ r.Resource("/foo", huma.GinContextDependency()).Get(func (c *gin.Context) string
### Request Body Size Limits
By default each operation has a 1 MiB reqeuest body size limit. This value is configurable at the resource and operation level.
By default each operation has a 1 MiB reqeuest body size limit.
When triggered, the server sends a 413 Request Entity Too Large as JSON with a message containing the maximum body size for this operation.
```go
r := huma.NewRouter("My API", "1.0.0")
// Resource-level limit set to 10 MiB
r.Resource("/foo").MaxBodyBytes(10 * 1024 * 1024).Get(...)
// Operation-level limit
r.Resource("/foo").Operation(http.MethodGet, &huma.Operation{
// ...
MaxBodyBytes: 10 * 1024 * 1024,
// ...
})
// Limit set to 10 MiB
r.Resource("/foo", MaxBodyBytes(10 * 1024 * 1024)).Get(...)
```
> :whale: Set to `-1` in order to disable the check, allowing for unlimited request body size for e.g. large streaming file uploads.
@ -702,8 +660,7 @@ Huma provides a Zap-based contextual structured logger built-in. You can access
```go
r.Resource("/test",
huma.LogDependency(),
huma.ResponseText(http.StatusOK, "Successful")).
Get("Logger test", func(log *zap.SugaredLogger) string {
).Get("Logger test", func(log *zap.SugaredLogger) string {
log.Info("I'm using the logger!")
return "Hello, world"
})
@ -818,10 +775,10 @@ You can access the root `cobra.Command` via `r.Root()` and add new custom comman
## Middleware
You can make use of any Gin-compatible middleware via the `Middleware()` router option.
You can make use of any Gin-compatible middleware via the `GinMiddleware()` router option.
```go
r := huma.NewRouter("My API", "1.0.0", huma.Middleware(gin.Logger()))
r := huma.NewRouter("My API", "1.0.0", huma.GinMiddleware(gin.Logger()))
```
## HTTP/2 Setup

View file

@ -26,19 +26,17 @@ func main() {
r := huma.NewRouter("Benchmark", "1.0.0", huma.WithGin(g))
d := &huma.Dependency{
Params: []*huma.Param{
huma.HeaderParam("authorization", "Auth header", ""),
},
Value: func(auth string) (string, error) {
d := huma.Dependency(
huma.HeaderParam("authorization", "Auth header", ""),
func(auth string) (string, error) {
return strings.Split(auth, " ")[0], nil
},
}
)
r.Resource("/items", d,
huma.PathParam("id", "The item's unique ID"),
huma.Header("x-authinfo", "..."),
huma.ResponseJSON(http.StatusOK, "Successful hello response", "x-authinfo"),
huma.ResponseHeader("x-authinfo", "..."),
huma.ResponseJSON(http.StatusOK, "Successful hello response", huma.Headers("x-authinfo")),
).Get("Huma benchmark test", func(authInfo string, id int) (string, *Item) {
return authInfo, &Item{
ID: id,

View file

@ -11,51 +11,94 @@ import (
// ErrDependencyInvalid is returned when registering a dependency fails.
var ErrDependencyInvalid = errors.New("dependency invalid")
// Dependency represents a handler function dependency and its associated
// OpenAPIDependency represents a handler function dependency and its associated
// inputs and outputs. Value can be either a struct pointer (global dependency)
// or a `func(dependencies, params) (headers, struct pointer, error)` style
// function.
type Dependency struct {
Dependencies []*Dependency
Params []*Param
ResponseHeaders []*ResponseHeader
Value interface{}
type OpenAPIDependency struct {
dependencies []*OpenAPIDependency
params []*OpenAPIParam
responseHeaders []*OpenAPIResponseHeader
handler interface{}
}
var contextDependency Dependency
var ginContextDependency Dependency
var operationDependency Dependency
// Dependencies returns the dependencies associated with this dependency.
func (d *OpenAPIDependency) Dependencies() []*OpenAPIDependency {
return d.dependencies
}
// Params returns the params associated with this dependency.
func (d *OpenAPIDependency) Params() []*OpenAPIParam {
return d.params
}
// ResponseHeaders returns the params associated with this dependency.
func (d *OpenAPIDependency) ResponseHeaders() []*OpenAPIResponseHeader {
return d.responseHeaders
}
// NewSimpleDependency returns a dependency with a function or value.
func NewSimpleDependency(value interface{}) *OpenAPIDependency {
return NewDependency(nil, value)
}
// NewDependency returns a dependency with the given option and a handler
// function.
func NewDependency(option DependencyOption, handler interface{}) *OpenAPIDependency {
d := &OpenAPIDependency{
dependencies: make([]*OpenAPIDependency, 0),
params: make([]*OpenAPIParam, 0),
responseHeaders: make([]*OpenAPIResponseHeader, 0),
handler: handler,
}
if option != nil {
option.ApplyDependency(d)
}
return d
}
var contextDependency OpenAPIDependency
var ginContextDependency OpenAPIDependency
var operationDependency OpenAPIDependency
// ContextDependency returns a dependency for the current request's
// `context.Context`. This is useful for timeouts & cancellation.
func ContextDependency() *Dependency {
return &contextDependency
func ContextDependency() DependencyOption {
return &dependencyOption{func(d *OpenAPIDependency) {
d.dependencies = append(d.dependencies, &contextDependency)
}}
}
// GinContextDependency returns a dependency for the current request's
// `*gin.Context`.
func GinContextDependency() *Dependency {
return &ginContextDependency
func GinContextDependency() DependencyOption {
return &dependencyOption{func(d *OpenAPIDependency) {
d.dependencies = append(d.dependencies, &ginContextDependency)
}}
}
// OperationDependency returns a dependency for the current `*huma.Operation`.
func OperationDependency() *Dependency {
return &operationDependency
func OperationDependency() DependencyOption {
return &dependencyOption{func(d *OpenAPIDependency) {
d.dependencies = append(d.dependencies, &operationDependency)
}}
}
// validate that the dependency deps/params/headers match the function
// signature or that the value is not a function.
func (d *Dependency) validate(returnType reflect.Type) {
func (d *OpenAPIDependency) validate(returnType reflect.Type) {
if d == &contextDependency || d == &ginContextDependency || d == &operationDependency {
// Hard-coded known dependencies. These are special and have no value.
return
}
if d.Value == nil {
panic(fmt.Errorf("value must be set: %w", ErrDependencyInvalid))
if d.handler == nil {
panic(fmt.Errorf("handler must be set: %w", ErrDependencyInvalid))
}
v := reflect.ValueOf(d.Value)
v := reflect.ValueOf(d.handler)
if v.Kind() != reflect.Func {
if returnType != nil && returnType != v.Type() {
@ -63,11 +106,11 @@ func (d *Dependency) validate(returnType reflect.Type) {
}
// This is just a static value. It shouldn't have params/headers/etc.
if len(d.Params) > 0 {
if len(d.params) > 0 {
panic(fmt.Errorf("global dependency should not have params: %w", ErrDependencyInvalid))
}
if len(d.ResponseHeaders) > 0 {
if len(d.responseHeaders) > 0 {
panic(fmt.Errorf("global dependency should not set headers: %w", ErrDependencyInvalid))
}
@ -75,43 +118,43 @@ func (d *Dependency) validate(returnType reflect.Type) {
}
fn := v.Type()
lenArgs := len(d.Dependencies) + len(d.Params)
lenArgs := len(d.dependencies) + len(d.params)
if fn.NumIn() != lenArgs {
// TODO: generate suggested func signature
panic(fmt.Errorf("function signature should have %d args but got %s: %w", lenArgs, fn, ErrDependencyInvalid))
}
for _, dep := range d.Dependencies {
for _, dep := range d.dependencies {
dep.validate(nil)
}
for i, p := range d.Params {
p.validate(fn.In(len(d.Dependencies) + i))
for i, p := range d.params {
p.validate(fn.In(len(d.dependencies) + i))
}
lenReturn := len(d.ResponseHeaders) + 2
lenReturn := len(d.responseHeaders) + 2
if fn.NumOut() != lenReturn {
panic(fmt.Errorf("function should return %d values but got %d: %w", lenReturn, fn.NumOut(), ErrDependencyInvalid))
}
for i, h := range d.ResponseHeaders {
for i, h := range d.responseHeaders {
h.validate(fn.Out(i))
}
}
// AllParams returns all parameters for all dependencies in the graph of this
// allParams returns all parameters for all dependencies in the graph of this
// dependency in depth-first order without duplicates.
func (d *Dependency) AllParams() []*Param {
params := []*Param{}
seen := map[*Param]bool{}
func (d *OpenAPIDependency) allParams() []*OpenAPIParam {
params := []*OpenAPIParam{}
seen := map[*OpenAPIParam]bool{}
for _, p := range d.Params {
for _, p := range d.params {
seen[p] = true
params = append(params, p)
}
for _, d := range d.Dependencies {
for _, p := range d.AllParams() {
for _, d := range d.dependencies {
for _, p := range d.allParams() {
if _, ok := seen[p]; !ok {
seen[p] = true
@ -123,19 +166,19 @@ func (d *Dependency) AllParams() []*Param {
return params
}
// AllResponseHeaders returns all response headers for all dependencies in
// allResponseHeaders returns all response headers for all dependencies in
// the graph of this dependency in depth-first order without duplicates.
func (d *Dependency) AllResponseHeaders() []*ResponseHeader {
headers := []*ResponseHeader{}
seen := map[*ResponseHeader]bool{}
func (d *OpenAPIDependency) allResponseHeaders() []*OpenAPIResponseHeader {
headers := []*OpenAPIResponseHeader{}
seen := map[*OpenAPIResponseHeader]bool{}
for _, h := range d.ResponseHeaders {
for _, h := range d.responseHeaders {
seen[h] = true
headers = append(headers, h)
}
for _, d := range d.Dependencies {
for _, h := range d.AllResponseHeaders() {
for _, d := range d.dependencies {
for _, h := range d.allResponseHeaders() {
if _, ok := seen[h]; !ok {
seen[h] = true
@ -147,8 +190,8 @@ func (d *Dependency) AllResponseHeaders() []*ResponseHeader {
return headers
}
// Resolve the value of the dependency. Returns (response headers, value, error).
func (d *Dependency) Resolve(c *gin.Context, op *Operation) (map[string]string, interface{}, error) {
// resolve the value of the dependency. Returns (response headers, value, error).
func (d *OpenAPIDependency) resolve(c *gin.Context, op *OpenAPIOperation) (map[string]string, interface{}, error) {
// Identity dependencies are first. Just return if it's one of them.
if d == &contextDependency {
return nil, c.Request.Context(), nil
@ -162,10 +205,10 @@ func (d *Dependency) Resolve(c *gin.Context, op *Operation) (map[string]string,
return nil, op, nil
}
v := reflect.ValueOf(d.Value)
v := reflect.ValueOf(d.handler)
if v.Kind() != reflect.Func {
// Not a function, just return the global value.
return nil, d.Value, nil
return nil, d.handler, nil
}
// Generate the input arguments
@ -173,8 +216,8 @@ func (d *Dependency) Resolve(c *gin.Context, op *Operation) (map[string]string,
headers := map[string]string{}
// Resolve each sub-dependency
for _, dep := range d.Dependencies {
dHeaders, dVal, err := dep.Resolve(c, op)
for _, dep := range d.dependencies {
dHeaders, dVal, err := dep.resolve(c, op)
if err != nil {
return nil, nil, err
}
@ -187,7 +230,7 @@ func (d *Dependency) Resolve(c *gin.Context, op *Operation) (map[string]string,
}
// Get each input parameter
for _, param := range d.Params {
for _, param := range d.params {
v, ok := getParamValue(c, param)
if !ok {
return nil, nil, fmt.Errorf("could not get param value")
@ -205,9 +248,9 @@ func (d *Dependency) Resolve(c *gin.Context, op *Operation) (map[string]string,
}
// Get the headers & response value.
for i, h := range d.ResponseHeaders {
for i, h := range d.responseHeaders {
headers[h.Name] = out[i].Interface().(string)
}
return headers, out[len(d.ResponseHeaders)].Interface(), nil
return headers, out[len(d.responseHeaders)].Interface(), nil
}

View file

@ -12,7 +12,7 @@ import (
)
func TestGlobalDepEmpty(t *testing.T) {
d := Dependency{}
d := OpenAPIDependency{}
typ := reflect.TypeOf(123)
@ -22,8 +22,8 @@ func TestGlobalDepEmpty(t *testing.T) {
}
func TestGlobalDepWrongType(t *testing.T) {
d := Dependency{
Value: "test",
d := OpenAPIDependency{
handler: "test",
}
typ := reflect.TypeOf(123)
@ -34,13 +34,12 @@ func TestGlobalDepWrongType(t *testing.T) {
}
func TestGlobalDepParams(t *testing.T) {
d := Dependency{
Params: []*Param{
HeaderParam("foo", "description", "hello"),
},
Value: "test",
d := OpenAPIDependency{
handler: "test",
}
HeaderParam("foo", "description", "hello").ApplyDependency(&d)
typ := reflect.TypeOf("test")
assert.Panics(t, func() {
@ -49,11 +48,12 @@ func TestGlobalDepParams(t *testing.T) {
}
func TestGlobalDepHeaders(t *testing.T) {
d := Dependency{
ResponseHeaders: []*ResponseHeader{Header("foo", "description")},
Value: "test",
d := OpenAPIDependency{
handler: "test",
}
ResponseHeader("foo", "description").ApplyDependency(&d)
typ := reflect.TypeOf("test")
assert.Panics(t, func() {
@ -62,11 +62,11 @@ func TestGlobalDepHeaders(t *testing.T) {
}
func TestDepContext(t *testing.T) {
d := Dependency{
Dependencies: []*Dependency{
ContextDependency(),
d := OpenAPIDependency{
dependencies: []*OpenAPIDependency{
&contextDependency,
},
Value: func(ctx context.Context) (context.Context, error) { return ctx, nil },
handler: func(ctx context.Context) (context.Context, error) { return ctx, nil },
}
mock, _ := gin.CreateTestContext(nil)
@ -75,17 +75,17 @@ func TestDepContext(t *testing.T) {
typ := reflect.TypeOf(mock)
d.validate(typ)
_, v, err := d.Resolve(mock, &Operation{})
_, v, err := d.resolve(mock, &OpenAPIOperation{})
assert.NoError(t, err)
assert.Equal(t, v, mock.Request.Context())
}
func TestDepGinContext(t *testing.T) {
d := Dependency{
Dependencies: []*Dependency{
GinContextDependency(),
d := OpenAPIDependency{
dependencies: []*OpenAPIDependency{
&ginContextDependency,
},
Value: func(c *gin.Context) (*gin.Context, error) { return c, nil },
handler: func(c *gin.Context) (*gin.Context, error) { return c, nil },
}
mock, _ := gin.CreateTestContext(nil)
@ -93,56 +93,54 @@ func TestDepGinContext(t *testing.T) {
typ := reflect.TypeOf(mock)
d.validate(typ)
_, v, err := d.Resolve(mock, &Operation{})
_, v, err := d.resolve(mock, &OpenAPIOperation{})
assert.NoError(t, err)
assert.Equal(t, v, mock)
}
func TestDepOperation(t *testing.T) {
d := Dependency{
Dependencies: []*Dependency{
OperationDependency(),
d := OpenAPIDependency{
dependencies: []*OpenAPIDependency{
&operationDependency,
},
Value: func(o *Operation) (*Operation, error) { return o, nil },
handler: func(o *OpenAPIOperation) (*OpenAPIOperation, error) { return o, nil },
}
mock := &Operation{}
mock := &OpenAPIOperation{}
typ := reflect.TypeOf(mock)
d.validate(typ)
_, v, err := d.Resolve(&gin.Context{}, mock)
_, v, err := d.resolve(&gin.Context{}, mock)
assert.NoError(t, err)
assert.Equal(t, v, mock)
}
func TestDepFuncWrongArgs(t *testing.T) {
d := Dependency{
Params: []*Param{
HeaderParam("foo", "desc", ""),
},
Value: func() (string, error) {
d := OpenAPIDependency{
handler: func() (string, error) {
return "", nil
},
}
HeaderParam("foo", "desc", "").ApplyDependency(&d)
assert.Panics(t, func() {
d.validate(reflect.TypeOf(""))
})
}
func TestDepFunc(t *testing.T) {
d := Dependency{
Params: []*Param{
HeaderParam("x-in", "desc", ""),
},
ResponseHeaders: []*ResponseHeader{
Header("x-out", "desc"),
},
Value: func(xin string) (string, string, error) {
d := OpenAPIDependency{
handler: func(xin string) (string, string, error) {
return "xout", "value", nil
},
}
DependencyOptions(
HeaderParam("x-in", "desc", ""),
ResponseHeader("x-out", "desc"),
).ApplyDependency(&d)
c := &gin.Context{
Request: &http.Request{
Header: http.Header{
@ -152,7 +150,7 @@ func TestDepFunc(t *testing.T) {
}
d.validate(reflect.TypeOf(""))
h, v, err := d.Resolve(c, &Operation{})
h, v, err := d.resolve(c, &OpenAPIOperation{})
assert.NoError(t, err)
assert.Equal(t, "xout", h["x-out"])
assert.Equal(t, "value", v)

View file

@ -6,6 +6,7 @@ import (
"time"
"github.com/danielgtaylor/huma"
"github.com/danielgtaylor/huma/schema"
)
// NoteSummary is used to list notes. It does not include the (potentially)
@ -46,9 +47,11 @@ func main() {
})
// Add an `id` path parameter to create a note resource.
note := notes.With(huma.PathParam("id", "Note ID", &huma.Schema{
Pattern: "^[a-zA-Z0-9._-]{1,32}$",
}))
note := notes.With(huma.PathParam("id", "Note ID",
huma.Schema(&schema.Schema{
Pattern: "^[a-zA-Z0-9._-]{1,32}$",
}),
))
notFound := huma.ResponseError(http.StatusNotFound, "Note not found")

View file

@ -24,18 +24,21 @@ var logLevel *zap.AtomicLevel
// panic when using the recovery middleware. Defaults to 10KB.
var MaxLogBodyBytes int64 = 10 * 1024
// BufferedReadCloser will read and buffer up to max bytes into buf. Additional
// Middleware TODO ...
type Middleware = gin.HandlerFunc
// bufferedReadCloser will read and buffer up to max bytes into buf. Additional
// reads bypass the buffer.
type BufferedReadCloser struct {
type bufferedReadCloser struct {
reader io.ReadCloser
buf *bytes.Buffer
max int64
}
// NewBufferedReadCloser returns a new BufferedReadCloser that wraps reader
// newBufferedReadCloser returns a new BufferedReadCloser that wraps reader
// and reads up to max bytes into the buffer.
func NewBufferedReadCloser(reader io.ReadCloser, buffer *bytes.Buffer, max int64) *BufferedReadCloser {
return &BufferedReadCloser{
func newBufferedReadCloser(reader io.ReadCloser, buffer *bytes.Buffer, max int64) *bufferedReadCloser {
return &bufferedReadCloser{
reader: reader,
buf: buffer,
max: max,
@ -43,7 +46,7 @@ func NewBufferedReadCloser(reader io.ReadCloser, buffer *bytes.Buffer, max int64
}
// Read data into p. Returns number of bytes read and an error, if any.
func (r *BufferedReadCloser) Read(p []byte) (n int, err error) {
func (r *bufferedReadCloser) Read(p []byte) (n int, err error) {
// Read from the underlying reader like normal.
n, err = r.reader.Read(p)
@ -61,12 +64,12 @@ func (r *BufferedReadCloser) Read(p []byte) (n int, err error) {
}
// Close the underlying reader.
func (r *BufferedReadCloser) Close() error {
func (r *bufferedReadCloser) Close() error {
return r.reader.Close()
}
// Recovery prints stack traces on panic when used with the logging middleware.
func Recovery() func(*gin.Context) {
func Recovery() Middleware {
bufPool := sync.Pool{
New: func() interface{} {
return new(bytes.Buffer)
@ -82,7 +85,7 @@ func Recovery() func(*gin.Context) {
buf = bufPool.Get().(*bytes.Buffer)
defer bufPool.Put(buf)
c.Request.Body = NewBufferedReadCloser(c.Request.Body, buf, MaxLogBodyBytes)
c.Request.Body = newBufferedReadCloser(c.Request.Body, buf, MaxLogBodyBytes)
}
// Recovering comes *after* the above so the buffer is not returned to
@ -137,7 +140,7 @@ func NewLogger() (*zap.Logger, error) {
// Gin context under the `log` key. It debug logs request info. If passed `nil`
// for the logger, then it creates one. If the current terminal is a TTY, it
// will try to use colored output automatically.
func LogMiddleware(l *zap.Logger, tags map[string]string) func(*gin.Context) {
func LogMiddleware(l *zap.Logger, tags map[string]string) Middleware {
var err error
if l == nil {
if l, err = NewLogger(); err != nil {
@ -180,19 +183,22 @@ func LogMiddleware(l *zap.Logger, tags map[string]string) func(*gin.Context) {
// LogDependency returns a dependency that resolves to a `*zap.SugaredLogger`
// for the current request. This dependency *requires* the use of
// `LogMiddleware` and will error if the logger is not in the request context.
func LogDependency() *Dependency {
return &Dependency{
Dependencies: []*Dependency{ContextDependency(), OperationDependency()},
Value: func(c *gin.Context, op *Operation) (*zap.SugaredLogger, error) {
l, ok := c.Get("log")
if !ok {
return nil, fmt.Errorf("missing logger in context")
}
sl := l.(*zap.SugaredLogger).With("operation", op.ID)
sl.Desugar()
return sl, nil
},
}
func LogDependency() DependencyOption {
dep := NewDependency(DependencyOptions(
GinContextDependency(),
OperationDependency(),
), func(c *gin.Context, op *OpenAPIOperation) (*zap.SugaredLogger, error) {
l, ok := c.Get("log")
if !ok {
return nil, fmt.Errorf("missing logger in context")
}
sl := l.(*zap.SugaredLogger).With("operation", op.id)
return sl, nil
})
return &dependencyOption{func(d *OpenAPIDependency) {
d.dependencies = append(d.dependencies, dep)
}}
}
// Handler404 will return JSON responses for 404 errors.
@ -226,7 +232,7 @@ func (w *minimalWriter) WriteHeader(statusCode int) {
// PreferMinimalMiddleware will remove the response body and return 204 No
// Content for any 2xx response where the request had the Prefer: return=minimal
// set on the request.
func PreferMinimalMiddleware() func(*gin.Context) {
func PreferMinimalMiddleware() Middleware {
return func(c *gin.Context) {
// Wrap the response writer
if c.GetHeader("prefer") == "return=minimal" {

View file

@ -15,14 +15,8 @@ func TestRecoveryMiddleware(t *testing.T) {
r := NewTestRouter(t)
r.GinEngine().Use(Recovery())
r.Register(http.MethodGet, "/panic", &Operation{
Description: "Panic recovery test",
Responses: []*Response{
ResponseText(http.StatusOK, "Success"),
},
Handler: func() string {
panic(fmt.Errorf("Some error"))
},
r.Resource("/panic").Get("Panic recovery test", func() string {
panic(fmt.Errorf("Some error"))
})
w := httptest.NewRecorder()
@ -36,14 +30,8 @@ func TestRecoveryMiddlewareLogBody(t *testing.T) {
r := NewTestRouter(t)
r.GinEngine().Use(Recovery())
r.Register(http.MethodPut, "/panic", &Operation{
Description: "Panic recovery test",
Responses: []*Response{
ResponseText(http.StatusOK, "Success"),
},
Handler: func(in map[string]string) string {
panic(fmt.Errorf("Some error"))
},
r.Resource("/panic").Put("Panic recovery test", func(in map[string]string) string {
panic(fmt.Errorf("Some error"))
})
w := httptest.NewRecorder()

View file

@ -8,6 +8,7 @@ import (
"time"
"github.com/Jeffail/gabs"
"github.com/danielgtaylor/huma/schema"
"github.com/gin-gonic/gin"
)
@ -21,216 +22,182 @@ const (
InHeader ParamLocation = "header"
)
// Param describes an OpenAPI 3 parameter
type Param struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
In ParamLocation `json:"in"`
Required bool `json:"required,omitempty"`
Schema *Schema `json:"schema,omitempty"`
Deprecated bool `json:"deprecated,omitempty"`
Example interface{} `json:"example,omitempty"`
// OpenAPIParam describes an OpenAPI 3 parameter
type OpenAPIParam struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
In ParamLocation `json:"in"`
Required bool `json:"required,omitempty"`
Schema *schema.Schema `json:"schema,omitempty"`
Deprecated bool `json:"deprecated,omitempty"`
Example interface{} `json:"example,omitempty"`
// Internal params are excluded from the OpenAPI document and can set up
// params sent between a load balander / proxy and the service internally.
internal bool
def interface{}
typ reflect.Type
Internal bool
def interface{}
typ reflect.Type
}
// PathParam returns a new required path parameter
func PathParam(name string, description string, schema ...*Schema) *Param {
return PathParamExample(name, description, nil, schema...)
}
// PathParamExample returns a new required path parameter with example
func PathParamExample(name string, description string, example interface{}, schema ...*Schema) *Param {
p := &Param{
// NewOpenAPIParam returns a new parameter instance.
func NewOpenAPIParam(name, description string, in ParamLocation, options ...ParamOption) *OpenAPIParam {
p := &OpenAPIParam{
Name: name,
Description: description,
In: InPath,
Required: true,
Example: example,
In: in,
}
if len(schema) > 0 {
p.Schema = schema[0]
for _, option := range options {
option.ApplyParam(p)
}
return p
}
// QueryParam returns a new optional query string parameter
func QueryParam(name string, description string, defaultValue interface{}, schema ...*Schema) *Param {
return QueryParamExample(name, description, defaultValue, nil, schema...)
}
// QueryParamExample returns a new optional query string parameter with example
func QueryParamExample(name string, description string, defaultValue interface{}, example interface{}, schema ...*Schema) *Param {
p := &Param{
Name: name,
Description: description,
In: InQuery,
Example: example,
def: defaultValue,
}
if len(schema) > 0 {
p.Schema = schema[0]
}
return p
}
// QueryParamInternal returns a new optional internal query string parameter
func QueryParamInternal(name string, description string, defaultValue interface{}) *Param {
return &Param{
Name: name,
Description: description,
In: InQuery,
internal: true,
def: defaultValue,
}
}
// HeaderParam returns a new optional header parameter
func HeaderParam(name string, description string, defaultValue interface{}, schema ...*Schema) *Param {
return HeaderParamExample(name, description, defaultValue, nil, schema...)
}
// HeaderParamExample returns a new optional header parameter with example
func HeaderParamExample(name string, description string, defaultValue interface{}, example interface{}, schema ...*Schema) *Param {
p := &Param{
Name: name,
Description: description,
In: InHeader,
Example: example,
def: defaultValue,
}
if len(schema) > 0 {
p.Schema = schema[0]
}
return p
}
// HeaderParamInternal returns a new optional internal header parameter
func HeaderParamInternal(name string, description string, defaultValue interface{}) *Param {
return &Param{
Name: name,
Description: description,
In: InHeader,
internal: true,
def: defaultValue,
}
}
// Response describes an OpenAPI 3 response
type Response struct {
// OpenAPIResponse describes an OpenAPI 3 response
type OpenAPIResponse struct {
Description string
ContentType string
StatusCode int
Schema *Schema
Schema *schema.Schema
Headers []string
empty bool
}
// ResponseEmpty creates a new response with no content type.
func ResponseEmpty(statusCode int, description string, headers ...string) *Response {
return &Response{
Description: description,
// NewOpenAPIResponse returns a new response instance.
func NewOpenAPIResponse(statusCode int, description string, options ...ResponseOption) *OpenAPIResponse {
r := &OpenAPIResponse{
StatusCode: statusCode,
Headers: headers,
empty: true,
}
}
// ResponseText creates a new string response model.
func ResponseText(statusCode int, description string, headers ...string) *Response {
return &Response{
Description: description,
ContentType: "text/plain",
StatusCode: statusCode,
Headers: headers,
}
}
// ResponseJSON creates a new JSON response model.
func ResponseJSON(statusCode int, description string, headers ...string) *Response {
return &Response{
Description: description,
ContentType: "application/json",
StatusCode: statusCode,
Headers: headers,
}
}
// ResponseError creates a new error response model. Alias for ResponseJSON
// mainly useful for documentation.
func ResponseError(status int, description string, headers ...string) *Response {
return ResponseJSON(status, description, headers...)
}
// ResponseHeader describes a response header
type ResponseHeader struct {
Name string `json:"-"`
Description string `json:"description,omitempty"`
Schema *Schema `json:"schema,omitempty"`
}
// Header returns a new header
func Header(name, description string) *ResponseHeader {
return &ResponseHeader{
Name: name,
Description: description,
}
for _, option := range options {
option.ApplyResponse(r)
}
return r
}
// SecurityRequirement defines the security schemes and scopes required to use
// OpenAPIResponseHeader describes a response header
type OpenAPIResponseHeader struct {
Name string `json:"-"`
Description string `json:"description,omitempty"`
Schema *schema.Schema `json:"schema,omitempty"`
}
// OpenAPISecurityRequirement defines the security schemes and scopes required to use
// an operation.
type SecurityRequirement map[string][]string
type OpenAPISecurityRequirement map[string][]string
// Operation describes an OpenAPI 3 operation on a path
type Operation struct {
ID string
Summary string
Description string
Tags []string
Security []SecurityRequirement
Dependencies []*Dependency
Params []*Param
RequestContentType string
RequestSchema *Schema
ResponseHeaders []*ResponseHeader
Responses []*Response
Handler interface{}
Extra map[string]interface{}
// OpenAPIOperation describes an OpenAPI 3 operation on a path
type OpenAPIOperation struct {
*OpenAPIDependency
id string
summary string
description string
tags []string
security []OpenAPISecurityRequirement
requestContentType string
requestSchema *schema.Schema
responses []*OpenAPIResponse
extra map[string]interface{}
// MaxBodyBytes limits the size of the request body that will be read before
// maxBodyBytes limits the size of the request body that will be read before
// an error is returned. Defaults to 1MiB if set to zero. Set to -1 for
// unlimited.
MaxBodyBytes int64
maxBodyBytes int64
// BodyReadTimeout sets the duration until reading the body is given up and
// bodyReadTimeout sets the duration until reading the body is given up and
// aborted with an error. Defaults to 15 seconds if the body is automatically
// read and parsed into a struct, otherwise unset. Set to -1 for unlimited.
BodyReadTimeout time.Duration
bodyReadTimeout time.Duration
}
// AllParams returns a list of all the parameters for this operation, including
// those for dependencies.
func (o *Operation) AllParams() []*Param {
params := []*Param{}
seen := map[*Param]bool{}
// ID returns the unique identifier for this operation. If not set manually,
// it is generated from the path and HTTP method.
func (o *OpenAPIOperation) ID() string {
return o.id
}
for _, p := range o.Params {
// NewOperation creates a new operation with the given options applied.
func NewOperation(options ...OperationOption) *OpenAPIOperation {
op := &OpenAPIOperation{
OpenAPIDependency: &OpenAPIDependency{
dependencies: make([]*OpenAPIDependency, 0),
params: make([]*OpenAPIParam, 0),
responseHeaders: make([]*OpenAPIResponseHeader, 0),
},
tags: make([]string, 0),
security: make([]OpenAPISecurityRequirement, 0),
responses: make([]*OpenAPIResponse, 0),
extra: make(map[string]interface{}),
}
for _, option := range options {
option.ApplyOperation(op)
}
return op
}
// Copy creates a new shallow copy of the operation. New arrays are created for
// e.g. parameters so they can be safely appended. Existing params are not
// deeply copied and should not be modified.
func (o *OpenAPIOperation) Copy() *OpenAPIOperation {
extraCopy := map[string]interface{}{}
for k, v := range o.extra {
extraCopy[k] = v
}
newOp := &OpenAPIOperation{
OpenAPIDependency: &OpenAPIDependency{
dependencies: append([]*OpenAPIDependency{}, o.dependencies...),
params: append([]*OpenAPIParam{}, o.params...),
responseHeaders: append([]*OpenAPIResponseHeader{}, o.responseHeaders...),
handler: o.handler,
},
id: o.id,
summary: o.summary,
description: o.description,
tags: append([]string{}, o.tags...),
security: append([]OpenAPISecurityRequirement{}, o.security...),
requestContentType: o.requestContentType,
requestSchema: o.requestSchema,
responses: append([]*OpenAPIResponse{}, o.responses...),
extra: extraCopy,
maxBodyBytes: o.maxBodyBytes,
bodyReadTimeout: o.bodyReadTimeout,
}
return newOp
}
// With applies options to the operation. It makes it easy to set up new params,
// responese headers, responses, etc. It always creates a new copy.
func (o *OpenAPIOperation) With(options ...OperationOption) *OpenAPIOperation {
copy := o.Copy()
for _, option := range options {
option.ApplyOperation(copy)
}
return copy
}
// allParams returns a list of all the parameters for this operation, including
// those for dependencies.
func (o *OpenAPIOperation) allParams() []*OpenAPIParam {
params := []*OpenAPIParam{}
seen := map[*OpenAPIParam]bool{}
for _, p := range o.params {
seen[p] = true
params = append(params, p)
}
for _, d := range o.Dependencies {
for _, p := range d.AllParams() {
for _, d := range o.dependencies {
for _, p := range d.allParams() {
if _, ok := seen[p]; !ok {
seen[p] = true
@ -242,19 +209,19 @@ func (o *Operation) AllParams() []*Param {
return params
}
// AllResponseHeaders returns a list of all the parameters for this operation,
// allResponseHeaders returns a list of all the parameters for this operation,
// including those for dependencies.
func (o *Operation) AllResponseHeaders() []*ResponseHeader {
headers := []*ResponseHeader{}
seen := map[*ResponseHeader]bool{}
func (o *OpenAPIOperation) allResponseHeaders() []*OpenAPIResponseHeader {
headers := []*OpenAPIResponseHeader{}
seen := map[*OpenAPIResponseHeader]bool{}
for _, h := range o.ResponseHeaders {
for _, h := range o.responseHeaders {
seen[h] = true
headers = append(headers, h)
}
for _, d := range o.Dependencies {
for _, h := range d.AllResponseHeaders() {
for _, d := range o.dependencies {
for _, h := range d.allResponseHeaders() {
if _, ok := seen[h]; !ok {
seen[h] = true
@ -266,57 +233,45 @@ func (o *Operation) AllResponseHeaders() []*ResponseHeader {
return headers
}
// Server describes an OpenAPI 3 API server location
type Server struct {
// OpenAPIServer describes an OpenAPI 3 API server location
type OpenAPIServer struct {
URL string `json:"url"`
Description string `json:"description,omitempty"`
}
// Contact information for this API.
type Contact struct {
// OpenAPIContact information for this API.
type OpenAPIContact struct {
Name string `json:"name"`
URL string `json:"url"`
Email string `json:"email"`
}
// OAuthFlow describes the URLs and scopes to get tokens via a specific flow.
type OAuthFlow struct {
// OpenAPIOAuthFlow describes the URLs and scopes to get tokens via a specific flow.
type OpenAPIOAuthFlow struct {
AuthorizationURL string `json:"authorizationUrl"`
TokenURL string `json:"tokenUrl"`
RefreshURL string `json:"refreshUrl,omitempty"`
Scopes map[string]string `json:"scopes"`
}
// OAuthFlows describes the configuration for each flow type.
type OAuthFlows struct {
Implicit *OAuthFlow `json:"implicit,omitempty"`
Password *OAuthFlow `json:"password,omitempty"`
ClientCredentials *OAuthFlow `json:"clientCredentials,omitempty"`
AuthorizationCode *OAuthFlow `json:"authorizationCode,omitempty"`
// OpenAPIOAuthFlows describes the configuration for each flow type.
type OpenAPIOAuthFlows struct {
Implicit *OpenAPIOAuthFlow `json:"implicit,omitempty"`
Password *OpenAPIOAuthFlow `json:"password,omitempty"`
ClientCredentials *OpenAPIOAuthFlow `json:"clientCredentials,omitempty"`
AuthorizationCode *OpenAPIOAuthFlow `json:"authorizationCode,omitempty"`
}
// SecurityScheme describes the auth mechanism(s) for this API.
type SecurityScheme struct {
Type string `json:"type"`
Description string `json:"description,omitempty"`
Name string `json:"name,omitempty"`
In string `json:"in,omitempty"`
Scheme string `json:"scheme,omitempty"`
BearerFormat string `json:"bearerFormat,omitempty"`
Flows *OAuthFlows `json:"flows,omitempty"`
OpenIDConnectURL string `json:"openIdConnectUrl,omitempty"`
}
// SecurityRef references a previously defined `SecurityScheme` by name along
// with any required scopes.
func SecurityRef(name string, scopes ...string) []SecurityRequirement {
if scopes == nil {
scopes = []string{}
}
return []SecurityRequirement{
{name: scopes},
}
// OpenAPISecurityScheme describes the auth mechanism(s) for this API.
type OpenAPISecurityScheme struct {
Type string `json:"type"`
Description string `json:"description,omitempty"`
Name string `json:"name,omitempty"`
In string `json:"in,omitempty"`
Scheme string `json:"scheme,omitempty"`
BearerFormat string `json:"bearerFormat,omitempty"`
Flows *OpenAPIOAuthFlows `json:"flows,omitempty"`
OpenIDConnectURL string `json:"openIdConnectUrl,omitempty"`
}
// OpenAPI describes the OpenAPI 3 API
@ -324,11 +279,11 @@ type OpenAPI struct {
Title string
Version string
Description string
Contact *Contact
Servers []*Server
SecuritySchemes map[string]*SecurityScheme
Security []SecurityRequirement
Paths map[string]map[string]*Operation
Contact *OpenAPIContact
Servers []*OpenAPIServer
SecuritySchemes map[string]*OpenAPISecurityScheme
Security []OpenAPISecurityRequirement
Paths map[string]map[string]*OpenAPIOperation
// Extra allows setting extra keys in the OpenAPI root structure.
Extra map[string]interface{}
@ -338,9 +293,9 @@ type OpenAPI struct {
Hook func(*gabs.Container)
}
// OpenAPIHandler returns a new handler function to generate an OpenAPI spec.
func OpenAPIHandler(api *OpenAPI) func(*gin.Context) {
respSchema400, _ := GenerateSchema(reflect.ValueOf(ErrorInvalidModel{}).Type())
// openAPIHandler returns a new handler function to generate an OpenAPI spec.
func openAPIHandler(api *OpenAPI) gin.HandlerFunc {
respSchema400, _ := schema.Generate(reflect.ValueOf(ErrorInvalidModel{}).Type())
return func(c *gin.Context) {
openapi := gabs.New()
@ -382,51 +337,51 @@ func OpenAPIHandler(api *OpenAPI) func(*gin.Context) {
for method, op := range methods {
method := strings.ToLower(method)
for k, v := range op.Extra {
for k, v := range op.extra {
openapi.Set(v, "paths", path, method, k)
}
openapi.Set(op.ID, "paths", path, method, "operationId")
if op.Summary != "" {
openapi.Set(op.Summary, "paths", path, method, "summary")
openapi.Set(op.id, "paths", path, method, "operationId")
if op.summary != "" {
openapi.Set(op.summary, "paths", path, method, "summary")
}
openapi.Set(op.Description, "paths", path, method, "description")
if len(op.Tags) > 0 {
openapi.Set(op.Tags, "paths", path, method, "tags")
openapi.Set(op.description, "paths", path, method, "description")
if len(op.tags) > 0 {
openapi.Set(op.tags, "paths", path, method, "tags")
}
if len(op.Security) > 0 {
openapi.Set(op.Security, "paths", path, method, "security")
if len(op.security) > 0 {
openapi.Set(op.security, "paths", path, method, "security")
}
for _, param := range op.AllParams() {
if param.internal {
for _, param := range op.allParams() {
if param.Internal {
// Skip internal-only parameters.
continue
}
openapi.ArrayAppend(param, "paths", path, method, "parameters")
}
if op.RequestSchema != nil {
ct := op.RequestContentType
if op.requestSchema != nil {
ct := op.requestContentType
if ct == "" {
ct = "application/json"
}
openapi.Set(op.RequestSchema, "paths", path, method, "requestBody", "content", ct, "schema")
openapi.Set(op.requestSchema, "paths", path, method, "requestBody", "content", ct, "schema")
}
responses := make([]*Response, 0, len(op.Responses))
responses := make([]*OpenAPIResponse, 0, len(op.responses))
found400 := false
for _, resp := range op.Responses {
for _, resp := range op.responses {
responses = append(responses, resp)
if resp.StatusCode == http.StatusBadRequest {
found400 = true
}
}
if op.RequestSchema != nil && !found400 {
if op.requestSchema != nil && !found400 {
// Add a 400-level response in case parsing the request fails.
responses = append(responses, &Response{
responses = append(responses, &OpenAPIResponse{
Description: "Invalid input",
ContentType: "application/json",
StatusCode: http.StatusBadRequest,
@ -434,12 +389,12 @@ func OpenAPIHandler(api *OpenAPI) func(*gin.Context) {
})
}
headerMap := map[string]*ResponseHeader{}
for _, header := range op.AllResponseHeaders() {
headerMap := map[string]*OpenAPIResponseHeader{}
for _, header := range op.allResponseHeaders() {
headerMap[header.Name] = header
}
for _, resp := range op.Responses {
for _, resp := range op.responses {
status := fmt.Sprintf("%v", resp.StatusCode)
openapi.Set(resp.Description, "paths", path, method, "responses", status, "description")
@ -449,8 +404,8 @@ func OpenAPIHandler(api *OpenAPI) func(*gin.Context) {
headers = append(headers, name)
seen[name] = true
}
for _, dep := range op.Dependencies {
for _, header := range dep.AllResponseHeaders() {
for _, dep := range op.dependencies {
for _, header := range dep.allResponseHeaders() {
if _, ok := seen[header.Name]; !ok {
headers = append(headers, header.Name)
seen[header.Name] = true

View file

@ -6,13 +6,14 @@ import (
"net/http/httptest"
"testing"
"github.com/danielgtaylor/huma/schema"
"github.com/getkin/kin-openapi/openapi3"
"github.com/stretchr/testify/assert"
)
var paramFuncsTable = []struct {
n string
param *Param
param OperationOption
name string
description string
in ParamLocation
@ -22,28 +23,30 @@ var paramFuncsTable = []struct {
example interface{}
}{
{"PathParam", PathParam("test", "desc"), "test", "desc", InPath, true, false, nil, nil},
{"PathParamSchema", PathParam("test", "desc", &Schema{}), "test", "desc", InPath, true, false, nil, nil},
{"PathParamExample", PathParamExample("test", "desc", 123), "test", "desc", InPath, true, false, nil, 123},
{"PathParamSchema", PathParam("test", "desc", Schema(&schema.Schema{})), "test", "desc", InPath, true, false, nil, nil},
{"PathParamExample", PathParam("test", "desc", Example(123)), "test", "desc", InPath, true, false, nil, 123},
{"QueryParam", QueryParam("test", "desc", "def"), "test", "desc", InQuery, false, false, "def", nil},
{"QueryParamSchema", QueryParam("test", "desc", "def", &Schema{}), "test", "desc", InQuery, false, false, "def", nil},
{"QueryParamExample", QueryParamExample("test", "desc", "def", "foo"), "test", "desc", InQuery, false, false, "def", "foo"},
{"QueryParamInternal", QueryParamInternal("test", "desc", "def"), "test", "desc", InQuery, false, true, "def", nil},
{"QueryParamSchema", QueryParam("test", "desc", "def", Schema(&schema.Schema{})), "test", "desc", InQuery, false, false, "def", nil},
{"QueryParamExample", QueryParam("test", "desc", "def", Example("foo")), "test", "desc", InQuery, false, false, "def", "foo"},
{"QueryParamInternal", QueryParam("test", "desc", "def", Internal()), "test", "desc", InQuery, false, true, "def", nil},
{"HeaderParam", HeaderParam("test", "desc", "def"), "test", "desc", InHeader, false, false, "def", nil},
{"HeaderParamSchema", HeaderParam("test", "desc", "def", &Schema{}), "test", "desc", InHeader, false, false, "def", nil},
{"HeaderParamExample", HeaderParamExample("test", "desc", "def", "foo"), "test", "desc", InHeader, false, false, "def", "foo"},
{"HeaderParamInternal", HeaderParamInternal("test", "desc", "def"), "test", "desc", InHeader, false, true, "def", nil},
{"HeaderParamSchema", HeaderParam("test", "desc", "def", Schema(&schema.Schema{})), "test", "desc", InHeader, false, false, "def", nil},
{"HeaderParamExample", HeaderParam("test", "desc", "def", Example("foo")), "test", "desc", InHeader, false, false, "def", "foo"},
{"HeaderParamInternal", HeaderParam("test", "desc", "def", Internal()), "test", "desc", InHeader, false, true, "def", nil},
}
func TestParamFuncs(outer *testing.T) {
for _, tt := range paramFuncsTable {
local := tt
outer.Run(fmt.Sprintf("%v", tt.n), func(t *testing.T) {
param := local.param
op := NewOperation()
local.param.ApplyOperation(op)
param := op.params[0]
assert.Equal(t, local.name, param.Name)
assert.Equal(t, local.description, param.Description)
assert.Equal(t, local.in, param.In)
assert.Equal(t, local.required, param.Required)
assert.Equal(t, local.internal, param.internal)
assert.Equal(t, local.internal, param.Internal)
assert.Equal(t, local.def, param.def)
assert.Equal(t, local.example, param.Example)
})
@ -52,23 +55,25 @@ func TestParamFuncs(outer *testing.T) {
var responseFuncsTable = []struct {
n string
resp *Response
resp OperationOption
statusCode int
description string
headers []string
contentType string
}{
{"ResponseEmpty", ResponseEmpty(204, "desc", "head1", "head2"), 204, "desc", []string{"head1", "head2"}, ""},
{"ResponseText", ResponseText(200, "desc", "head1", "head2"), 200, "desc", []string{"head1", "head2"}, "application/json"},
{"ResponseJSON", ResponseJSON(200, "desc", "head1", "head2"), 200, "desc", []string{"head1", "head2"}, "application/json"},
{"ResponseError", ResponseJSON(200, "desc", "head1", "head2"), 200, "desc", []string{"head1", "head2"}, "application/json"},
{"ResponseEmpty", Response(204, "desc", Headers("head1", "head2")), 204, "desc", []string{"head1", "head2"}, ""},
{"ResponseText", ResponseText(200, "desc", Headers("head1", "head2")), 200, "desc", []string{"head1", "head2"}, "application/json"},
{"ResponseJSON", ResponseJSON(200, "desc", Headers("head1", "head2")), 200, "desc", []string{"head1", "head2"}, "application/json"},
{"ResponseError", ResponseJSON(200, "desc", Headers("head1", "head2")), 200, "desc", []string{"head1", "head2"}, "application/json"},
}
func TestResponseFuncs(outer *testing.T) {
for _, tt := range responseFuncsTable {
local := tt
outer.Run(fmt.Sprintf("%v", tt.n), func(t *testing.T) {
resp := local.resp
op := NewOperation()
local.resp.ApplyOperation(op)
resp := op.responses[0]
assert.Equal(t, local.statusCode, resp.StatusCode)
assert.Equal(t, local.description, resp.Description)
assert.Equal(t, local.headers, resp.Headers)
@ -141,52 +146,29 @@ func TestOpenAPIHandler(t *testing.T) {
Extra("x-foo", "bar"),
)
dep1 := &Dependency{
Params: []*Param{
QueryParam("q", "Test query param", ""),
},
ResponseHeaders: []*ResponseHeader{
Header("dep", "description"),
},
Value: func(q string) (string, string, error) {
return "header", "foo", nil
},
}
dep1 := Dependency(DependencyOptions(
QueryParam("q", "Test query param", ""),
ResponseHeader("dep", "description"),
), func(q string) (string, string, error) {
return "header", "foo", nil
})
dep2 := &Dependency{
Dependencies: []*Dependency{dep1},
Value: func(q string) (string, error) {
return q, nil
},
}
dep2 := Dependency(dep1, func(q string) (string, error) {
return q, nil
})
r.Register(http.MethodPut, "/hello", &Operation{
ID: "put-hello",
Summary: "Summary message",
Description: "Get a welcome message",
Tags: []string{"Messages"},
Security: SecurityRef("basic"),
Dependencies: []*Dependency{
dep2,
},
Params: []*Param{
QueryParam("greet", "Whether to greet or not", false),
HeaderParamInternal("user", "User from auth token", ""),
},
ResponseHeaders: []*ResponseHeader{
Header("etag", "Content hash for caching"),
},
Responses: []*Response{
ResponseJSON(200, "Successful response", "etag"),
},
Extra: map[string]interface{}{
"x-foo": "bar",
},
Handler: func(q string, greet bool, user string, body *HelloRequest) (string, *HelloResponse) {
return "etag", &HelloResponse{
Message: "Hello",
}
},
r.Resource("/hello",
dep2,
SecurityRef("basic"),
QueryParam("greet", "Whether to greet or not", false),
HeaderParam("user", "User from auth token", "", Internal()),
ResponseHeader("etag", "Content hash for caching"),
ResponseJSON(200, "Successful response", Headers("etag")),
Extra("x-foo", "bar"),
).Put("Get a welcome message", func(q string, greet bool, user string, body *HelloRequest) (string, *HelloResponse) {
return "etag", &HelloResponse{
Message: "Hello",
}
})
w := httptest.NewRecorder()

View file

@ -1,9 +1,12 @@
package huma
import (
"fmt"
"net/http"
"time"
"github.com/Jeffail/gabs"
"github.com/danielgtaylor/huma/schema"
"github.com/gin-gonic/gin"
)
@ -12,47 +15,6 @@ type RouterOption interface {
ApplyRouter(r *Router)
}
// ResourceOption sets an option on the resource to be used in sub-resources
// and operations.
type ResourceOption interface {
ApplyResource(r *Resource)
}
// SharedOption sets an option on either a router/API or resource.
type SharedOption interface {
RouterOption
ResourceOption
}
type extraOption struct {
extra map[string]interface{}
}
func (o *extraOption) ApplyRouter(r *Router) {
for k, v := range o.extra {
r.api.Extra[k] = v
}
}
func (o *extraOption) ApplyResource(r *Resource) {
// for k, v := range o.extra {
// r.extra[k] = v
// }
}
// Extra sets extra values in the generated OpenAPI 3 spec.
func Extra(pairs ...interface{}) SharedOption {
extra := map[string]interface{}{}
for i := 0; i < len(pairs); i += 2 {
k := pairs[i].(string)
v := pairs[i+1]
extra[k] = v
}
return &extraOption{extra}
}
// routerOption is a shorthand struct used to create API options easily.
type routerOption struct {
handler func(*Router)
@ -62,45 +24,257 @@ func (o *routerOption) ApplyRouter(router *Router) {
o.handler(router)
}
// ResourceOption sets an option on the resource to be used in sub-resources
// and operations.
type ResourceOption interface {
ApplyResource(r *Resource)
}
// resourceOption is a shorthand struct used to create resource options easily.
type resourceOption struct {
handler func(*Resource)
}
func (o *resourceOption) ApplyResource(r *Resource) {
o.handler(r)
}
// OperationOption sets an option on an operation or resource object.
type OperationOption interface {
ResourceOption
ApplyOperation(o *OpenAPIOperation)
}
// operationOption is a shorthand struct used to create operation options
// easily. Options created with it can be applied to either operations or
// resources.
type operationOption struct {
handler func(*OpenAPIOperation)
}
func (o *operationOption) ApplyResource(r *Resource) {
o.handler(r.OpenAPIOperation)
}
func (o *operationOption) ApplyOperation(op *OpenAPIOperation) {
o.handler(op)
}
// DependencyOption sets an option on a dependency, operation, or resource
// object.
type DependencyOption interface {
OperationOption
ApplyDependency(d *OpenAPIDependency)
}
// dependencyOption is a shorthand struct used to create dependency options
// easily. Options created with it can be applied to dependencies, operations,
// and resources.
type dependencyOption struct {
handler func(*OpenAPIDependency)
}
func (o *dependencyOption) ApplyResource(r *Resource) {
o.handler(r.OpenAPIDependency)
}
func (o *dependencyOption) ApplyOperation(op *OpenAPIOperation) {
o.handler(op.OpenAPIDependency)
}
func (o *dependencyOption) ApplyDependency(d *OpenAPIDependency) {
o.handler(d)
}
// DependencyOptions composes together a set of options into one.
func DependencyOptions(options ...DependencyOption) DependencyOption {
return &dependencyOption{func(d *OpenAPIDependency) {
for _, option := range options {
option.ApplyDependency(d)
}
}}
}
// ParamOption sets an option on an OpenAPI parameter.
type ParamOption interface {
ApplyParam(*OpenAPIParam)
}
type paramOption struct {
apply func(*OpenAPIParam)
}
func (o *paramOption) ApplyParam(p *OpenAPIParam) {
o.apply(p)
}
// ResponseHeaderOption sets an option on an OpenAPI response header.
type ResponseHeaderOption interface {
ApplyResponseHeader(*OpenAPIResponseHeader)
}
// ResponseOption sets an option on an OpenAPI response.
type ResponseOption interface {
ApplyResponse(*OpenAPIResponse)
}
type responseOption struct {
apply func(*OpenAPIResponse)
}
func (o *responseOption) ApplyResponse(r *OpenAPIResponse) {
o.apply(r)
}
// sharedOption sets an option on any combination of objects.
type sharedOption struct {
Set func(v interface{})
}
func (o *sharedOption) ApplyRouter(r *Router) {
o.Set(r)
}
func (o *sharedOption) ApplyResource(r *Resource) {
o.Set(r)
}
func (o *sharedOption) ApplyOperation(op *OpenAPIOperation) {
o.Set(op)
}
func (o *sharedOption) ApplyParam(p *OpenAPIParam) {
o.Set(p)
}
func (o *sharedOption) ApplyResponseHeader(r *OpenAPIResponseHeader) {
o.Set(r)
}
func (o *sharedOption) ApplyResponse(r *OpenAPIResponse) {
o.Set(r)
}
// Schema manually sets a JSON Schema on the object. If the top-level `type` is
// blank then the type will be guessed from the handler function.
func Schema(s *schema.Schema) interface {
ParamOption
ResponseHeaderOption
ResponseOption
} {
return &sharedOption{func(v interface{}) {
switch cast := v.(type) {
case *OpenAPIParam:
cast.Schema = s
case *OpenAPIResponseHeader:
cast.Schema = s
case *OpenAPIResponse:
cast.Schema = s
}
}}
}
// SecurityRef adds a security reference by name with optional scopes.
func SecurityRef(name string, scopes ...string) interface {
RouterOption
OperationOption
} {
if scopes == nil {
scopes = []string{}
}
return &sharedOption{
Set: func(v interface{}) {
req := OpenAPISecurityRequirement{name: scopes}
switch cast := v.(type) {
case *Router:
cast.api.Security = append(cast.api.Security, req)
case *Resource:
cast.security = append(cast.security, req)
case *OpenAPIOperation:
cast.security = append(cast.security, req)
}
},
}
}
// Extra sets extra values in the generated OpenAPI 3 spec.
func Extra(pairs ...interface{}) interface {
RouterOption
OperationOption
} {
extra := map[string]interface{}{}
if len(pairs)%2 > 0 {
panic(fmt.Errorf("requires key-value pairs but got: %v", pairs))
}
for i := 0; i < len(pairs); i += 2 {
k := pairs[i].(string)
v := pairs[i+1]
extra[k] = v
}
return &sharedOption{
Set: func(v interface{}) {
var x map[string]interface{}
switch cast := v.(type) {
case *Router:
x = cast.api.Extra
case *Resource:
x = cast.extra
case *OpenAPIOperation:
x = cast.extra
}
for k, v := range extra {
x[k] = v
}
},
}
}
// ProdServer sets the production server URL on the API.
func ProdServer(url string) RouterOption {
return &routerOption{func(r *Router) {
r.api.Servers = append(r.api.Servers, &Server{url, "Production server"})
r.api.Servers = append(r.api.Servers, &OpenAPIServer{url, "Production server"})
}}
}
// DevServer sets the development server URL on the API.
func DevServer(url string) RouterOption {
return &routerOption{func(r *Router) {
r.api.Servers = append(r.api.Servers, &Server{url, "Development server"})
r.api.Servers = append(r.api.Servers, &OpenAPIServer{url, "Development server"})
}}
}
// ContactFull sets the API contact information.
func ContactFull(name, url, email string) RouterOption {
return &routerOption{func(r *Router) {
r.api.Contact = &Contact{name, url, email}
r.api.Contact = &OpenAPIContact{name, url, email}
}}
}
// ContactURL sets the API contact name & URL information.
func ContactURL(name, url string) RouterOption {
return &routerOption{func(r *Router) {
r.api.Contact = &Contact{Name: name, URL: url}
r.api.Contact = &OpenAPIContact{Name: name, URL: url}
}}
}
// ContactEmail sets the API contact name & email information.
func ContactEmail(name, email string) RouterOption {
return &routerOption{func(r *Router) {
r.api.Contact = &Contact{Name: name, Email: email}
r.api.Contact = &OpenAPIContact{Name: name, Email: email}
}}
}
// BasicAuth adds a named HTTP Basic Auth security scheme.
func BasicAuth(name string) RouterOption {
return &routerOption{func(r *Router) {
r.api.SecuritySchemes[name] = &SecurityScheme{
r.api.SecuritySchemes[name] = &OpenAPISecurityScheme{
Type: "http",
Scheme: "basic",
}
@ -112,7 +286,7 @@ func BasicAuth(name string) RouterOption {
// `header`, or `cookie`.
func APIKeyAuth(name, keyName, in string) RouterOption {
return &routerOption{func(r *Router) {
r.api.SecuritySchemes[name] = &SecurityScheme{
r.api.SecuritySchemes[name] = &OpenAPISecurityScheme{
Type: "apiKey",
Name: keyName,
In: in,
@ -124,7 +298,7 @@ func APIKeyAuth(name, keyName, in string) RouterOption {
// header.
func JWTBearerAuth(name string) RouterOption {
return &routerOption{func(r *Router) {
r.api.SecuritySchemes[name] = &SecurityScheme{
r.api.SecuritySchemes[name] = &OpenAPISecurityScheme{
Type: "http",
Scheme: "bearer",
BearerFormat: "JWT",
@ -177,3 +351,169 @@ func OpenAPIHook(f func(*gabs.Container)) RouterOption {
r.api.Hook = f
}}
}
// SimpleDependency adds a new dependency with just a value or function.
func SimpleDependency(handler interface{}) DependencyOption {
dep := &OpenAPIDependency{
handler: handler,
}
return &dependencyOption{func(d *OpenAPIDependency) {
d.dependencies = append(d.dependencies, dep)
}}
}
// Dependency adds a dependency.
func Dependency(option DependencyOption, handler interface{}) DependencyOption {
dep := NewDependency(option, handler)
return &dependencyOption{func(d *OpenAPIDependency) {
d.dependencies = append(d.dependencies, dep)
}}
}
// Example sets an example value, used for documentation and mocks.
func Example(value interface{}) ParamOption {
return &paramOption{func(p *OpenAPIParam) {
p.Example = value
}}
}
// Internal marks this parameter as internal-only, meaning it will not be
// included in the OpenAPI 3 JSON. Useful for things like auth headers set
// by a load balancer / gateway.
func Internal() ParamOption {
return &paramOption{func(p *OpenAPIParam) {
p.Internal = true
}}
}
// Deprecated marks this parameter as deprecated.
func Deprecated() ParamOption {
return &paramOption{func(p *OpenAPIParam) {
p.Deprecated = true
}}
}
func newParamOption(name, description string, required bool, def interface{}, in ParamLocation, options ...ParamOption) DependencyOption {
p := NewOpenAPIParam(name, description, in, options...)
p.Required = required
p.def = def
return &dependencyOption{func(d *OpenAPIDependency) {
d.params = append(d.params, p)
}}
}
// PathParam adds a new required path parameter
func PathParam(name string, description string, options ...ParamOption) DependencyOption {
return newParamOption(name, description, true, nil, InPath, options...)
}
// QueryParam returns a new optional query string parameter
func QueryParam(name string, description string, defaultValue interface{}, options ...ParamOption) DependencyOption {
return newParamOption(name, description, false, defaultValue, InQuery, options...)
}
// HeaderParam returns a new optional header parameter
func HeaderParam(name string, description string, defaultValue interface{}, options ...ParamOption) DependencyOption {
return newParamOption(name, description, false, defaultValue, InHeader, options...)
}
// ResponseHeader returns a new response header
func ResponseHeader(name, description string) DependencyOption {
r := &OpenAPIResponseHeader{
Name: name,
Description: description,
}
return &dependencyOption{func(d *OpenAPIDependency) {
d.responseHeaders = append(d.responseHeaders, r)
}}
}
// OperationID manually sets the operation's unique ID. If not set, it will
// be auto-generated from the resource path and operation verb.
func OperationID(id string) OperationOption {
return &operationOption{func(o *OpenAPIOperation) {
o.id = id
}}
}
// Tags sets one or more text tags on the operation.
func Tags(values ...string) OperationOption {
return &operationOption{func(o *OpenAPIOperation) {
o.tags = append(o.tags, values...)
}}
}
// RequestContentType sets the request content type on the operation.
func RequestContentType(name string) OperationOption {
return &operationOption{func(o *OpenAPIOperation) {
o.requestContentType = name
}}
}
// RequestSchema sets the request body schema on the operation.
func RequestSchema(schema *schema.Schema) OperationOption {
return &operationOption{func(o *OpenAPIOperation) {
o.requestSchema = schema
}}
}
// ContentType sets the content type for this response. If blank, an empty
// response is returned.
func ContentType(value string) ResponseOption {
return &responseOption{func(r *OpenAPIResponse) {
r.ContentType = value
}}
}
// Headers sets a list of allowed response headers.
func Headers(values ...string) ResponseOption {
return &responseOption{func(r *OpenAPIResponse) {
r.Headers = values
}}
}
// Response adds a new response to the operation.
func Response(statusCode int, description string, options ...ResponseOption) OperationOption {
r := NewOpenAPIResponse(statusCode, description, options...)
return &operationOption{func(o *OpenAPIOperation) {
o.responses = append(o.responses, r)
}}
}
// ResponseText adds a new string response to the operation.
func ResponseText(statusCode int, description string, options ...ResponseOption) OperationOption {
options = append(options, ContentType("text/plain"))
return Response(statusCode, description, options...)
}
// ResponseJSON adds a new JSON response model to the operation.
func ResponseJSON(statusCode int, description string, options ...ResponseOption) OperationOption {
options = append(options, ContentType("application/json"))
return Response(statusCode, description, options...)
}
// ResponseError adds a new error response model. Alias for ResponseJSON
// mainly useful for documentation purposes.
func ResponseError(statusCode int, description string, options ...ResponseOption) OperationOption {
return ResponseJSON(statusCode, description, options...)
}
// MaxBodyBytes sets the max number of bytes read from a request body before
// the handler aborts and returns an error. Applies to all sub-resources.
func MaxBodyBytes(value int64) OperationOption {
return &operationOption{func(o *OpenAPIOperation) {
o.maxBodyBytes = value
}}
}
// BodyReadTimeout sets the duration after which the read is aborted and an
// error is returned.
func BodyReadTimeout(value time.Duration) OperationOption {
return &operationOption{func(o *OpenAPIOperation) {
o.bodyReadTimeout = value
}}
}

View file

@ -1,40 +1,34 @@
package huma
import (
"fmt"
"net/http"
"reflect"
"strings"
"time"
)
// Resource describes a REST resource at a given URI path. Resources are
// typically created from a router or as a sub-resource of an existing resource.
type Resource struct {
router *Router
path string
deps []*Dependency
security []SecurityRequirement
params []*Param
responseHeaders []*ResponseHeader
responses []*Response
maxBodyBytes int64
bodyReadTimeout time.Duration
*OpenAPIOperation
router *Router
path string
}
// NewResource creates a new resource with the given router and path. All
// dependencies, security requirements, params, headers, and responses are
// empty.
func NewResource(router *Router, path string) *Resource {
return &Resource{
router: router,
path: path,
deps: make([]*Dependency, 0),
security: make([]SecurityRequirement, 0),
params: make([]*Param, 0),
responseHeaders: make([]*ResponseHeader, 0),
responses: make([]*Response, 0),
func NewResource(router *Router, path string, options ...ResourceOption) *Resource {
r := &Resource{
OpenAPIOperation: NewOperation(),
router: router,
path: path,
}
for _, option := range options {
option.ApplyResource(r)
}
return r
}
// Copy the resource. New arrays are created for dependencies, security
@ -42,60 +36,24 @@ func NewResource(router *Router, path string) *Resource {
// pointer values themselves are the same.
func (r *Resource) Copy() *Resource {
return &Resource{
router: r.router,
path: r.path,
deps: append([]*Dependency{}, r.deps...),
security: append([]SecurityRequirement{}, r.security...),
params: append([]*Param{}, r.params...),
responseHeaders: append([]*ResponseHeader{}, r.responseHeaders...),
responses: append([]*Response{}, r.responses...),
maxBodyBytes: r.maxBodyBytes,
bodyReadTimeout: r.bodyReadTimeout,
OpenAPIOperation: r.OpenAPIOperation.Copy(),
router: r.router,
path: r.path,
}
}
// With returns a copy of this resource with the given dependencies, security
// requirements, params, response headers, or responses added to it.
func (r *Resource) With(depsParamHeadersOrResponses ...interface{}) *Resource {
func (r *Resource) With(options ...ResourceOption) *Resource {
c := r.Copy()
// For each input, determine which type it is and store it.
for _, dph := range depsParamHeadersOrResponses {
switch v := dph.(type) {
case *Dependency:
c.deps = append(c.deps, v)
case []SecurityRequirement:
c.security = v
case SecurityRequirement:
c.security = append(c.security, v)
case *Param:
c.params = append(c.params, v)
case *ResponseHeader:
c.responseHeaders = append(c.responseHeaders, v)
case *Response:
c.responses = append(c.responses, v)
default:
panic(fmt.Errorf("unsupported type %v", v))
}
for _, option := range options {
option.ApplyResource(c)
}
return c
}
// MaxBodyBytes sets the max number of bytes read from a request body before
// the handler aborts and returns an error. Applies to all sub-resources.
func (r *Resource) MaxBodyBytes(value int64) *Resource {
r.maxBodyBytes = value
return r
}
// BodyReadTimeout sets the duration after which the read is aborted and an
// error is returned.
func (r *Resource) BodyReadTimeout(value time.Duration) *Resource {
r.bodyReadTimeout = value
return r
}
// Path returns the generated path including any path parameters.
func (r *Resource) Path() string {
generated := r.path
@ -117,7 +75,7 @@ func (r *Resource) Path() string {
// SubResource creates a new resource at the given path, which is appended
// to the existing resource path after adding any existing path parameters.
func (r *Resource) SubResource(path string, depsParamHeadersOrResponses ...interface{}) *Resource {
func (r *Resource) SubResource(path string, options ...ResourceOption) *Resource {
// Apply all existing params to the path.
newPath := r.Path()
@ -131,7 +89,7 @@ func (r *Resource) SubResource(path string, depsParamHeadersOrResponses ...inter
newPath += path
// Clone the resource and update the path.
c := r.With(depsParamHeadersOrResponses...)
c := r.With(options...)
c.path = newPath
return c
@ -139,16 +97,16 @@ func (r *Resource) SubResource(path string, depsParamHeadersOrResponses ...inter
// Operation adds the operation to this resource's router with all the
// combined deps, security requirements, params, headers, responses, etc.
func (r *Resource) Operation(method string, op *Operation) {
func (r *Resource) operation(method string, op *OpenAPIOperation) {
// Set params, etc
allDeps := append([]*Dependency{}, r.deps...)
allDeps = append(allDeps, op.Dependencies...)
op.Dependencies = allDeps
allDeps := append([]*OpenAPIDependency{}, r.dependencies...)
allDeps = append(allDeps, op.dependencies...)
op.dependencies = allDeps
// Combine resource and operation params. Update path with any required
// path parameters if they are not yet present.
allParams := append([]*Param{}, r.params...)
allParams = append(allParams, op.Params...)
allParams := append([]*OpenAPIParam{}, r.params...)
allParams = append(allParams, op.params...)
path := r.path
for _, p := range allParams {
if p.In == "path" {
@ -161,67 +119,47 @@ func (r *Resource) Operation(method string, op *Operation) {
}
}
}
op.Params = allParams
op.params = allParams
allHeaders := append([]*ResponseHeader{}, r.responseHeaders...)
allHeaders = append(allHeaders, op.ResponseHeaders...)
op.ResponseHeaders = allHeaders
allHeaders := append([]*OpenAPIResponseHeader{}, r.responseHeaders...)
allHeaders = append(allHeaders, op.responseHeaders...)
op.responseHeaders = allHeaders
allResponses := append([]*Response{}, r.responses...)
allResponses = append(allResponses, op.Responses...)
op.Responses = allResponses
allResponses := append([]*OpenAPIResponse{}, r.responses...)
allResponses = append(allResponses, op.responses...)
op.responses = allResponses
if op.Handler != nil {
t := reflect.TypeOf(op.Handler)
if t.NumOut() == len(op.ResponseHeaders)+len(op.Responses)+1 {
if op.handler != nil {
t := reflect.TypeOf(op.handler)
if t.NumOut() == len(op.responseHeaders)+len(op.responses)+1 {
rtype := t.Out(t.NumOut() - 1)
switch rtype.Kind() {
case reflect.Bool:
op.Responses = append(op.Responses, ResponseEmpty(http.StatusNoContent, "Success"))
op = op.With(Response(http.StatusNoContent, "Success"))
case reflect.String:
op.Responses = append(op.Responses, ResponseText(http.StatusOK, "Success"))
op = op.With(ResponseText(http.StatusOK, "Success"))
default:
op.Responses = append(op.Responses, ResponseJSON(http.StatusOK, "Success"))
op = op.With(ResponseJSON(http.StatusOK, "Success"))
}
}
}
if op.MaxBodyBytes == 0 {
op.MaxBodyBytes = r.maxBodyBytes
if op.maxBodyBytes == 0 {
op.maxBodyBytes = r.maxBodyBytes
}
if op.BodyReadTimeout == 0 {
op.BodyReadTimeout = r.bodyReadTimeout
if op.bodyReadTimeout == 0 {
op.bodyReadTimeout = r.bodyReadTimeout
}
r.router.Register(method, path, op)
}
// Text is shorthand for `r.With(huma.ResponseText(...))`.
func (r *Resource) Text(statusCode int, description string, headers ...string) *Resource {
return r.With(ResponseText(statusCode, description, headers...))
}
// JSON is shorthand for `r.With(huma.ResponseJSON(...))`.
func (r *Resource) JSON(statusCode int, description string, headers ...string) *Resource {
return r.With(ResponseJSON(statusCode, description, headers...))
}
// NoContent is shorthand for `r.With(huma.ResponseEmpty(http.StatusNoContent, ...)`
func (r *Resource) NoContent(description string, headers ...string) *Resource {
return r.With(ResponseEmpty(http.StatusNoContent, description, headers...))
}
// Empty is shorthand for `r.With(huma.ResponseEmpty(...))`.
func (r *Resource) Empty(statusCode int, description string, headers ...string) *Resource {
return r.With(ResponseEmpty(statusCode, description, headers...))
}
// Head creates an HTTP HEAD operation on the resource.
func (r *Resource) Head(description string, handler interface{}) {
r.Operation(http.MethodHead, &Operation{
Description: description,
Handler: handler,
r.operation(http.MethodHead, &OpenAPIOperation{
description: description,
OpenAPIDependency: &OpenAPIDependency{handler: handler},
})
}
@ -232,40 +170,40 @@ func (r *Resource) List(description string, handler interface{}) {
// Get creates an HTTP GET operation on the resource.
func (r *Resource) Get(description string, handler interface{}) {
r.Operation(http.MethodGet, &Operation{
Description: description,
Handler: handler,
r.operation(http.MethodGet, &OpenAPIOperation{
description: description,
OpenAPIDependency: &OpenAPIDependency{handler: handler},
})
}
// Post creates an HTTP POST operation on the resource.
func (r *Resource) Post(description string, handler interface{}) {
r.Operation(http.MethodPost, &Operation{
Description: description,
Handler: handler,
r.operation(http.MethodPost, &OpenAPIOperation{
description: description,
OpenAPIDependency: &OpenAPIDependency{handler: handler},
})
}
// Put creates an HTTP PUT operation on the resource.
func (r *Resource) Put(description string, handler interface{}) {
r.Operation(http.MethodPut, &Operation{
Description: description,
Handler: handler,
r.operation(http.MethodPut, &OpenAPIOperation{
description: description,
OpenAPIDependency: &OpenAPIDependency{handler: handler},
})
}
// Patch creates an HTTP PATCH operation on the resource.
func (r *Resource) Patch(description string, handler interface{}) {
r.Operation(http.MethodPatch, &Operation{
Description: description,
Handler: handler,
r.operation(http.MethodPatch, &OpenAPIOperation{
description: description,
OpenAPIDependency: &OpenAPIDependency{handler: handler},
})
}
// Delete creates an HTTP DELETE operation on the resource.
func (r *Resource) Delete(description string, handler interface{}) {
r.Operation(http.MethodDelete, &Operation{
Description: description,
Handler: handler,
r.operation(http.MethodDelete, &OpenAPIOperation{
description: description,
OpenAPIDependency: &OpenAPIDependency{handler: handler},
})
}

View file

@ -13,44 +13,38 @@ func TestResourceCopy(t *testing.T) {
r1 := NewResource(nil, "/test")
r2 := r1.Copy()
assert.NotSame(t, r1.deps, r2.deps)
assert.NotSame(t, r1.dependencies, r2.dependencies)
assert.NotSame(t, r1.params, r2.params)
assert.NotSame(t, r1.responseHeaders, r2.responseHeaders)
assert.NotSame(t, r1.responses, r2.responses)
}
func TestResourceWithBadInput(t *testing.T) {
assert.Panics(t, func() {
NewResource(nil, "/test").With("bad-value")
})
}
func TestResourceWithDep(t *testing.T) {
dep1 := &Dependency{Value: "dep1"}
dep2 := &Dependency{Value: "dep2"}
dep1 := SimpleDependency("dep1")
dep2 := SimpleDependency("dep2")
r1 := NewResource(nil, "/test")
r2 := r1.With(dep1)
r3 := r1.With(dep2)
assert.Contains(t, r2.deps, dep1)
assert.NotContains(t, r2.deps, dep2)
assert.Contains(t, r3.deps, dep2)
assert.NotContains(t, r3.deps, dep1)
assert.NotEmpty(t, r2.dependencies)
assert.NotEmpty(t, r3.dependencies)
assert.NotSame(t, r2.dependencies[0], r3.dependencies[0])
}
func TestResourceWithSecurity(t *testing.T) {
sec1 := SecurityRef("sec1")
sec2 := SecurityRef("sec2")[0]
sec2 := SecurityRef("sec2")
r1 := NewResource(nil, "/test")
r2 := r1.With(sec1)
r3 := r1.With(sec2)
assert.Equal(t, r2.security, sec1)
assert.NotContains(t, r2.security, sec2)
assert.Contains(t, r3.security, sec2)
assert.NotEqual(t, r3.security, sec1)
assert.NotEmpty(t, r2.security)
assert.NotEmpty(t, r3.security)
assert.NotSame(t, r2.security[0], r3.security[0])
}
func TestResourceWithParam(t *testing.T) {
@ -61,27 +55,27 @@ func TestResourceWithParam(t *testing.T) {
r2 := r1.With(param1)
r3 := r1.With(param2)
assert.Contains(t, r2.params, param1)
assert.NotContains(t, r2.params, param2)
assert.Contains(t, r3.params, param2)
assert.NotContains(t, r3.params, param1)
assert.NotEmpty(t, r2.params)
assert.NotEmpty(t, r3.params)
assert.NotSame(t, r2.params[0], r3.params[0])
assert.Equal(t, "/test/{p1}", r2.Path())
assert.Equal(t, "/test/{p2}", r3.Path())
}
func TestResourceWithHeader(t *testing.T) {
header1 := Header("h1", "desc")
header2 := Header("h2", "desc")
header1 := ResponseHeader("h1", "desc")
header2 := ResponseHeader("h2", "desc")
r1 := NewResource(nil, "/test")
r2 := r1.With(header1)
r3 := r1.With(header2)
assert.Contains(t, r2.responseHeaders, header1)
assert.NotContains(t, r2.responseHeaders, header2)
assert.Contains(t, r3.responseHeaders, header2)
assert.NotContains(t, r3.responseHeaders, header1)
assert.NotEmpty(t, r2.responseHeaders)
assert.NotEmpty(t, r3.responseHeaders)
assert.NotSame(t, r2.responseHeaders[0], r3.responseHeaders[0])
}
func TestResourceWithResponse(t *testing.T) {
@ -92,10 +86,10 @@ func TestResourceWithResponse(t *testing.T) {
r2 := r1.With(resp1)
r3 := r1.With(resp2)
assert.Contains(t, r2.responses, resp1)
assert.NotContains(t, r2.responses, resp2)
assert.Contains(t, r3.responses, resp2)
assert.NotContains(t, r3.responses, resp1)
assert.NotEmpty(t, r2.responses)
assert.NotEmpty(t, r3.responses)
assert.NotSame(t, r2.responses[0], r3.responses[0])
}
func TestSubResource(t *testing.T) {
@ -137,7 +131,7 @@ func TestResourceFuncs(outer *testing.T) {
local := tt
outer.Run(fmt.Sprintf("%v", tt), func(t *testing.T) {
r := NewTestRouter(t)
res := NewResource(r, "/test").Text(http.StatusOK, "desc")
res := NewResource(r, "/test")
var f func(string, interface{})
@ -180,34 +174,6 @@ var resourceShorthandFuncs = []struct {
{"Empty", http.StatusNotModified, "", "desc"},
}
func TestResourceShorthandFuncs(outer *testing.T) {
for _, tt := range resourceShorthandFuncs {
local := tt
outer.Run(fmt.Sprintf("%v", local.n), func(t *testing.T) {
r := NewTestRouter(t)
res := NewResource(r, "/test")
switch local.n {
case "Text":
res = res.Text(local.statusCode, local.desc, "header")
case "JSON":
res = res.JSON(local.statusCode, local.desc, "header")
case "NoContent":
res = res.NoContent(local.desc, "header")
case "Empty":
res = res.Empty(local.statusCode, local.desc, "header")
default:
panic("invalid case " + local.n)
}
resp := res.responses[0]
assert.Equal(t, local.statusCode, resp.StatusCode)
assert.Equal(t, local.contentType, resp.ContentType)
assert.Equal(t, local.desc, resp.Description)
})
}
}
func TestResourceAutoJSON(t *testing.T) {
r := NewTestRouter(t)
@ -218,8 +184,8 @@ func TestResourceAutoJSON(t *testing.T) {
return &MyResponse{}
})
assert.Equal(t, http.StatusOK, r.api.Paths["/test"][http.MethodGet].Responses[0].StatusCode)
assert.Equal(t, "application/json", r.api.Paths["/test"][http.MethodGet].Responses[0].ContentType)
assert.Equal(t, http.StatusOK, r.api.Paths["/test"][http.MethodGet].responses[0].StatusCode)
assert.Equal(t, "application/json", r.api.Paths["/test"][http.MethodGet].responses[0].ContentType)
}
func TestResourceAutoText(t *testing.T) {
@ -230,8 +196,8 @@ func TestResourceAutoText(t *testing.T) {
return "Hello, world"
})
assert.Equal(t, http.StatusOK, r.api.Paths["/test"][http.MethodGet].Responses[0].StatusCode)
assert.Equal(t, "text/plain", r.api.Paths["/test"][http.MethodGet].Responses[0].ContentType)
assert.Equal(t, http.StatusOK, r.api.Paths["/test"][http.MethodGet].responses[0].StatusCode)
assert.Equal(t, "text/plain", r.api.Paths["/test"][http.MethodGet].responses[0].ContentType)
}
func TestResourceAutoNoContent(t *testing.T) {
@ -242,7 +208,6 @@ func TestResourceAutoNoContent(t *testing.T) {
return true
})
assert.Equal(t, http.StatusNoContent, r.api.Paths["/test"][http.MethodGet].Responses[0].StatusCode)
assert.Equal(t, "", r.api.Paths["/test"][http.MethodGet].Responses[0].ContentType)
assert.Equal(t, true, r.api.Paths["/test"][http.MethodGet].Responses[0].empty)
assert.Equal(t, http.StatusNoContent, r.api.Paths["/test"][http.MethodGet].responses[0].StatusCode)
assert.Equal(t, "", r.api.Paths["/test"][http.MethodGet].responses[0].ContentType)
}

View file

@ -15,6 +15,7 @@ import (
"sync"
"time"
"github.com/danielgtaylor/huma/schema"
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
"github.com/spf13/cobra"
@ -30,6 +31,8 @@ var ErrInvalidParamLocation = errors.New("invalid parameter location")
// context value.
var ConnContextKey = struct{}{}
var timeType = reflect.TypeOf(time.Time{})
// GetConn gets the underlying `net.Conn` from a request.
func GetConn(r *http.Request) net.Conn {
conn := r.Context().Value(ConnContextKey)
@ -40,7 +43,7 @@ func GetConn(r *http.Request) net.Conn {
}
// Checks if data validates against the given schema. Returns false on failure.
func validAgainstSchema(c *gin.Context, label string, schema *Schema, data []byte) bool {
func validAgainstSchema(c *gin.Context, label string, schema *schema.Schema, data []byte) bool {
defer func() {
// Catch panics from the `gojsonschema` library.
if err := recover(); err != nil {
@ -149,7 +152,7 @@ func parseParamValue(c *gin.Context, name string, typ reflect.Type, pstr string)
return pv, true
}
func getParamValue(c *gin.Context, param *Param) (interface{}, bool) {
func getParamValue(c *gin.Context, param *OpenAPIParam) (interface{}, bool) {
var pstr string
switch param.In {
case InPath:
@ -198,18 +201,18 @@ func getParamValue(c *gin.Context, param *Param) (interface{}, bool) {
return pv, true
}
func getRequestBody(c *gin.Context, t reflect.Type, op *Operation) (interface{}, bool) {
func getRequestBody(c *gin.Context, t reflect.Type, op *OpenAPIOperation) (interface{}, bool) {
val := reflect.New(t).Interface()
if op.RequestSchema != nil {
if op.requestSchema != nil {
body, err := ioutil.ReadAll(c.Request.Body)
if err != nil {
if strings.Contains(err.Error(), "request body too large") {
c.AbortWithStatusJSON(http.StatusRequestEntityTooLarge, ErrorModel{
Message: fmt.Sprintf("Request body too large, limit = %d bytes", op.MaxBodyBytes),
Message: fmt.Sprintf("Request body too large, limit = %d bytes", op.maxBodyBytes),
})
} else if e, ok := err.(net.Error); ok && e.Timeout() {
c.AbortWithStatusJSON(http.StatusRequestTimeout, ErrorModel{
Message: fmt.Sprintf("Request body took too long to read: timed out after %v", op.BodyReadTimeout),
Message: fmt.Sprintf("Request body took too long to read: timed out after %v", op.bodyReadTimeout),
})
} else {
panic(err)
@ -219,7 +222,7 @@ func getRequestBody(c *gin.Context, t reflect.Type, op *Operation) (interface{},
c.Request.Body = ioutil.NopCloser(bytes.NewBuffer(body))
if !validAgainstSchema(c, "request body", op.RequestSchema, body) {
if !validAgainstSchema(c, "request body", op.requestSchema, body) {
// Error already handled, just return.
return nil, false
}
@ -274,10 +277,10 @@ func NewRouter(docs, version string, options ...RouterOption) *Router {
Title: title,
Description: desc,
Version: version,
Servers: make([]*Server, 0),
SecuritySchemes: make(map[string]*SecurityScheme, 0),
Security: make([]SecurityRequirement, 0),
Paths: make(map[string]map[string]*Operation),
Servers: make([]*OpenAPIServer, 0),
SecuritySchemes: make(map[string]*OpenAPISecurityScheme, 0),
Security: make([]OpenAPISecurityRequirement, 0),
Paths: make(map[string]map[string]*OpenAPIOperation),
Extra: make(map[string]interface{}),
},
engine: g,
@ -298,7 +301,7 @@ func NewRouter(docs, version string, options ...RouterOption) *Router {
}
// Set up handlers for the auto-generated spec and docs.
r.engine.GET("/openapi.json", OpenAPIHandler(r.api))
r.engine.GET("/openapi.json", openAPIHandler(r.api))
r.engine.GET("/docs", func(c *gin.Context) {
r.docsHandler(c, r.api)
@ -330,19 +333,19 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// Resource creates a new resource at the given path with the given
// dependencies, parameters, response headers, and responses defined.
func (r *Router) Resource(path string, depsParamsHeadersOrResponses ...interface{}) *Resource {
return NewResource(r, path).With(depsParamsHeadersOrResponses...)
func (r *Router) Resource(path string, options ...ResourceOption) *Resource {
return NewResource(r, path).With(options...)
}
// Register a new operation.
func (r *Router) Register(method, path string, op *Operation) {
func (r *Router) Register(method, path string, op *OpenAPIOperation) {
// First, make sure the operation and handler make sense, as well as pre-
// generating any schemas for use later during request handling.
op.validate(method, path)
// Add the operation to the list of operations for the path entry.
if r.api.Paths[path] == nil {
r.api.Paths[path] = make(map[string]*Operation)
r.api.Paths[path] = make(map[string]*OpenAPIOperation)
}
r.api.Paths[path][method] = op
@ -376,12 +379,12 @@ func (r *Router) Register(method, path string, op *Operation) {
// Then call it to register our handler function.
f(path, func(c *gin.Context) {
method := reflect.ValueOf(op.Handler)
method := reflect.ValueOf(op.handler)
in := make([]reflect.Value, 0, method.Type().NumIn())
// Limit the body size
if c.Request.Body != nil {
maxBody := op.MaxBodyBytes
maxBody := op.maxBodyBytes
if maxBody == 0 {
// 1 MiB default
maxBody = 1024 * 1024
@ -394,8 +397,8 @@ func (r *Router) Register(method, path string, op *Operation) {
}
// Process any dependencies first.
for _, dep := range op.Dependencies {
headers, value, err := dep.Resolve(c, op)
for _, dep := range op.dependencies {
headers, value, err := dep.resolve(c, op)
if err != nil {
if !c.IsAborted() {
// Nothing else has handled the error, so treat it like a general
@ -413,7 +416,7 @@ func (r *Router) Register(method, path string, op *Operation) {
in = append(in, reflect.ValueOf(value))
}
for _, param := range op.Params {
for _, param := range op.params {
pv, ok := getParamValue(c, param)
if !ok {
// Error has already been handled.
@ -423,7 +426,7 @@ func (r *Router) Register(method, path string, op *Operation) {
in = append(in, reflect.ValueOf(pv))
}
readTimeout := op.BodyReadTimeout
readTimeout := op.bodyReadTimeout
if len(in) != method.Type().NumIn() {
if readTimeout == 0 {
// Default to 15s when reading/parsing/validating automatically.
@ -458,14 +461,14 @@ func (r *Router) Register(method, path string, op *Operation) {
// from the registered `huma.Response` struct.
// This breaks down with scalar types... so they need to be passed
// as a pointer and we'll dereference it automatically.
for i, o := range out[len(op.ResponseHeaders):] {
for i, o := range out[len(op.responseHeaders):] {
if !o.IsZero() {
body := o.Interface()
r := op.Responses[i]
r := op.responses[i]
// Set response headers
for j, header := range op.ResponseHeaders {
for j, header := range op.responseHeaders {
value := out[j]
found := false
@ -498,7 +501,7 @@ func (r *Router) Register(method, path string, op *Operation) {
}
}
if r.empty {
if r.ContentType == "" {
// No body allowed, e.g. for HTTP 204.
c.Status(r.StatusCode)
break

View file

@ -10,6 +10,7 @@ import (
"testing"
"time"
"github.com/danielgtaylor/huma/schema"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
@ -53,16 +54,10 @@ func BenchmarkGin(b *testing.B) {
func BenchmarkHuma(b *testing.B) {
r := NewRouter("Benchmark test", "1.0.0", WithGin(gin.New()))
r.Register(http.MethodGet, "/hello", &Operation{
Description: "Greet the world",
Responses: []*Response{
ResponseJSON(200, "Return a greeting"),
},
Handler: func() *helloResponse {
return &helloResponse{
Message: "Hello, world",
}
},
r.Resource("/hello").Get("Greet the world", func() *helloResponse {
return &helloResponse{
Message: "Hello, world",
}
})
b.ResetTimer()
@ -120,56 +115,35 @@ func BenchmarkGinComplex(b *testing.B) {
func BenchmarkHumaComplex(b *testing.B) {
r := NewRouter("Benchmark test", "1.0.0", WithGin(gin.New()))
dep1 := &Dependency{
Value: "dep1",
}
dep1 := SimpleDependency("dep1")
dep2 := &Dependency{
Dependencies: []*Dependency{ContextDependency(), dep1},
Params: []*Param{
HeaderParam("x-foo", "desc", ""),
},
Value: func(c *gin.Context, d1 string, xfoo string) (string, error) {
return "dep2", nil
},
}
dep2 := Dependency(DependencyOptions(
ContextDependency(), dep1, HeaderParam("x-foo", "desc", ""),
), func(c *gin.Context, d1 string, xfoo string) (string, error) {
return "dep2", nil
})
dep3 := &Dependency{
Dependencies: []*Dependency{dep1},
ResponseHeaders: []*ResponseHeader{
Header("x-bar", "desc"),
},
Value: func(d1 string) (string, string, error) {
return "xbar", "dep3", nil
},
}
dep3 := Dependency(DependencyOptions(
dep1, ResponseHeader("x-bar", "desc"),
), func(d1 string) (string, string, error) {
return "xbar", "dep3", nil
})
r.Register(http.MethodGet, "/hello", &Operation{
Description: "Greet the world",
Dependencies: []*Dependency{
ContextDependency(), dep2, dep3,
},
Params: []*Param{
QueryParam("name", "desc", "world"),
},
ResponseHeaders: []*ResponseHeader{
Header("x-baz", "desc"),
},
Responses: []*Response{
ResponseJSON(200, "Return a greeting", "x-baz"),
ResponseError(500, "desc"),
},
Handler: func(c *gin.Context, d2, d3, name string) (string, *helloResponse, *ErrorModel) {
if name == "test" {
return "", nil, &ErrorModel{
Message: "Name cannot be test",
}
r.Resource("/hello", dep1, dep2, dep3,
QueryParam("name", "desc", "world"),
ResponseHeader("x-baz", "desc"),
ResponseJSON(200, "Return a greeting", Headers("x-baz")),
ResponseError(500, "desc"),
).Get("Greet the world", func(c *gin.Context, d2, d3, name string) (string, *helloResponse, *ErrorModel) {
if name == "test" {
return "", nil, &ErrorModel{
Message: "Name cannot be test",
}
}
return "xbaz", &helloResponse{
Message: "Hello, " + name,
}, nil
},
return "xbaz", &helloResponse{
Message: "Hello, " + name,
}, nil
})
b.ResetTimer()
@ -193,28 +167,22 @@ func TestRouter(t *testing.T) {
r := NewTestRouter(t)
r.Register(http.MethodPut, "/echo/{word}", &Operation{
Description: "Echo back an input word.",
Params: []*Param{
PathParam("word", "The word to echo back"),
QueryParam("greet", "Return a greeting", false),
},
Responses: []*Response{
ResponseJSON(http.StatusOK, "Successful echo response"),
ResponseError(http.StatusBadRequest, "Invalid input"),
},
Handler: func(word string, greet bool) (*EchoResponse, *ErrorModel) {
if word == "test" {
return nil, &ErrorModel{Message: "Value not allowed: test"}
}
r.Resource("/echo",
PathParam("word", "The word to echo back"),
QueryParam("greet", "Return a greeting", false),
ResponseJSON(http.StatusOK, "Successful echo response"),
ResponseError(http.StatusBadRequest, "Invalid input"),
).Put("Echo back an input word.", func(word string, greet bool) (*EchoResponse, *ErrorModel) {
if word == "test" {
return nil, &ErrorModel{Message: "Value not allowed: test"}
}
v := word
if greet {
v = "Hello, " + word
}
v := word
if greet {
v = "Hello, " + word
}
return &EchoResponse{Value: v}, nil
},
return &EchoResponse{Value: v}, nil
})
w := httptest.NewRecorder()
@ -257,14 +225,8 @@ func TestRouterRequestBody(t *testing.T) {
r := NewTestRouter(t)
r.Register(http.MethodPut, "/echo", &Operation{
Description: "Echo back an input word.",
Responses: []*Response{
ResponseJSON(http.StatusOK, "Successful echo response"),
},
Handler: func(in *EchoRequest) *EchoResponse {
return &EchoResponse{Value: in.Value}
},
r.Resource("/echo").Put("Echo back an input word.", func(in *EchoRequest) *EchoResponse {
return &EchoResponse{Value: in.Value}
})
w := httptest.NewRecorder()
@ -283,14 +245,8 @@ func TestRouterRequestBody(t *testing.T) {
func TestRouterScalarResponse(t *testing.T) {
r := NewTestRouter(t)
r.Register(http.MethodPut, "/hello", &Operation{
Description: "Say hello.",
Responses: []*Response{
ResponseText(http.StatusOK, "Successful hello response"),
},
Handler: func() string {
return "hello"
},
r.Resource("/hello").Put("Say hello", func() string {
return "hello"
})
w := httptest.NewRecorder()
@ -304,15 +260,9 @@ func TestRouterScalarResponse(t *testing.T) {
func TestRouterZeroScalarResponse(t *testing.T) {
r := NewTestRouter(t)
r.Register(http.MethodPut, "/bool", &Operation{
Description: "Say hello.",
Responses: []*Response{
ResponseText(http.StatusOK, "Successful zero bool response"),
},
Handler: func() *bool {
resp := false
return &resp
},
r.Resource("/bool").Put("Bool response", func() *bool {
resp := false
return &resp
})
w := httptest.NewRecorder()
@ -320,27 +270,21 @@ func TestRouterZeroScalarResponse(t *testing.T) {
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "false", w.Body.String())
assert.Equal(t, "false\n", w.Body.String())
}
func TestRouterResponseHeaders(t *testing.T) {
r := NewTestRouter(t)
r.Register(http.MethodGet, "/test", &Operation{
Description: "Test operation",
ResponseHeaders: []*ResponseHeader{
Header("Etag", "Identifies a specific version of this resource"),
Header("X-Test", "Custom test header"),
Header("X-Missing", "Won't get sent"),
},
Responses: []*Response{
ResponseText(http.StatusOK, "Successful test", "Etag", "X-Test", "X-Missing"),
ResponseError(http.StatusBadRequest, "Error example", "X-Test"),
},
Handler: func() (etag string, xTest *string, xMissing string, success string, fail string) {
test := "test"
return "\"abc123\"", &test, "", "hello", ""
},
r.Resource("/test",
ResponseHeader("Etag", "Identifies a specific version of this resource"),
ResponseHeader("X-Test", "Custom test header"),
ResponseHeader("X-Missing", "Won't get sent"),
ResponseText(http.StatusOK, "Successful test", Headers("Etag", "X-Test", "X-Missing")),
ResponseError(http.StatusBadRequest, "Error example", Headers("X-Test")),
).Get("Test operation", func() (etag string, xTest *string, xMissing string, success string, fail string) {
test := "test"
return "\"abc123\"", &test, "", "hello", ""
})
w := httptest.NewRecorder()
@ -362,11 +306,9 @@ func TestRouterDependencies(t *testing.T) {
}
// Datastore is a global dependency, set by value.
db := &Dependency{
Value: &DB{
Get: func() string {
return "Hello, "
},
db := &DB{
Get: func() string {
return "Hello, "
},
}
@ -376,35 +318,25 @@ func TestRouterDependencies(t *testing.T) {
// Logger is a contextual instance from the gin request context.
captured := ""
log := &Dependency{
Dependencies: []*Dependency{
GinContextDependency(),
},
Value: func(c *gin.Context) (*Logger, error) {
return &Logger{
Log: func(msg string) {
captured = fmt.Sprintf("%s [uri:%s]", msg, c.FullPath())
},
}, nil
},
}
log := Dependency(GinContextDependency(), func(c *gin.Context) (*Logger, error) {
return &Logger{
Log: func(msg string) {
captured = fmt.Sprintf("%s [uri:%s]", msg, c.FullPath())
},
}, nil
})
r.Register(http.MethodGet, "/hello", &Operation{
Description: "Basic hello world",
Dependencies: []*Dependency{GinContextDependency(), db, log},
Params: []*Param{
QueryParam("name", "Your name", ""),
},
Responses: []*Response{
ResponseText(http.StatusOK, "Successful hello response"),
},
Handler: func(c *gin.Context, db *DB, l *Logger, name string) string {
if name == "" {
name = c.Request.RemoteAddr
}
l.Log("Hello logger!")
return db.Get() + name
},
r.Resource("/hello",
GinContextDependency(),
SimpleDependency(db),
log,
QueryParam("name", "Your name", ""),
).Get("Basic hello world", func(c *gin.Context, db *DB, l *Logger, name string) string {
if name == "" {
name = c.Request.RemoteAddr
}
l.Log("Hello logger!")
return db.Get() + name
})
w := httptest.NewRecorder()
@ -421,7 +353,7 @@ func TestRouterBadHeader(t *testing.T) {
g := gin.New()
g.Use(LogMiddleware(l, nil))
r := NewRouter("Test API", "1.0.0", WithGin(g))
r.Resource("/test", Header("foo", "desc"), ResponseError(http.StatusBadRequest, "desc", "foo")).Get("desc", func() (string, *ErrorModel, string) {
r.Resource("/test", ResponseHeader("foo", "desc"), ResponseError(http.StatusBadRequest, "desc", Headers("foo"))).Get("desc", func() (string, *ErrorModel, string) {
return "header-value", nil, "response"
})
@ -441,7 +373,7 @@ func TestRouterParams(t *testing.T) {
QueryParam("i", "desc", int16(0)),
QueryParam("f32", "desc", float32(0.0)),
QueryParam("f64", "desc", 0.0),
QueryParam("schema", "desc", "test", &Schema{Pattern: "^a-z+$"}),
QueryParam("schema", "desc", "test", Schema(&schema.Schema{Pattern: "^a-z+$"})),
QueryParam("items", "desc", []int{}),
QueryParam("start", "desc", time.Time{}),
).Get("desc", func(id string, i int16, f32 float32, f64 float64, schema string, items []int, start time.Time) string {
@ -501,8 +433,11 @@ func TestRouterParams(t *testing.T) {
func TestInvalidParamLocation(t *testing.T) {
r := NewTestRouter(t)
test := r.Resource("/test", PathParam("name", "desc"))
test.params[len(test.params)-1].In = "bad'"
assert.Panics(t, func() {
r.Resource("/test", &Param{Name: "test", In: "bad"}).Get("desc", func(test string) string {
test.Get("desc", func(test string) string {
return "Hello, test!"
})
})
@ -523,7 +458,7 @@ func TestTooBigBody(t *testing.T) {
ID string
}
r.Resource("/test").MaxBodyBytes(5).Put("desc", func(input *Input) string {
r.Resource("/test", MaxBodyBytes(5)).Put("desc", func(input *Input) string {
return "hello, " + input.ID
})
@ -561,7 +496,7 @@ func TestBodySlow(t *testing.T) {
ID string
}
r.Resource("/test").BodyReadTimeout(1).Put("desc", func(input *Input) string {
r.Resource("/test", BodyReadTimeout(1)).Put("desc", func(input *Input) string {
return "hello, " + input.ID
})

View file

@ -1,4 +1,6 @@
package huma
// Package schema implements OpenAPI 3 compatible JSON Schema which can be
// generated from structs.
package schema
import (
"encoding/json"
@ -16,18 +18,18 @@ import (
// ErrSchemaInvalid is sent when there is a problem building the schema.
var ErrSchemaInvalid = errors.New("schema is invalid")
// SchemaMode defines whether the schema is being generated for read or
// Mode defines whether the schema is being generated for read or
// write mode. Read-only fields are dropped when in write mode, for example.
type SchemaMode int
type Mode int
const (
// SchemaModeAll is for general purpose use and includes all fields.
SchemaModeAll SchemaMode = iota
// SchemaModeRead is for HTTP HEAD & GET and will hide write-only fields.
SchemaModeRead
// SchemaModeWrite is for HTTP POST, PUT, PATCH, DELETE and will hide
// ModeAll is for general purpose use and includes all fields.
ModeAll Mode = iota
// ModeRead is for HTTP HEAD & GET and will hide write-only fields.
ModeRead
// ModeWrite is for HTTP POST, PUT, PATCH, DELETE and will hide
// read-only fields.
SchemaModeWrite
ModeWrite
)
var (
@ -120,20 +122,20 @@ func (s *Schema) HasValidation() bool {
return false
}
// GenerateSchema creates a JSON schema for a Go type. Struct field tags
// Generate creates a JSON schema for a Go type. Struct field tags
// can be used to provide additional metadata such as descriptions and
// validation.
func GenerateSchema(t reflect.Type) (*Schema, error) {
return GenerateSchemaWithMode(t, SchemaModeAll, nil)
func Generate(t reflect.Type) (*Schema, error) {
return GenerateWithMode(t, ModeAll, nil)
}
// GenerateSchemaWithMode creates a JSON schema for a Go type. Struct field
// GenerateWithMode creates a JSON schema for a Go type. Struct field
// tags can be used to provide additional metadata such as descriptions and
// validation. The mode can be all, read, or write. In read or write mode
// any field that is marked as the opposite will be excluded, e.g. a
// write-only field would not be included in read mode. If a schema is given
// as input, add to it, otherwise creates a new schema.
func GenerateSchemaWithMode(t reflect.Type, mode SchemaMode, schema *Schema) (*Schema, error) {
func GenerateWithMode(t reflect.Type, mode Mode, schema *Schema) (*Schema, error) {
if schema == nil {
schema = &Schema{}
}
@ -167,7 +169,7 @@ func GenerateSchemaWithMode(t reflect.Type, mode SchemaMode, schema *Schema) (*S
name = jsonTags[0]
}
s, err := GenerateSchemaWithMode(f.Type, mode, nil)
s, err := GenerateWithMode(f.Type, mode, nil)
if err != nil {
return nil, err
}
@ -177,6 +179,10 @@ func GenerateSchemaWithMode(t reflect.Type, mode SchemaMode, schema *Schema) (*S
s.Description = tag
}
if tag, ok := f.Tag.Lookup("doc"); ok {
s.Description = tag
}
if tag, ok := f.Tag.Lookup("format"); ok {
s.Format = tag
}
@ -326,7 +332,7 @@ func GenerateSchemaWithMode(t reflect.Type, mode SchemaMode, schema *Schema) (*S
}
s.ReadOnly = tag == "true"
if s.ReadOnly && mode == SchemaModeWrite {
if s.ReadOnly && mode == ModeWrite {
delete(properties, name)
continue
}
@ -338,7 +344,7 @@ func GenerateSchemaWithMode(t reflect.Type, mode SchemaMode, schema *Schema) (*S
}
s.WriteOnly = tag == "true"
if s.WriteOnly && mode == SchemaModeRead {
if s.WriteOnly && mode == ModeRead {
delete(properties, name)
continue
}
@ -372,14 +378,14 @@ func GenerateSchemaWithMode(t reflect.Type, mode SchemaMode, schema *Schema) (*S
case reflect.Map:
schema.Type = "object"
s, err := GenerateSchemaWithMode(t.Elem(), mode, nil)
s, err := GenerateWithMode(t.Elem(), mode, nil)
if err != nil {
return nil, err
}
schema.AdditionalProperties = s
case reflect.Slice, reflect.Array:
schema.Type = "array"
s, err := GenerateSchemaWithMode(t.Elem(), mode, nil)
s, err := GenerateWithMode(t.Elem(), mode, nil)
if err != nil {
return nil, err
}
@ -410,7 +416,7 @@ func GenerateSchemaWithMode(t reflect.Type, mode SchemaMode, schema *Schema) (*S
case reflect.String:
schema.Type = "string"
case reflect.Ptr:
return GenerateSchemaWithMode(t.Elem(), mode, schema)
return GenerateWithMode(t.Elem(), mode, schema)
default:
return nil, fmt.Errorf("unsupported type %s from %s", t.Kind(), t)
}

View file

@ -1,4 +1,4 @@
package huma
package schema
import (
"fmt"
@ -11,6 +11,21 @@ import (
"github.com/stretchr/testify/assert"
)
func Example() {
type MyObject struct {
ID string `doc:"Object ID" readOnly:"true"`
Rate float64 `doc:"Rate of change" minimum:"0"`
Coords []int `doc:"X,Y coordinates" minItems:"2" maxItems:"2"`
}
generated, err := Generate(reflect.TypeOf(MyObject{}))
if err != nil {
panic(err)
}
fmt.Println(generated.Properties["id"].ReadOnly)
// output: true
}
var types = []struct {
in interface{}
out string
@ -34,7 +49,7 @@ func TestSchemaTypes(outer *testing.T) {
local := tt
outer.Run(fmt.Sprintf("%v", tt.in), func(t *testing.T) {
t.Parallel()
s, err := GenerateSchema(reflect.ValueOf(local.in).Type())
s, err := Generate(reflect.ValueOf(local.in).Type())
assert.NoError(t, err)
assert.Equal(t, local.out, s.Type)
assert.Equal(t, local.format, s.Format)
@ -48,7 +63,7 @@ func TestSchemaRequiredFields(t *testing.T) {
Required string `json:"required"`
}
s, err := GenerateSchema(reflect.ValueOf(Example{}).Type())
s, err := Generate(reflect.ValueOf(Example{}).Type())
assert.NoError(t, err)
assert.Len(t, s.Properties, 2)
assert.NotContains(t, s.Required, "optional")
@ -60,7 +75,7 @@ func TestSchemaRenameField(t *testing.T) {
Foo string `json:"bar"`
}
s, err := GenerateSchema(reflect.ValueOf(Example{}).Type())
s, err := Generate(reflect.ValueOf(Example{}).Type())
assert.NoError(t, err)
assert.Empty(t, s.Properties["foo"])
assert.NotEmpty(t, s.Properties["bar"])
@ -71,7 +86,7 @@ func TestSchemaDescription(t *testing.T) {
Foo string `json:"foo" description:"I am a test"`
}
s, err := GenerateSchema(reflect.ValueOf(Example{}).Type())
s, err := Generate(reflect.ValueOf(Example{}).Type())
assert.NoError(t, err)
assert.Equal(t, "I am a test", s.Properties["foo"].Description)
}
@ -81,7 +96,7 @@ func TestSchemaFormat(t *testing.T) {
Foo string `json:"foo" format:"date-time"`
}
s, err := GenerateSchema(reflect.ValueOf(Example{}).Type())
s, err := Generate(reflect.ValueOf(Example{}).Type())
assert.NoError(t, err)
assert.Equal(t, "date-time", s.Properties["foo"].Format)
}
@ -91,7 +106,7 @@ func TestSchemaEnum(t *testing.T) {
Foo string `json:"foo" enum:"one,two,three"`
}
s, err := GenerateSchema(reflect.ValueOf(Example{}).Type())
s, err := Generate(reflect.ValueOf(Example{}).Type())
assert.NoError(t, err)
assert.Equal(t, []interface{}{"one", "two", "three"}, s.Properties["foo"].Enum)
}
@ -101,7 +116,7 @@ func TestSchemaDefault(t *testing.T) {
Foo string `json:"foo" default:"def"`
}
s, err := GenerateSchema(reflect.ValueOf(Example{}).Type())
s, err := Generate(reflect.ValueOf(Example{}).Type())
assert.NoError(t, err)
assert.Equal(t, "def", s.Properties["foo"].Default)
}
@ -111,7 +126,7 @@ func TestSchemaExample(t *testing.T) {
Foo string `json:"foo" example:"ex"`
}
s, err := GenerateSchema(reflect.ValueOf(Example{}).Type())
s, err := Generate(reflect.ValueOf(Example{}).Type())
assert.NoError(t, err)
assert.Equal(t, "ex", s.Properties["foo"].Example)
}
@ -121,7 +136,7 @@ func TestSchemaNullable(t *testing.T) {
Foo string `json:"foo" nullable:"true"`
}
s, err := GenerateSchema(reflect.ValueOf(Example{}).Type())
s, err := Generate(reflect.ValueOf(Example{}).Type())
assert.NoError(t, err)
assert.Equal(t, true, s.Properties["foo"].Nullable)
}
@ -131,7 +146,7 @@ func TestSchemaNullableError(t *testing.T) {
Foo string `json:"foo" nullable:"bad"`
}
_, err := GenerateSchema(reflect.ValueOf(Example{}).Type())
_, err := Generate(reflect.ValueOf(Example{}).Type())
assert.Error(t, err)
}
@ -140,7 +155,7 @@ func TestSchemaReadOnly(t *testing.T) {
Foo string `json:"foo" readOnly:"true"`
}
s, err := GenerateSchema(reflect.ValueOf(Example{}).Type())
s, err := Generate(reflect.ValueOf(Example{}).Type())
assert.NoError(t, err)
assert.Equal(t, true, s.Properties["foo"].ReadOnly)
}
@ -150,7 +165,7 @@ func TestSchemaReadOnlyError(t *testing.T) {
Foo string `json:"foo" readOnly:"bad"`
}
_, err := GenerateSchema(reflect.ValueOf(Example{}).Type())
_, err := Generate(reflect.ValueOf(Example{}).Type())
assert.Error(t, err)
}
@ -159,7 +174,7 @@ func TestSchemaWriteOnly(t *testing.T) {
Foo string `json:"foo" writeOnly:"true"`
}
s, err := GenerateSchema(reflect.ValueOf(Example{}).Type())
s, err := Generate(reflect.ValueOf(Example{}).Type())
assert.NoError(t, err)
assert.Equal(t, true, s.Properties["foo"].WriteOnly)
}
@ -169,7 +184,7 @@ func TestSchemaWriteOnlyError(t *testing.T) {
Foo string `json:"foo" writeOnly:"bad"`
}
_, err := GenerateSchema(reflect.ValueOf(Example{}).Type())
_, err := Generate(reflect.ValueOf(Example{}).Type())
assert.Error(t, err)
}
@ -178,7 +193,7 @@ func TestSchemaDeprecated(t *testing.T) {
Foo string `json:"foo" deprecated:"true"`
}
s, err := GenerateSchema(reflect.ValueOf(Example{}).Type())
s, err := Generate(reflect.ValueOf(Example{}).Type())
assert.NoError(t, err)
assert.Equal(t, true, s.Properties["foo"].Deprecated)
}
@ -188,7 +203,7 @@ func TestSchemaDeprecatedError(t *testing.T) {
Foo string `json:"foo" deprecated:"bad"`
}
_, err := GenerateSchema(reflect.ValueOf(Example{}).Type())
_, err := Generate(reflect.ValueOf(Example{}).Type())
assert.Error(t, err)
}
@ -197,7 +212,7 @@ func TestSchemaMinimum(t *testing.T) {
Foo float64 `json:"foo" minimum:"1"`
}
s, err := GenerateSchema(reflect.ValueOf(Example{}).Type())
s, err := Generate(reflect.ValueOf(Example{}).Type())
assert.NoError(t, err)
assert.Equal(t, 1.0, *s.Properties["foo"].Minimum)
}
@ -207,7 +222,7 @@ func TestSchemaMinimumError(t *testing.T) {
Foo float64 `json:"foo" minimum:"bad"`
}
_, err := GenerateSchema(reflect.ValueOf(Example{}).Type())
_, err := Generate(reflect.ValueOf(Example{}).Type())
assert.Error(t, err)
}
@ -216,7 +231,7 @@ func TestSchemaExclusiveMinimum(t *testing.T) {
Foo float64 `json:"foo" exclusiveMinimum:"1"`
}
s, err := GenerateSchema(reflect.ValueOf(Example{}).Type())
s, err := Generate(reflect.ValueOf(Example{}).Type())
assert.NoError(t, err)
assert.Equal(t, 1.0, *s.Properties["foo"].ExclusiveMinimum)
}
@ -226,7 +241,7 @@ func TestSchemaExclusiveMinimumError(t *testing.T) {
Foo float64 `json:"foo" exclusiveMinimum:"bad"`
}
_, err := GenerateSchema(reflect.ValueOf(Example{}).Type())
_, err := Generate(reflect.ValueOf(Example{}).Type())
assert.Error(t, err)
}
@ -235,7 +250,7 @@ func TestSchemaMaximum(t *testing.T) {
Foo float64 `json:"foo" maximum:"0"`
}
s, err := GenerateSchema(reflect.ValueOf(Example{}).Type())
s, err := Generate(reflect.ValueOf(Example{}).Type())
assert.NoError(t, err)
assert.Equal(t, 0.0, *s.Properties["foo"].Maximum)
}
@ -245,7 +260,7 @@ func TestSchemaMaximumError(t *testing.T) {
Foo float64 `json:"foo" maximum:"bad"`
}
_, err := GenerateSchema(reflect.ValueOf(Example{}).Type())
_, err := Generate(reflect.ValueOf(Example{}).Type())
assert.Error(t, err)
}
@ -254,7 +269,7 @@ func TestSchemaExclusiveMaximum(t *testing.T) {
Foo float64 `json:"foo" exclusiveMaximum:"0"`
}
s, err := GenerateSchema(reflect.ValueOf(Example{}).Type())
s, err := Generate(reflect.ValueOf(Example{}).Type())
assert.NoError(t, err)
assert.Equal(t, 0.0, *s.Properties["foo"].ExclusiveMaximum)
}
@ -264,7 +279,7 @@ func TestSchemaExclusiveMaximumError(t *testing.T) {
Foo float64 `json:"foo" exclusiveMaximum:"bad"`
}
_, err := GenerateSchema(reflect.ValueOf(Example{}).Type())
_, err := Generate(reflect.ValueOf(Example{}).Type())
assert.Error(t, err)
}
@ -273,7 +288,7 @@ func TestSchemaMultipleOf(t *testing.T) {
Foo float64 `json:"foo" multipleOf:"10"`
}
s, err := GenerateSchema(reflect.ValueOf(Example{}).Type())
s, err := Generate(reflect.ValueOf(Example{}).Type())
assert.NoError(t, err)
assert.Equal(t, 10.0, s.Properties["foo"].MultipleOf)
}
@ -283,7 +298,7 @@ func TestSchemaMultipleOfError(t *testing.T) {
Foo float64 `json:"foo" multipleOf:"bad"`
}
_, err := GenerateSchema(reflect.ValueOf(Example{}).Type())
_, err := Generate(reflect.ValueOf(Example{}).Type())
assert.Error(t, err)
}
@ -292,7 +307,7 @@ func TestSchemaMinLength(t *testing.T) {
Foo string `json:"foo" minLength:"10"`
}
s, err := GenerateSchema(reflect.ValueOf(Example{}).Type())
s, err := Generate(reflect.ValueOf(Example{}).Type())
assert.NoError(t, err)
assert.Equal(t, uint64(10), *s.Properties["foo"].MinLength)
}
@ -302,7 +317,7 @@ func TestSchemaMinLengthError(t *testing.T) {
Foo string `json:"foo" minLength:"bad"`
}
_, err := GenerateSchema(reflect.ValueOf(Example{}).Type())
_, err := Generate(reflect.ValueOf(Example{}).Type())
assert.Error(t, err)
}
@ -311,7 +326,7 @@ func TestSchemaMaxLength(t *testing.T) {
Foo string `json:"foo" maxLength:"10"`
}
s, err := GenerateSchema(reflect.ValueOf(Example{}).Type())
s, err := Generate(reflect.ValueOf(Example{}).Type())
assert.NoError(t, err)
assert.Equal(t, uint64(10), *s.Properties["foo"].MaxLength)
}
@ -321,7 +336,7 @@ func TestSchemaMaxLengthError(t *testing.T) {
Foo string `json:"foo" maxLength:"bad"`
}
_, err := GenerateSchema(reflect.ValueOf(Example{}).Type())
_, err := Generate(reflect.ValueOf(Example{}).Type())
assert.Error(t, err)
}
@ -330,7 +345,7 @@ func TestSchemaPattern(t *testing.T) {
Foo string `json:"foo" pattern:"a-z+"`
}
s, err := GenerateSchema(reflect.ValueOf(Example{}).Type())
s, err := Generate(reflect.ValueOf(Example{}).Type())
assert.NoError(t, err)
assert.Equal(t, "a-z+", s.Properties["foo"].Pattern)
}
@ -340,7 +355,7 @@ func TestSchemaPatternError(t *testing.T) {
Foo string `json:"foo" pattern:"(.*"`
}
_, err := GenerateSchema(reflect.ValueOf(Example{}).Type())
_, err := Generate(reflect.ValueOf(Example{}).Type())
assert.Error(t, err)
}
@ -349,7 +364,7 @@ func TestSchemaMinItems(t *testing.T) {
Foo []string `json:"foo" minItems:"10"`
}
s, err := GenerateSchema(reflect.ValueOf(Example{}).Type())
s, err := Generate(reflect.ValueOf(Example{}).Type())
assert.NoError(t, err)
assert.Equal(t, uint64(10), *s.Properties["foo"].MinItems)
}
@ -359,7 +374,7 @@ func TestSchemaMinItemsError(t *testing.T) {
Foo []string `json:"foo" minItems:"bad"`
}
_, err := GenerateSchema(reflect.ValueOf(Example{}).Type())
_, err := Generate(reflect.ValueOf(Example{}).Type())
assert.Error(t, err)
}
@ -368,7 +383,7 @@ func TestSchemaMaxItems(t *testing.T) {
Foo []string `json:"foo" maxItems:"10"`
}
s, err := GenerateSchema(reflect.ValueOf(Example{}).Type())
s, err := Generate(reflect.ValueOf(Example{}).Type())
assert.NoError(t, err)
assert.Equal(t, uint64(10), *s.Properties["foo"].MaxItems)
}
@ -378,7 +393,7 @@ func TestSchemaMaxItemsError(t *testing.T) {
Foo []string `json:"foo" maxItems:"bad"`
}
_, err := GenerateSchema(reflect.ValueOf(Example{}).Type())
_, err := Generate(reflect.ValueOf(Example{}).Type())
assert.Error(t, err)
}
@ -387,7 +402,7 @@ func TestSchemaUniqueItems(t *testing.T) {
Foo []string `json:"foo" uniqueItems:"true"`
}
s, err := GenerateSchema(reflect.ValueOf(Example{}).Type())
s, err := Generate(reflect.ValueOf(Example{}).Type())
assert.NoError(t, err)
assert.Equal(t, true, s.Properties["foo"].UniqueItems)
}
@ -397,7 +412,7 @@ func TestSchemaUniqueItemsError(t *testing.T) {
Foo []string `json:"foo" uniqueItems:"bad"`
}
_, err := GenerateSchema(reflect.ValueOf(Example{}).Type())
_, err := Generate(reflect.ValueOf(Example{}).Type())
assert.Error(t, err)
}
@ -406,7 +421,7 @@ func TestSchemaMinProperties(t *testing.T) {
Foo []string `json:"foo" minProperties:"10"`
}
s, err := GenerateSchema(reflect.ValueOf(Example{}).Type())
s, err := Generate(reflect.ValueOf(Example{}).Type())
assert.NoError(t, err)
assert.Equal(t, uint64(10), *s.Properties["foo"].MinProperties)
}
@ -416,7 +431,7 @@ func TestSchemaMinPropertiesError(t *testing.T) {
Foo []string `json:"foo" minProperties:"bad"`
}
_, err := GenerateSchema(reflect.ValueOf(Example{}).Type())
_, err := Generate(reflect.ValueOf(Example{}).Type())
assert.Error(t, err)
}
@ -425,7 +440,7 @@ func TestSchemaMaxProperties(t *testing.T) {
Foo []string `json:"foo" maxProperties:"10"`
}
s, err := GenerateSchema(reflect.ValueOf(Example{}).Type())
s, err := Generate(reflect.ValueOf(Example{}).Type())
assert.NoError(t, err)
assert.Equal(t, uint64(10), *s.Properties["foo"].MaxProperties)
}
@ -435,12 +450,12 @@ func TestSchemaMaxPropertiesError(t *testing.T) {
Foo []string `json:"foo" maxProperties:"bad"`
}
_, err := GenerateSchema(reflect.ValueOf(Example{}).Type())
_, err := Generate(reflect.ValueOf(Example{}).Type())
assert.Error(t, err)
}
func TestSchemaMap(t *testing.T) {
s, err := GenerateSchema(reflect.TypeOf(map[string]string{}))
s, err := Generate(reflect.TypeOf(map[string]string{}))
assert.NoError(t, err)
assert.Equal(t, &Schema{
Type: "object",
@ -451,7 +466,7 @@ func TestSchemaMap(t *testing.T) {
}
func TestSchemaSlice(t *testing.T) {
s, err := GenerateSchema(reflect.TypeOf([]string{}))
s, err := Generate(reflect.TypeOf([]string{}))
assert.NoError(t, err)
assert.Equal(t, &Schema{
Type: "array",
@ -462,7 +477,7 @@ func TestSchemaSlice(t *testing.T) {
}
func TestSchemaUnsigned(t *testing.T) {
s, err := GenerateSchema(reflect.TypeOf(uint(10)))
s, err := Generate(reflect.TypeOf(uint(10)))
assert.NoError(t, err)
min := 0.0
assert.Equal(t, &Schema{
@ -477,7 +492,7 @@ func TestSchemaNonStringExample(t *testing.T) {
Foo uint32 `json:"foo" example:"10"`
}
s, err := GenerateSchema(reflect.ValueOf(Example{}).Type())
s, err := Generate(reflect.ValueOf(Example{}).Type())
assert.NoError(t, err)
assert.Equal(t, uint32(10), s.Properties["foo"].Example)
}
@ -487,7 +502,7 @@ func TestSchemaNonStringExampleErrorUnmarshal(t *testing.T) {
Foo uint32 `json:"foo" example:"bad"`
}
_, err := GenerateSchema(reflect.ValueOf(Example{}).Type())
_, err := Generate(reflect.ValueOf(Example{}).Type())
assert.Error(t, err)
}
@ -496,7 +511,7 @@ func TestSchemaNonStringExampleErrorCast(t *testing.T) {
Foo bool `json:"foo" example:"1"`
}
_, err := GenerateSchema(reflect.ValueOf(Example{}).Type())
_, err := Generate(reflect.ValueOf(Example{}).Type())
assert.Error(t, err)
}

View file

@ -8,6 +8,7 @@ import (
"regexp"
"strings"
"github.com/danielgtaylor/huma/schema"
"github.com/gosimple/slug"
)
@ -37,7 +38,7 @@ func (a *OpenAPI) validate() error {
}
// validate the parameter and generate schemas
func (p *Param) validate(t reflect.Type) {
func (p *OpenAPIParam) validate(t reflect.Type) {
switch p.In {
case InPath, InQuery, InHeader:
default:
@ -65,7 +66,7 @@ func (p *Param) validate(t reflect.Type) {
p.typ = t
if p.Schema == nil || p.Schema.Type == "" {
s, err := GenerateSchemaWithMode(p.typ, SchemaModeWrite, p.Schema)
s, err := schema.GenerateWithMode(p.typ, schema.ModeWrite, p.Schema)
if err != nil {
panic(fmt.Errorf("parameter %s schema generation error: %w", p.Name, err))
}
@ -84,10 +85,10 @@ func (p *Param) validate(t reflect.Type) {
}
// validate the header and generate schemas
func (h *ResponseHeader) validate(t reflect.Type) {
func (h *OpenAPIResponseHeader) validate(t reflect.Type) {
if h.Schema == nil {
// Generate the schema from the handler function types.
s, err := GenerateSchemaWithMode(t, SchemaModeRead, nil)
s, err := schema.GenerateWithMode(t, schema.ModeRead, nil)
if err != nil {
panic(fmt.Errorf("response header %s schema generation error: %w", h.Name, err))
}
@ -97,39 +98,39 @@ func (h *ResponseHeader) validate(t reflect.Type) {
// validate checks that the operation is well-formed (e.g. handler signature
// matches the given params) and generates schemas if needed.
func (o *Operation) validate(method, path string) {
func (o *OpenAPIOperation) validate(method, path string) {
prefix := method + " " + path + ":"
if o.Description == "" {
if o.description == "" {
panic(fmt.Errorf("%s description field required: %w", prefix, ErrOperationInvalid))
}
if len(o.Responses) == 0 {
if len(o.responses) == 0 {
panic(fmt.Errorf("%s at least one response is required: %w", prefix, ErrOperationInvalid))
}
if o.Handler == nil {
if o.handler == nil {
panic(fmt.Errorf("%s handler is required: %w", prefix, ErrOperationInvalid))
}
handler := reflect.ValueOf(o.Handler).Type()
handler := reflect.ValueOf(o.handler).Type()
totalIn := len(o.Dependencies) + len(o.Params)
totalOut := len(o.ResponseHeaders) + len(o.Responses)
totalIn := len(o.dependencies) + len(o.params)
totalOut := len(o.responseHeaders) + len(o.responses)
if !(handler.NumIn() == totalIn || (method != http.MethodGet && handler.NumIn() == totalIn+1)) || handler.NumOut() != totalOut {
expected := "func("
for _, dep := range o.Dependencies {
expected += "? " + reflect.ValueOf(dep.Value).Type().String() + ", "
for _, dep := range o.dependencies {
expected += "? " + reflect.ValueOf(dep.handler).Type().String() + ", "
}
for _, param := range o.Params {
for _, param := range o.params {
expected += param.Name + " ?, "
}
expected = strings.TrimRight(expected, ", ")
expected += ") ("
for _, h := range o.ResponseHeaders {
for _, h := range o.responseHeaders {
expected += h.Name + " ?, "
}
for _, r := range o.Responses {
for _, r := range o.responses {
expected += fmt.Sprintf("*Response%d, ", r.StatusCode)
}
expected = strings.TrimRight(expected, ", ")
@ -138,7 +139,7 @@ func (o *Operation) validate(method, path string) {
panic(fmt.Errorf("%s expected handler %s but found %s: %w", prefix, expected, handler, ErrOperationInvalid))
}
if o.ID == "" {
if o.id == "" {
verb := method
// Try to detect calls returning lists of things.
@ -152,10 +153,10 @@ func (o *Operation) validate(method, path string) {
// Remove variables from path so they aren't in the generated name.
path := paramRe.ReplaceAllString(path, "")
o.ID = slug.Make(verb + path)
o.id = slug.Make(verb + path)
}
for i, dep := range o.Dependencies {
for i, dep := range o.dependencies {
paramType := handler.In(i)
// Catch common errors.
@ -163,7 +164,7 @@ func (o *Operation) validate(method, path string) {
panic(fmt.Errorf("%s gin.Context should be pointer *gin.Context: %w", prefix, ErrOperationInvalid))
}
if paramType.String() == "huma.Operation" {
if paramType.String() == "huma.OpenAPIOperation" {
panic(fmt.Errorf("%s huma.Operation should be pointer *huma.Operation: %w", prefix, ErrOperationInvalid))
}
@ -171,13 +172,13 @@ func (o *Operation) validate(method, path string) {
}
types := []reflect.Type{}
for i := len(o.Dependencies); i < handler.NumIn(); i++ {
for i := len(o.dependencies); i < handler.NumIn(); i++ {
paramType := handler.In(i)
switch paramType.String() {
case "gin.Context", "*gin.Context":
panic(fmt.Errorf("%s expected param but found gin.Context: %w", prefix, ErrOperationInvalid))
case "huma.Operation", "*huma.Operation":
case "huma.Operation", "*huma.OpenAPIOperation":
panic(fmt.Errorf("%s expected param but found huma.Operation: %w", prefix, ErrOperationInvalid))
}
@ -185,37 +186,38 @@ func (o *Operation) validate(method, path string) {
}
requestBody := false
if len(types) == len(o.Params)+1 {
if len(types) == len(o.params)+1 {
requestBody = true
}
for i, paramType := range types {
if i == len(types)-1 && requestBody {
// The last item has no associated param. It is a request body.
if o.RequestSchema == nil {
s, err := GenerateSchemaWithMode(paramType, SchemaModeWrite, nil)
if o.requestSchema == nil {
s, err := schema.GenerateWithMode(paramType, schema.ModeWrite, nil)
if err != nil {
panic(fmt.Errorf("%s request body schema generation error: %w", prefix, err))
}
o.RequestSchema = s
o.requestSchema = s
}
continue
}
p := o.Params[i]
p := o.params[i]
p.validate(paramType)
}
for i, header := range o.ResponseHeaders {
for i, header := range o.responseHeaders {
header.validate(handler.Out(i))
}
for i, resp := range o.Responses {
respType := handler.Out(len(o.ResponseHeaders) + i)
// HTTP 204 explicitly forbids a response body.
if !resp.empty && resp.Schema == nil {
for i, resp := range o.responses {
respType := handler.Out(len(o.responseHeaders) + i)
// HTTP 204 explicitly forbids a response body. We model this with an
// empty content type.
if resp.ContentType != "" && resp.Schema == nil {
// Generate the schema from the handler function types.
s, err := GenerateSchemaWithMode(respType, SchemaModeRead, nil)
s, err := schema.GenerateWithMode(respType, schema.ModeRead, nil)
if err != nil {
panic(fmt.Errorf("%s response %d schema generation error: %w", prefix, resp.StatusCode, err))
}

View file

@ -12,7 +12,7 @@ func TestOperationDescriptionRequired(t *testing.T) {
r := NewTestRouter(t)
assert.Panics(t, func() {
r.Register(http.MethodGet, "/", &Operation{})
r.Register(http.MethodGet, "/", &OpenAPIOperation{})
})
}
@ -20,21 +20,8 @@ func TestOperationResponseRequired(t *testing.T) {
r := NewTestRouter(t)
assert.Panics(t, func() {
r.Register(http.MethodGet, "/", &Operation{
Description: "Test",
})
})
}
func TestOperationHandlerMissing(t *testing.T) {
r := NewTestRouter(t)
assert.Panics(t, func() {
r.Register(http.MethodGet, "/", &Operation{
Description: "Test",
Responses: []*Response{
ResponseText(200, "Test"),
},
r.Register(http.MethodGet, "/", &OpenAPIOperation{
description: "Test",
})
})
}
@ -42,26 +29,13 @@ func TestOperationHandlerMissing(t *testing.T) {
func TestOperationHandlerInput(t *testing.T) {
r := NewTestRouter(t)
d := &Dependency{
Value: func() (string, error) {
return "test", nil
},
}
assert.Panics(t, func() {
r.Register(http.MethodGet, "/", &Operation{
Description: "Test",
Dependencies: []*Dependency{d},
Params: []*Param{
QueryParam("foo", "Test", ""),
},
Responses: []*Response{
ResponseText(200, "Test"),
},
Handler: func() string {
// Wrong number of inputs!
return "fails"
},
r.Resource("/",
SimpleDependency("test"),
ResponseText(200, "Test"),
).Get("Test", func() string {
// Wrong number of inputs!
return "fails"
})
})
}
@ -70,18 +44,12 @@ func TestOperationHandlerOutput(t *testing.T) {
r := NewTestRouter(t)
assert.Panics(t, func() {
r.Register(http.MethodGet, "/", &Operation{
Description: "Test",
ResponseHeaders: []*ResponseHeader{
Header("x-test", "Test"),
},
Responses: []*Response{
ResponseText(200, "Test", "x-test"),
},
Handler: func() string {
// Wrong number of outputs!
return "fails"
},
r.Resource("/",
ResponseHeader("x-test", "Test"),
ResponseText(200, "Test", Headers("x-test")),
).Get("Test", func() string {
// Wrong number of outputs!
return "fails"
})
})
}
@ -89,36 +57,23 @@ func TestOperationHandlerOutput(t *testing.T) {
func TestOperationListAutoID(t *testing.T) {
r := NewTestRouter(t)
o := &Operation{
Description: "Test",
Responses: []*Response{
ResponseJSON(200, "Test"),
},
Handler: func() []string {
return []string{"test"}
},
}
r.Resource("/items").Get("Test", func() []string {
return []string{"test"}
})
r.Register(http.MethodGet, "/items", o)
o := r.OpenAPI().Paths["/items"][http.MethodGet]
assert.Equal(t, "list-items", o.ID)
assert.Equal(t, "list-items", o.id)
}
func TestOperationContextPointer(t *testing.T) {
r := NewTestRouter(t)
assert.Panics(t, func() {
r.Register(http.MethodGet, "/", &Operation{
Description: "Test",
Dependencies: []*Dependency{
ContextDependency(),
},
Responses: []*Response{
ResponseText(200, "Test"),
},
Handler: func(c gin.Context) string {
return "test"
},
r.Resource("/",
GinContextDependency(),
).Get("Test", func(c gin.Context) string {
return "test"
})
})
}
@ -127,17 +82,10 @@ func TestOperationOperationPointer(t *testing.T) {
r := NewTestRouter(t)
assert.Panics(t, func() {
r.Register(http.MethodGet, "/", &Operation{
Description: "Test",
Dependencies: []*Dependency{
OperationDependency(),
},
Responses: []*Response{
ResponseText(200, "Test"),
},
Handler: func(o Operation) string {
return "test"
},
r.Resource("/",
OperationDependency(),
).Get("Test", func(o OpenAPIOperation) string {
return "test"
})
})
}
@ -146,17 +94,10 @@ func TestOperationInvalidDep(t *testing.T) {
r := NewTestRouter(t)
assert.Panics(t, func() {
r.Register(http.MethodGet, "/", &Operation{
Description: "Test",
Dependencies: []*Dependency{
&Dependency{},
},
Responses: []*Response{
ResponseText(200, "Test"),
},
Handler: func(string) string {
return "test"
},
r.Resource("/",
SimpleDependency(nil),
).Get("Test", func(o OpenAPIOperation) string {
return "test"
})
})
}
@ -165,32 +106,18 @@ func TestOperationParamDep(t *testing.T) {
r := NewTestRouter(t)
assert.Panics(t, func() {
r.Register(http.MethodGet, "/", &Operation{
Description: "Test",
Params: []*Param{
QueryParam("foo", "Test", ""),
},
Responses: []*Response{
ResponseText(200, "Test"),
},
Handler: func(c *gin.Context) string {
return "test"
},
r.Resource("/",
QueryParam("foo", "Test", ""),
).Get("Test", func(c *gin.Context) string {
return "test"
})
})
assert.Panics(t, func() {
r.Register(http.MethodGet, "/", &Operation{
Description: "Test",
Params: []*Param{
QueryParam("foo", "Test", ""),
},
Responses: []*Response{
ResponseText(200, "Test"),
},
Handler: func(o *Operation) string {
return "test"
},
r.Resource("/",
QueryParam("foo", "Test", ""),
).Get("Test", func(c *OpenAPIOperation) string {
return "test"
})
})
}
@ -198,31 +125,13 @@ func TestOperationParamDep(t *testing.T) {
func TestOperationParamRedeclare(t *testing.T) {
r := NewTestRouter(t)
p := QueryParam("foo", "Test", 0)
param := QueryParam("foo", "Test", 0)
r.Register(http.MethodGet, "/", &Operation{
Description: "Test",
Params: []*Param{p},
Responses: []*Response{
ResponseText(200, "Test"),
},
Handler: func(p int) string {
return "test"
},
})
r.Resource("/a", param).Get("Test", func(p int) string { return "a" })
// Param p was declared as `int` above but is `string` here.
// Redeclare param `p` as a string while it was an int above.
assert.Panics(t, func() {
r.Register(http.MethodGet, "/", &Operation{
Description: "Test",
Params: []*Param{p},
Responses: []*Response{
ResponseText(200, "Test"),
},
Handler: func(p string) string {
return "test"
},
})
r.Resource("/b", param).Get("Test", func(p string) string { return "b" })
})
}
@ -230,17 +139,10 @@ func TestOperationParamExampleType(t *testing.T) {
r := NewTestRouter(t)
assert.Panics(t, func() {
r.Register(http.MethodGet, "/", &Operation{
Description: "Test",
Params: []*Param{
QueryParamExample("foo", "Test", "", 123),
},
Responses: []*Response{
ResponseText(200, "Test"),
},
Handler: func(p string) string {
return "test"
},
r.Resource("/",
QueryParam("foo", "Test", "", Example(123)),
).Get("Test", func(p string) string {
return "test"
})
})
}
@ -248,20 +150,13 @@ func TestOperationParamExampleType(t *testing.T) {
func TestOperationParamExampleSchema(t *testing.T) {
r := NewTestRouter(t)
p := QueryParamExample("foo", "Test", 0, 123)
p := QueryParam("foo", "Test", 0, Example(123))
r.Register(http.MethodGet, "/", &Operation{
Description: "Test",
Params: []*Param{
p,
},
Responses: []*Response{
ResponseText(200, "Test"),
},
Handler: func(p int) string {
return "test"
},
r.Resource("/", p).Get("Test", func(p int) string {
return "test"
})
assert.Equal(t, 123, p.Schema.Example)
param := r.OpenAPI().Paths["/"][http.MethodGet].params[0]
assert.Equal(t, 123, param.Schema.Example)
}