From 1e3569536b4f5d74f35947d78a4f33bc8b02d7dd Mon Sep 17 00:00:00 2001 From: "Daniel G. Taylor" Date: Mon, 21 Mar 2022 17:56:00 -0700 Subject: [PATCH] fix: serialization of time.Time and struct pointers --- context.go | 64 ++++++++++++++++++++++++++++++++++++++---------- go.mod | 1 - resolver_test.go | 5 +++- router_test.go | 40 ++++++++++++++++++++++++++++++ 4 files changed, 95 insertions(+), 15 deletions(-) diff --git a/context.go b/context.go index 5682ef3..d9e5ce6 100644 --- a/context.go +++ b/context.go @@ -11,7 +11,6 @@ import ( "github.com/danielgtaylor/huma/negotiation" "github.com/fxamacker/cbor/v2" "github.com/goccy/go-yaml" - "github.com/mitchellh/mapstructure" ) // allowedHeaders is a list of built-in headers that are always allowed without @@ -211,6 +210,52 @@ func (c *hcontext) URLPrefix() string { return scheme + "://" + c.r.Host } +// shallowStructToMap converts a struct to a map similar to how encoding/json +// would do it, but only one level deep so that the map may be modified before +// serialization. +func shallowStructToMap(v reflect.Value, result map[string]interface{}) { + t := v.Type() + if t.Kind() == reflect.Ptr { + shallowStructToMap(v.Elem(), result) + return + } + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + name := f.Name + if len(name) > 0 && strings.ToUpper(name)[0] != name[0] { + // Private field we somehow have access to? + continue + } + if f.Anonymous { + // Anonymous embedded struct, process its fields as our own. + shallowStructToMap(v.Field(i), result) + continue + } + if json := f.Tag.Get("json"); json != "" { + parts := strings.Split(json, ",") + if parts[0] != "" { + name = parts[0] + } + if name == "-" { + continue + } + if len(parts) == 2 && parts[1] == "omitempty" && v.Field(i).IsZero() { + vf := v.Field(i) + zero := vf.IsZero() + if vf.Kind() == reflect.Slice || vf.Kind() == reflect.Map { + // Special case: omit if they have no items in them to match the + // JSON encoder. + zero = vf.Len() > 0 + } + if zero { + continue + } + } + } + result[name] = v.Field(i).Interface() + } +} + func (c *hcontext) writeModel(ct string, status int, model interface{}) { // Is this allowed? Find the right response. modelRef := "" @@ -261,19 +306,12 @@ func (c *hcontext) writeModel(ct string, status int, model interface{}) { link += "<" + c.docsPrefix + "/schemas/" + id + ".json>; rel=\"describedby\"" c.Header().Set("Link", link) - if !c.disableSchemaProperty && modelType != nil && modelType.Kind() == reflect.Struct { + if modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + if !c.disableSchemaProperty && modelType != nil && modelType.Kind() == reflect.Struct && modelType != timeType { tmp := map[string]interface{}{} - decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ - TagName: "json", - Result: &tmp, - }) - if err != nil { - panic(fmt.Errorf("Unable to initialize struct decoder: %w", err)) - } - err = decoder.Decode(model) - if err != nil { - panic(fmt.Errorf("Unable to convert struct to map: %w", err)) - } + shallowStructToMap(reflect.ValueOf(model), tmp) if tmp["$schema"] == nil { tmp["$schema"] = c.URLPrefix() + c.docsPrefix + "/schemas/" + id + ".json" } diff --git a/go.mod b/go.mod index 0996c7a..003af2b 100644 --- a/go.mod +++ b/go.mod @@ -16,7 +16,6 @@ require ( github.com/koron-go/gqlcost v0.2.2 github.com/magiconair/properties v1.8.6 // indirect github.com/mattn/go-isatty v0.0.14 - github.com/mitchellh/mapstructure v1.4.3 github.com/opentracing/opentracing-go v1.2.0 github.com/spf13/afero v1.8.2 // indirect github.com/spf13/cobra v1.4.0 diff --git a/resolver_test.go b/resolver_test.go index 7dd7ae8..88b7700 100644 --- a/resolver_test.go +++ b/resolver_test.go @@ -31,9 +31,10 @@ func TestExhaustiveErrors(t *testing.T) { w := httptest.NewRecorder() r, _ := http.NewRequest(http.MethodGet, "/?bool=bad&int=bad&float32=bad&float64=bad&tags=1,2,bad&time=bad", strings.NewReader(`{"test": 1}`)) + r.Host = "example.com" app.ServeHTTP(w, r) - assert.JSONEq(t, `{"title":"Bad Request","status":400,"detail":"Error while parsing input parameters","errors":[{"message":"cannot parse boolean","location":"query.bool","value":"bad"},{"message":"cannot parse integer","location":"query.int","value":"bad"},{"message":"cannot parse float","location":"query.float32","value":"bad"},{"message":"cannot parse float","location":"query.float64","value":"bad"},{"message":"cannot parse integer","location":"query[2].tags","value":"bad"},{"message":"unable to validate against schema: invalid character 'b' looking for beginning of value","location":"query.tags","value":"[1,2,bad]"},{"message":"cannot parse time","location":"query.time","value":"bad"},{"message":"Must be greater than or equal to 5","location":"body.test","value":1}]}`, w.Body.String()) + assert.JSONEq(t, `{"$schema": "https://example.com/schemas/ErrorModel.json", "title":"Bad Request","status":400,"detail":"Error while parsing input parameters","errors":[{"message":"cannot parse boolean","location":"query.bool","value":"bad"},{"message":"cannot parse integer","location":"query.int","value":"bad"},{"message":"cannot parse float","location":"query.float32","value":"bad"},{"message":"cannot parse float","location":"query.float64","value":"bad"},{"message":"cannot parse integer","location":"query[2].tags","value":"bad"},{"message":"unable to validate against schema: invalid character 'b' looking for beginning of value","location":"query.tags","value":"[1,2,bad]"},{"message":"cannot parse time","location":"query.time","value":"bad"},{"message":"Must be greater than or equal to 5","location":"body.test","value":1}]}`, w.Body.String()) } type Dep1 struct { @@ -104,9 +105,11 @@ func TestNestedResolverError(t *testing.T) { ] } }`)) + r.Host = "example.com" app.ServeHTTP(w, r) assert.JSONEq(t, `{ + "$schema": "https://example.com/schemas/ErrorModel.json", "status": 400, "title": "Bad Request", "detail": "Error while parsing input parameters", diff --git a/router_test.go b/router_test.go index a628dcf..199b502 100644 --- a/router_test.go +++ b/router_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "fmt" "io" "io/ioutil" "net/http" @@ -138,6 +139,45 @@ func TestModelInputOutput(t *testing.T) { assert.Equal(t, http.StatusNotFound, w.Code) } +func TestRouterEmbeddedStructOutput(t *testing.T) { + type CreatedField struct { + Created time.Time `json:"created,omitempty"` + } + + type Resp struct { + CreatedField + Another string `json:"another"` + Ignored string `json:"-"` + } + + now := time.Now() + + r := New("Test", "1.0.0") + r.Resource("/test").Get("test", "Test", + NewResponse(http.StatusOK, "test").Model(&Resp{}), + ).Run(func(ctx Context) { + ctx.WriteModel(http.StatusOK, &Resp{ + CreatedField: CreatedField{ + Created: now, + }, + Another: "foo", + }) + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/test", nil) + req.Host = "example.com" + r.ServeHTTP(w, req) + + // Assert the response is as expected. + assert.Equal(t, http.StatusOK, w.Code) + assert.JSONEq(t, fmt.Sprintf(`{ + "$schema": "https://example.com/schemas/Resp.json", + "created": "%s", + "another": "foo" + }`, now.Format(time.RFC3339Nano)), w.Body.String()) +} + func TestTooBigBody(t *testing.T) { app := newTestRouter()