mirror of
https://github.com/Fishwaldo/huma.git
synced 2025-03-15 19:31:27 +00:00
feat: param validations, better error logging
This commit is contained in:
parent
01490e416a
commit
52c5742da9
8 changed files with 173 additions and 95 deletions
|
@ -185,9 +185,9 @@ func (d *Dependency) Resolve(c *gin.Context, op *Operation) (map[string]string,
|
|||
|
||||
// Get each input parameter
|
||||
for _, param := range d.Params {
|
||||
v, err := getParamValue(c, param)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
v, ok := getParamValue(c, param)
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("could not get param value")
|
||||
}
|
||||
|
||||
in = append(in, reflect.ValueOf(v))
|
||||
|
|
|
@ -57,10 +57,18 @@ func LogMiddleware(l *zap.Logger, tags map[string]string) func(*gin.Context) {
|
|||
|
||||
c.Next()
|
||||
|
||||
contextLog.Debug("Request",
|
||||
contextLog = contextLog.With(
|
||||
zap.Int("status", c.Writer.Status()),
|
||||
zap.Duration("duration", time.Since(start)),
|
||||
)
|
||||
|
||||
if len(c.Errors) > 0 {
|
||||
for _, e := range c.Errors {
|
||||
contextLog.Error("Error", zap.Error(e.Err))
|
||||
}
|
||||
}
|
||||
|
||||
contextLog.Debug("Request")
|
||||
}
|
||||
}
|
||||
|
||||
|
|
57
openapi.go
57
openapi.go
|
@ -28,45 +28,47 @@ type Param struct {
|
|||
}
|
||||
|
||||
// PathParam returns a new required path parameter
|
||||
func PathParam(name string, description string) *Param {
|
||||
return &Param{
|
||||
Name: name,
|
||||
Description: description,
|
||||
In: "path",
|
||||
Required: true,
|
||||
}
|
||||
func PathParam(name string, description string, schema ...*Schema) *Param {
|
||||
return PathParamExample(name, description, nil, schema...)
|
||||
}
|
||||
|
||||
// PathParamExample returns a new required path parameter with example
|
||||
func PathParamExample(name string, description string, example interface{}) *Param {
|
||||
return &Param{
|
||||
func PathParamExample(name string, description string, example interface{}, schema ...*Schema) *Param {
|
||||
p := &Param{
|
||||
Name: name,
|
||||
Description: description,
|
||||
In: "path",
|
||||
Required: true,
|
||||
Example: example,
|
||||
}
|
||||
|
||||
if len(schema) > 0 {
|
||||
p.Schema = schema[0]
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
// QueryParam returns a new optional query string parameter
|
||||
func QueryParam(name string, description string, defaultValue interface{}) *Param {
|
||||
return &Param{
|
||||
Name: name,
|
||||
Description: description,
|
||||
In: "query",
|
||||
def: defaultValue,
|
||||
}
|
||||
func QueryParam(name string, description string, defaultValue interface{}, schema ...*Schema) *Param {
|
||||
return QueryParamExample(name, description, defaultValue, nil, schema...)
|
||||
}
|
||||
|
||||
// QueryParamExample returns a new optional query string parameter with example
|
||||
func QueryParamExample(name string, description string, defaultValue interface{}, example interface{}) *Param {
|
||||
return &Param{
|
||||
func QueryParamExample(name string, description string, defaultValue interface{}, example interface{}, schema ...*Schema) *Param {
|
||||
p := &Param{
|
||||
Name: name,
|
||||
Description: description,
|
||||
In: "query",
|
||||
Example: example,
|
||||
def: defaultValue,
|
||||
}
|
||||
|
||||
if len(schema) > 0 {
|
||||
p.Schema = schema[0]
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
// QueryParamInternal returns a new optional internal query string parameter
|
||||
|
@ -81,24 +83,25 @@ func QueryParamInternal(name string, description string, defaultValue interface{
|
|||
}
|
||||
|
||||
// HeaderParam returns a new optional header parameter
|
||||
func HeaderParam(name string, description string, defaultValue interface{}) *Param {
|
||||
return &Param{
|
||||
Name: name,
|
||||
Description: description,
|
||||
In: "header",
|
||||
def: defaultValue,
|
||||
}
|
||||
func HeaderParam(name string, description string, defaultValue interface{}, schema ...*Schema) *Param {
|
||||
return HeaderParamExample(name, description, defaultValue, nil, schema...)
|
||||
}
|
||||
|
||||
// HeaderParamExample returns a new optional header parameter with example
|
||||
func HeaderParamExample(name string, description string, defaultValue interface{}, example interface{}) *Param {
|
||||
return &Param{
|
||||
func HeaderParamExample(name string, description string, defaultValue interface{}, example interface{}, schema ...*Schema) *Param {
|
||||
p := &Param{
|
||||
Name: name,
|
||||
Description: description,
|
||||
In: "header",
|
||||
Example: example,
|
||||
def: defaultValue,
|
||||
}
|
||||
|
||||
if len(schema) > 0 {
|
||||
p.Schema = schema[0]
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
// HeaderParamInternal returns a new optional internal header parameter
|
||||
|
|
|
@ -23,11 +23,14 @@ var paramFuncsTable = []struct {
|
|||
example interface{}
|
||||
}{
|
||||
{"PathParam", PathParam("test", "desc"), "test", "desc", "path", true, false, nil, nil},
|
||||
{"PathParamSchema", PathParam("test", "desc", &Schema{}), "test", "desc", "path", true, false, nil, nil},
|
||||
{"PathParamExample", PathParamExample("test", "desc", 123), "test", "desc", "path", true, false, nil, 123},
|
||||
{"QueryParam", QueryParam("test", "desc", "def"), "test", "desc", "query", false, false, "def", nil},
|
||||
{"QueryParamSchema", QueryParam("test", "desc", "def", &Schema{}), "test", "desc", "query", false, false, "def", nil},
|
||||
{"QueryParamExample", QueryParamExample("test", "desc", "def", "foo"), "test", "desc", "query", false, false, "def", "foo"},
|
||||
{"QueryParamInternal", QueryParamInternal("test", "desc", "def"), "test", "desc", "query", false, true, "def", nil},
|
||||
{"HeaderParam", HeaderParam("test", "desc", "def"), "test", "desc", "header", false, false, "def", nil},
|
||||
{"HeaderParamSchema", HeaderParam("test", "desc", "def", &Schema{}), "test", "desc", "header", false, false, "def", nil},
|
||||
{"HeaderParamExample", HeaderParamExample("test", "desc", "def", "foo"), "test", "desc", "header", false, false, "def", "foo"},
|
||||
{"HeaderParamInternal", HeaderParamInternal("test", "desc", "def"), "test", "desc", "header", false, true, "def", nil},
|
||||
}
|
||||
|
|
115
router.go
115
router.go
|
@ -21,7 +21,35 @@ import (
|
|||
// is not a valid value.
|
||||
var ErrInvalidParamLocation = errors.New("invalid parameter location")
|
||||
|
||||
func getParamValue(c *gin.Context, param *Param) (interface{}, error) {
|
||||
// Checks if data validates against the given schema. Returns false on failure.
|
||||
func validAgainstSchema(c *gin.Context, schema *Schema, data []byte) bool {
|
||||
loader := gojsonschema.NewGoLoader(schema)
|
||||
doc := gojsonschema.NewBytesLoader(data)
|
||||
s, err := gojsonschema.NewSchema(loader)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
result, err := s.Validate(doc)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if !result.Valid() {
|
||||
errors := []string{}
|
||||
for _, desc := range result.Errors() {
|
||||
errors = append(errors, fmt.Sprintf("%s", desc))
|
||||
}
|
||||
c.AbortWithStatusJSON(400, &ErrorInvalidModel{
|
||||
Message: "Invalid input",
|
||||
Errors: errors,
|
||||
})
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func getParamValue(c *gin.Context, param *Param) (interface{}, bool) {
|
||||
var pstr string
|
||||
switch param.In {
|
||||
case "path":
|
||||
|
@ -29,20 +57,33 @@ func getParamValue(c *gin.Context, param *Param) (interface{}, error) {
|
|||
case "query":
|
||||
pstr = c.Query(param.Name)
|
||||
if pstr == "" {
|
||||
return param.def, nil
|
||||
return param.def, true
|
||||
}
|
||||
case "header":
|
||||
pstr = c.GetHeader(param.Name)
|
||||
if pstr == "" {
|
||||
return param.def, nil
|
||||
return param.def, true
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("%s: %w", param.In, ErrInvalidParamLocation)
|
||||
panic(fmt.Errorf("%s: %w", param.In, ErrInvalidParamLocation))
|
||||
}
|
||||
|
||||
if pstr == "" && !param.Required {
|
||||
// Optional and not passed, so set it to its zero value.
|
||||
return reflect.New(param.typ).Elem().Interface(), nil
|
||||
return reflect.New(param.typ).Elem().Interface(), true
|
||||
}
|
||||
|
||||
if param.Schema.HasValidation() {
|
||||
data := pstr
|
||||
if param.Schema.Type == "string" {
|
||||
// Strings are special in that we don't expect users to provide them
|
||||
// with quotes, so wrap them here for the parser that does the
|
||||
// validation step below.
|
||||
data = `"` + data + `"`
|
||||
}
|
||||
if !validAgainstSchema(c, param.Schema, []byte(data)) {
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
var pv interface{}
|
||||
|
@ -50,33 +91,45 @@ func getParamValue(c *gin.Context, param *Param) (interface{}, error) {
|
|||
case reflect.Bool:
|
||||
converted, err := strconv.ParseBool(pstr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, &ErrorModel{
|
||||
Message: fmt.Sprintf("cannot parse boolean for param %s: %s", param.Name, pstr),
|
||||
})
|
||||
return nil, false
|
||||
}
|
||||
pv = converted
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
||||
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
converted, err := strconv.Atoi(pstr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, &ErrorModel{
|
||||
Message: fmt.Sprintf("cannot parse integer for param %s: %s", param.Name, pstr),
|
||||
})
|
||||
return nil, false
|
||||
}
|
||||
pv = reflect.ValueOf(converted).Convert(param.typ).Interface()
|
||||
case reflect.Float32:
|
||||
converted, err := strconv.ParseFloat(pstr, 32)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, &ErrorModel{
|
||||
Message: fmt.Sprintf("cannot parse float for param %s: %s", param.Name, pstr),
|
||||
})
|
||||
return nil, false
|
||||
}
|
||||
pv = converted
|
||||
case reflect.Float64:
|
||||
converted, err := strconv.ParseFloat(pstr, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, &ErrorModel{
|
||||
Message: fmt.Sprintf("cannot parse float for param %s: %s", param.Name, pstr),
|
||||
})
|
||||
return nil, false
|
||||
}
|
||||
pv = converted
|
||||
default:
|
||||
pv = pstr
|
||||
}
|
||||
|
||||
return pv, nil
|
||||
return pv, true
|
||||
}
|
||||
|
||||
func getRequestBody(c *gin.Context, t reflect.Type, op *Operation) (interface{}, bool) {
|
||||
|
@ -90,28 +143,7 @@ func getRequestBody(c *gin.Context, t reflect.Type, op *Operation) (interface{},
|
|||
|
||||
c.Request.Body = ioutil.NopCloser(bytes.NewBuffer(body))
|
||||
|
||||
loader := gojsonschema.NewGoLoader(op.RequestSchema)
|
||||
doc := gojsonschema.NewBytesLoader(body)
|
||||
s, err := gojsonschema.NewSchema(loader)
|
||||
if err != nil {
|
||||
c.AbortWithError(500, err)
|
||||
return nil, false
|
||||
}
|
||||
result, err := s.Validate(doc)
|
||||
if err != nil {
|
||||
c.AbortWithError(500, err)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if !result.Valid() {
|
||||
errors := []string{}
|
||||
for _, desc := range result.Errors() {
|
||||
errors = append(errors, fmt.Sprintf("%s", desc))
|
||||
}
|
||||
c.AbortWithStatusJSON(400, &ErrorInvalidModel{
|
||||
Message: "Invalid input",
|
||||
Errors: errors,
|
||||
})
|
||||
if !validAgainstSchema(c, op.RequestSchema, body) {
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
@ -255,10 +287,14 @@ func (r *Router) Register(op *Operation) {
|
|||
headers, value, err := dep.Resolve(c, op)
|
||||
if err != nil {
|
||||
// TODO: better error handling
|
||||
c.AbortWithStatusJSON(500, ErrorModel{
|
||||
Message: "Couldn't get dependency",
|
||||
//Errors: []error{err},
|
||||
})
|
||||
if !c.IsAborted() {
|
||||
// Nothing else has handled the error, so treat it like a general
|
||||
// internal server error.
|
||||
c.AbortWithStatusJSON(500, &ErrorModel{
|
||||
Message: "Couldn't get dependency",
|
||||
//Errors: []error{err},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
for k, v := range headers {
|
||||
|
@ -269,10 +305,9 @@ func (r *Router) Register(op *Operation) {
|
|||
}
|
||||
|
||||
for _, param := range op.Params {
|
||||
pv, err := getParamValue(c, param)
|
||||
if err != nil {
|
||||
// TODO expose error to user
|
||||
c.AbortWithError(400, err)
|
||||
pv, ok := getParamValue(c, param)
|
||||
if !ok {
|
||||
// Error has already been handled.
|
||||
return
|
||||
}
|
||||
|
||||
|
|
59
schema.go
59
schema.go
|
@ -37,6 +37,18 @@ var (
|
|||
byteSliceType = reflect.TypeOf([]byte(nil))
|
||||
)
|
||||
|
||||
// I returns a pointer to the given int. Useful helper function for pointer
|
||||
// schema validators like MaxLength or MinItems.
|
||||
func I(value uint64) *uint64 {
|
||||
return &value
|
||||
}
|
||||
|
||||
// F returns a pointer to the given float64. Useful helper function for pointer
|
||||
// schema validators like Maximum or Minimum.
|
||||
func F(value float64) *float64 {
|
||||
return &value
|
||||
}
|
||||
|
||||
// getTagValue returns a value of the schema's type for the given tag string.
|
||||
// Uses JSON parsing if the schema is not a string.
|
||||
func getTagValue(s *Schema, t reflect.Type, value string) (interface{}, error) {
|
||||
|
@ -97,20 +109,34 @@ type Schema struct {
|
|||
Deprecated bool `json:"deprecated,omitempty"`
|
||||
}
|
||||
|
||||
// HasValidation returns true if at least one validator is set on the schema.
|
||||
// This excludes the schema's type but includes most other fields and can be
|
||||
// used to trigger additional slow validation steps when needed.
|
||||
func (s *Schema) HasValidation() bool {
|
||||
if s.Items != nil || len(s.Properties) > 0 || s.AdditionalProperties != nil || len(s.PatternProperties) > 0 || len(s.Required) > 0 || len(s.Enum) > 0 || s.Minimum != nil || s.ExclusiveMinimum != nil || s.Maximum != nil || s.ExclusiveMaximum != nil || s.MultipleOf != 0 || s.MinLength != nil || s.MaxLength != nil || s.Pattern != "" || s.MinItems != nil || s.MaxItems != nil || s.UniqueItems || s.MinProperties != nil || s.MaxProperties != nil || len(s.AllOf) > 0 || len(s.AnyOf) > 0 || len(s.OneOf) > 0 || s.Not != nil {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// GenerateSchema creates a JSON schema for a Go type. Struct field tags
|
||||
// can be used to provide additional metadata such as descriptions and
|
||||
// validation.
|
||||
func GenerateSchema(t reflect.Type) (*Schema, error) {
|
||||
return GenerateSchemaWithMode(t, SchemaModeAll)
|
||||
return GenerateSchemaWithMode(t, SchemaModeAll, nil)
|
||||
}
|
||||
|
||||
// GenerateSchemaWithMode creates a JSON schema for a Go type. Struct field
|
||||
// tags can be used to provide additional metadata such as descriptions and
|
||||
// validation. The mode can be all, read, or write. In read or write mode
|
||||
// any field that is marked as the opposite will be excluded, e.g. a
|
||||
// write-only field would not be included in read mode.
|
||||
func GenerateSchemaWithMode(t reflect.Type, mode SchemaMode) (*Schema, error) {
|
||||
schema := &Schema{}
|
||||
// write-only field would not be included in read mode. If a schema is given
|
||||
// as input, add to it, otherwise creates a new schema.
|
||||
func GenerateSchemaWithMode(t reflect.Type, mode SchemaMode, schema *Schema) (*Schema, error) {
|
||||
if schema == nil {
|
||||
schema = &Schema{}
|
||||
}
|
||||
|
||||
if t == ipType {
|
||||
// Special case: IP address.
|
||||
|
@ -142,7 +168,7 @@ func GenerateSchemaWithMode(t reflect.Type, mode SchemaMode) (*Schema, error) {
|
|||
name = jsonTags[0]
|
||||
}
|
||||
|
||||
s, err := GenerateSchemaWithMode(f.Type, mode)
|
||||
s, err := GenerateSchemaWithMode(f.Type, mode, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -347,37 +373,32 @@ func GenerateSchemaWithMode(t reflect.Type, mode SchemaMode) (*Schema, error) {
|
|||
|
||||
case reflect.Map:
|
||||
schema.Type = "object"
|
||||
s, err := GenerateSchemaWithMode(t.Elem(), mode)
|
||||
s, err := GenerateSchemaWithMode(t.Elem(), mode, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
schema.AdditionalProperties = s
|
||||
case reflect.Slice, reflect.Array:
|
||||
schema.Type = "array"
|
||||
s, err := GenerateSchemaWithMode(t.Elem(), mode)
|
||||
s, err := GenerateSchemaWithMode(t.Elem(), mode, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
schema.Items = s
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
return &Schema{
|
||||
Type: "integer",
|
||||
}, nil
|
||||
schema.Type = "integer"
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
// Unsigned integers can't be negative.
|
||||
min := 0.0
|
||||
return &Schema{
|
||||
Type: "integer",
|
||||
Minimum: &min,
|
||||
}, nil
|
||||
schema.Type = "integer"
|
||||
schema.Minimum = F(0.0)
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return &Schema{Type: "number"}, nil
|
||||
schema.Type = "number"
|
||||
case reflect.Bool:
|
||||
return &Schema{Type: "boolean"}, nil
|
||||
schema.Type = "boolean"
|
||||
case reflect.String:
|
||||
return &Schema{Type: "string"}, nil
|
||||
schema.Type = "string"
|
||||
case reflect.Ptr:
|
||||
return GenerateSchemaWithMode(t.Elem(), mode)
|
||||
return GenerateSchemaWithMode(t.Elem(), mode, schema)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported type %s from %s", t.Kind(), t)
|
||||
}
|
||||
|
|
|
@ -498,3 +498,11 @@ func TestSchemaNonStringExampleErrorCast(t *testing.T) {
|
|||
_, err := GenerateSchema(reflect.ValueOf(Example{}).Type())
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestPointerHelpers(t *testing.T) {
|
||||
// Just confirm this compiles.
|
||||
_ = Schema{
|
||||
Minimum: F(98.6),
|
||||
MinLength: I(5),
|
||||
}
|
||||
}
|
||||
|
|
10
validate.go
10
validate.go
|
@ -51,8 +51,8 @@ func (p *Param) validate(t reflect.Type) error {
|
|||
|
||||
p.typ = t
|
||||
|
||||
if p.Schema == nil {
|
||||
s, err := GenerateSchemaWithMode(p.typ, SchemaModeWrite)
|
||||
if p.Schema == nil || p.Schema.Type == "" {
|
||||
s, err := GenerateSchemaWithMode(p.typ, SchemaModeWrite, p.Schema)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -76,7 +76,7 @@ func (p *Param) validate(t reflect.Type) error {
|
|||
func (h *ResponseHeader) validate(t reflect.Type) error {
|
||||
if h.Schema == nil {
|
||||
// Generate the schema from the handler function types.
|
||||
s, err := GenerateSchemaWithMode(t, SchemaModeRead)
|
||||
s, err := GenerateSchemaWithMode(t, SchemaModeRead, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -197,7 +197,7 @@ func (o *Operation) validate() error {
|
|||
if i == len(types)-1 && requestBody {
|
||||
// The last item has no associated param. It is a request body.
|
||||
if o.RequestSchema == nil {
|
||||
s, err := GenerateSchemaWithMode(paramType, SchemaModeWrite)
|
||||
s, err := GenerateSchemaWithMode(paramType, SchemaModeWrite, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -223,7 +223,7 @@ func (o *Operation) validate() error {
|
|||
// HTTP 204 explicitly forbids a response body.
|
||||
if resp.StatusCode != 204 && resp.Schema == nil {
|
||||
// Generate the schema from the handler function types.
|
||||
s, err := GenerateSchemaWithMode(respType, SchemaModeRead)
|
||||
s, err := GenerateSchemaWithMode(respType, SchemaModeRead, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue