diff --git a/openapi.go b/openapi.go index 687aae6..522a9f6 100644 --- a/openapi.go +++ b/openapi.go @@ -1,6 +1,9 @@ package huma import ( + "fmt" + "reflect" + "github.com/danielgtaylor/huma/schema" ) @@ -41,3 +44,59 @@ type oaParam struct { // params sent between a load balander / proxy and the service internally. Internal bool `json:"-"` } + +type oaComponents struct { + Schemas map[string]*schema.Schema `json:"schemas,omitempty"` +} + +func (c *oaComponents) AddSchema(t reflect.Type, mode schema.Mode, hint string) string { + // Try to determine the type's name. + name := t.Name() + if name == "" && t.Kind() == reflect.Ptr { + // Take the name of the pointed-to type. + name = t.Elem().Name() + } + if name == "" && t.Kind() == reflect.Slice { + // Take the name of the type in the array and append "List" to it. + tmp := t.Elem() + if tmp.Kind() == reflect.Ptr { + tmp = tmp.Elem() + } + name = tmp.Name() + if name != "" { + name += "List" + } + } + if name == "" { + // No luck, fall back to the passed-in hint. Better than nothing. + name = hint + } + + s, err := schema.GenerateWithMode(t, mode, nil) + if err != nil { + panic(err) + } + + orig := name + num := 1 + for { + if c.Schemas[name] == nil { + // No existing schema, we are the first! + break + } + + if reflect.DeepEqual(c.Schemas[name], s) { + // Existing schema matches! + break + } + + // If we are here, then an existing schema doesn't match and this is a new + // type. So we will rename it in a deterministic fashion. + num++ + name = fmt.Sprintf("%s%d", orig, num) + } + + c.Schemas[name] = s + + return "#/components/schemas/" + name +} diff --git a/openapi_test.go b/openapi_test.go new file mode 100644 index 0000000..5b72367 --- /dev/null +++ b/openapi_test.go @@ -0,0 +1,48 @@ +package huma + +import ( + "reflect" + "testing" + + "github.com/danielgtaylor/huma/schema" + "github.com/stretchr/testify/assert" +) + +type componentFoo struct { + Field string `json:"field"` + Another string `json:"another" readOnly:"true"` +} + +type componentBar struct { + Field string `json:"field"` +} + +func TestComponentSchemas(t *testing.T) { + components := oaComponents{ + Schemas: map[string]*schema.Schema{}, + } + + // Adding two different versions of the same component. + ref := components.AddSchema(reflect.TypeOf(&componentFoo{}), schema.ModeRead, "hint") + assert.Equal(t, ref, "#/components/schemas/componentFoo") + assert.NotNil(t, components.Schemas["componentFoo"]) + + ref = components.AddSchema(reflect.TypeOf(&componentFoo{}), schema.ModeWrite, "hint") + assert.Equal(t, ref, "#/components/schemas/componentFoo2") + assert.NotNil(t, components.Schemas["componentFoo2"]) + + // Re-adding the second should not create a third. + ref = components.AddSchema(reflect.TypeOf(&componentFoo{}), schema.ModeWrite, "hint") + assert.Equal(t, ref, "#/components/schemas/componentFoo2") + assert.Nil(t, components.Schemas["componentFoo3"]) + + // Adding a list of pointers to a struct. + ref = components.AddSchema(reflect.TypeOf([]*componentBar{}), schema.ModeAll, "hint") + assert.Equal(t, ref, "#/components/schemas/componentBarList") + assert.NotNil(t, components.Schemas["componentBarList"]) + + // Adding an anonymous empty struct, should use the hint. + ref = components.AddSchema(reflect.TypeOf(struct{}{}), schema.ModeAll, "hint") + assert.Equal(t, ref, "#/components/schemas/hint") + assert.NotNil(t, components.Schemas["hint"]) +} diff --git a/operation.go b/operation.go index e0c4a08..8b4b354 100644 --- a/operation.go +++ b/operation.go @@ -43,7 +43,7 @@ func newOperation(resource *Resource, method, id, docs string, responses []Respo } } -func (o *Operation) toOpenAPI() *gabs.Container { +func (o *Operation) toOpenAPI(components *oaComponents) *gabs.Container { doc := gabs.New() doc.Set(o.id, "operationId") @@ -98,11 +98,8 @@ func (o *Operation) toOpenAPI() *gabs.Container { } if resp.model != nil { - schema, err := schema.GenerateWithMode(resp.model, schema.ModeRead, nil) - if err != nil { - panic(err) - } - doc.Set(schema, "responses", status, "content", resp.contentType, "schema") + ref := components.AddSchema(resp.model, schema.ModeRead, o.id) + doc.Set(ref, "responses", status, "content", resp.contentType, "schema", "$ref") } } diff --git a/resource.go b/resource.go index 1eede6b..b949842 100644 --- a/resource.go +++ b/resource.go @@ -22,15 +22,15 @@ type Resource struct { tags []string } -func (r *Resource) toOpenAPI() *gabs.Container { +func (r *Resource) toOpenAPI(components *oaComponents) *gabs.Container { doc := gabs.New() for _, sub := range r.subResources { - doc.Merge(sub.toOpenAPI()) + doc.Merge(sub.toOpenAPI(components)) } for _, op := range r.operations { - opValue := op.toOpenAPI() + opValue := op.toOpenAPI(components) if len(r.tags) > 0 { opValue.Set(r.tags, "tags") diff --git a/router.go b/router.go index 4995f1d..2121ac6 100644 --- a/router.go +++ b/router.go @@ -9,6 +9,7 @@ import ( "time" "github.com/Jeffail/gabs/v2" + "github.com/danielgtaylor/huma/schema" "github.com/go-chi/chi" ) @@ -70,11 +71,17 @@ func (r *Router) OpenAPI() *gabs.Container { doc.Set(r.description, "info", "description") } + components := &oaComponents{ + Schemas: map[string]*schema.Schema{}, + } + paths, _ := doc.Object("paths") for _, res := range r.resources { - paths.Merge(res.toOpenAPI()) + paths.Merge(res.toOpenAPI(components)) } + doc.Set(components, "components") + if r.openapiHook != nil { r.openapiHook(doc) }