mirror of
https://github.com/Fishwaldo/huma.git
synced 2025-03-15 11:21:42 +00:00
404 lines
11 KiB
Go
404 lines
11 KiB
Go
package huma
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"reflect"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/danielgtaylor/huma/negotiation"
|
|
"github.com/fxamacker/cbor/v2"
|
|
"github.com/goccy/go-yaml"
|
|
)
|
|
|
|
// allowedHeaders is a list of built-in headers that are always allowed without
|
|
// explicitly being documented. Mostly they are low-level HTTP headers that
|
|
// control access or connection settings.
|
|
var allowedHeaders = map[string]bool{
|
|
"access-control-allow-origin": true,
|
|
"access-control-allow-methods": true,
|
|
"access-control-allow-headers": true,
|
|
"access-control-max-age": true,
|
|
"connection": true,
|
|
"keep-alive": true,
|
|
"vary": true,
|
|
}
|
|
|
|
// AddAllowedHeader adds a header to the list of allowed headers for the response.
|
|
// This is useful for common headers that might be sent due to middleware
|
|
// or other frameworks
|
|
func AddAllowedHeaders(name ...string) {
|
|
for _, h := range name {
|
|
allowedHeaders[strings.ToLower(h)] = true
|
|
}
|
|
}
|
|
|
|
|
|
// ContextFromRequest returns a Huma context for a request, useful for
|
|
// accessing high-level convenience functions from e.g. middleware.
|
|
func ContextFromRequest(w http.ResponseWriter, r *http.Request) Context {
|
|
return &hcontext{
|
|
Context: r.Context(),
|
|
ResponseWriter: w,
|
|
r: r,
|
|
}
|
|
}
|
|
|
|
// Context provides a request context and response writer with convenience
|
|
// functions for error and model marshaling in handler functions.
|
|
type Context interface {
|
|
context.Context
|
|
http.ResponseWriter
|
|
|
|
// WithValue returns a shallow copy of the context with a new context value
|
|
// applied to it.
|
|
WithValue(key, value interface{}) Context
|
|
|
|
// SetValue sets a context value. The Huma context is modified in place while
|
|
// the underlying request context is copied. This is particularly useful for
|
|
// setting context values from input resolver functions.
|
|
SetValue(key, value interface{})
|
|
|
|
// AddError adds a new error to the list of errors for this request.
|
|
AddError(err error)
|
|
|
|
// HasError returns true if at least one error has been added to the context.
|
|
HasError() bool
|
|
|
|
// WriteError writes out an HTTP status code, friendly error message, and
|
|
// optionally a set of error details set with `AddError` and/or passed in.
|
|
WriteError(status int, message string, errors ...error)
|
|
|
|
// WriteModel writes out an HTTP status code and marshalled model based on
|
|
// content negotiation (e.g. JSON or CBOR). This must match the registered
|
|
// response status code & type.
|
|
WriteModel(status int, model interface{})
|
|
|
|
// WriteContent wraps http.ServeContent in order to handle serving streams
|
|
// it will handle Range and Modified (like If-Unmodified-Since) headers.
|
|
WriteContent(name string, content io.ReadSeeker, lastModified time.Time)
|
|
}
|
|
|
|
type hcontext struct {
|
|
context.Context
|
|
http.ResponseWriter
|
|
r *http.Request
|
|
errors []error
|
|
errorCode int
|
|
op *Operation
|
|
closed bool
|
|
docsPath string
|
|
schemasPath string
|
|
specPath string
|
|
urlPrefix string
|
|
disableSchemaProperty bool
|
|
}
|
|
|
|
func (c *hcontext) WithValue(key, value interface{}) Context {
|
|
r := c.r.WithContext(context.WithValue(c.r.Context(), key, value))
|
|
return &hcontext{
|
|
Context: context.WithValue(c.Context, key, value),
|
|
ResponseWriter: c.ResponseWriter,
|
|
r: r,
|
|
errors: append([]error{}, c.errors...),
|
|
op: c.op,
|
|
closed: c.closed,
|
|
disableSchemaProperty: c.disableSchemaProperty,
|
|
}
|
|
}
|
|
|
|
func (c *hcontext) SetValue(key, value interface{}) {
|
|
c.r = c.r.WithContext(context.WithValue(c.r.Context(), key, value))
|
|
c.Context = c.r.Context()
|
|
}
|
|
|
|
func (c *hcontext) AddError(err error) {
|
|
c.errors = append(c.errors, err)
|
|
}
|
|
|
|
func (c *hcontext) HasError() bool {
|
|
return len(c.errors) > 0
|
|
}
|
|
|
|
func (c *hcontext) WriteHeader(status int) {
|
|
if c.op != nil {
|
|
allowed := []string{}
|
|
for _, r := range c.op.responses {
|
|
if r.status == status || r.status == 0 {
|
|
for _, h := range r.headers {
|
|
allowed = append(allowed, h)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Check that all headers were allowed to be sent.
|
|
for name := range c.Header() {
|
|
if allowedHeaders[strings.ToLower(name)] {
|
|
continue
|
|
}
|
|
|
|
found := false
|
|
|
|
for _, h := range allowed {
|
|
if strings.ToLower(name) == strings.ToLower(h) {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if !found {
|
|
panic(fmt.Errorf("Response header %s is not declared for %s %s with status code %d (allowed: %s)", name, c.r.Method, c.r.URL.Path, status, allowed))
|
|
}
|
|
}
|
|
}
|
|
|
|
c.ResponseWriter.WriteHeader(status)
|
|
}
|
|
|
|
func (c *hcontext) Write(data []byte) (int, error) {
|
|
if c.closed {
|
|
panic(fmt.Errorf("Trying to write to response after WriteModel, WriteError, or WriteContent for %s %s", c.r.Method, c.r.URL.Path))
|
|
}
|
|
|
|
return c.ResponseWriter.Write(data)
|
|
}
|
|
|
|
func (c *hcontext) WriteError(status int, message string, errors ...error) {
|
|
if c.closed {
|
|
panic(fmt.Errorf("Trying to write to response after WriteModel, WriteError, or WriteContent for %s %s", c.r.Method, c.r.URL.Path))
|
|
}
|
|
|
|
details := []*ErrorDetail{}
|
|
|
|
c.errors = append(c.errors, errors...)
|
|
for _, err := range c.errors {
|
|
if d, ok := err.(ErrorDetailer); ok {
|
|
details = append(details, d.ErrorDetail())
|
|
} else {
|
|
details = append(details, &ErrorDetail{Message: err.Error()})
|
|
}
|
|
}
|
|
|
|
model := &ErrorModel{
|
|
Title: http.StatusText(status),
|
|
Status: status,
|
|
Detail: message,
|
|
Errors: details,
|
|
}
|
|
|
|
// Select content type and transform it to the appropriate error type.
|
|
ct := selectContentType(c.r)
|
|
switch ct {
|
|
case "application/cbor":
|
|
ct = "application/problem+cbor"
|
|
case "", "application/json":
|
|
ct = "application/problem+json"
|
|
case "application/yaml", "application/x-yaml":
|
|
ct = "application/problem+yaml"
|
|
}
|
|
|
|
c.writeModel(ct, status, model)
|
|
}
|
|
|
|
func (c *hcontext) WriteModel(status int, model interface{}) {
|
|
if c.closed {
|
|
panic(fmt.Errorf("Trying to write to response after WriteModel, WriteError, or WriteContent for %s %s", c.r.Method, c.r.URL.Path))
|
|
}
|
|
|
|
// Get the negotiated content type the client wants and we are willing to
|
|
// provide.
|
|
ct := selectContentType(c.r)
|
|
|
|
c.writeModel(ct, status, model)
|
|
}
|
|
|
|
// URLPrefix returns the prefix to use for non-relative URL links.
|
|
func (c *hcontext) URLPrefix() string {
|
|
if c.urlPrefix != "" {
|
|
return c.urlPrefix
|
|
}
|
|
|
|
scheme := "https"
|
|
if strings.HasPrefix(c.r.Host, "localhost") {
|
|
scheme = "http"
|
|
}
|
|
|
|
return scheme + "://" + c.r.Host
|
|
}
|
|
|
|
// shallowStructToMap converts a struct to a map similar to how encoding/json
|
|
// would do it, but only one level deep so that the map may be modified before
|
|
// serialization.
|
|
func shallowStructToMap(v reflect.Value, result map[string]interface{}) {
|
|
t := v.Type()
|
|
if t.Kind() == reflect.Ptr {
|
|
shallowStructToMap(v.Elem(), result)
|
|
return
|
|
}
|
|
for i := 0; i < t.NumField(); i++ {
|
|
f := t.Field(i)
|
|
name := f.Name
|
|
if len(name) > 0 && strings.ToUpper(name)[0] != name[0] {
|
|
// Private field we somehow have access to?
|
|
continue
|
|
}
|
|
if f.Anonymous {
|
|
// Anonymous embedded struct, process its fields as our own.
|
|
shallowStructToMap(v.Field(i), result)
|
|
continue
|
|
}
|
|
if json := f.Tag.Get("json"); json != "" {
|
|
parts := strings.Split(json, ",")
|
|
if parts[0] != "" {
|
|
name = parts[0]
|
|
}
|
|
if name == "-" {
|
|
continue
|
|
}
|
|
if len(parts) > 1 && parts[1] == "omitempty" {
|
|
vf := v.Field(i)
|
|
zero := vf.IsZero()
|
|
if vf.Kind() == reflect.Slice || vf.Kind() == reflect.Map {
|
|
// Special case: omit if they have no items in them to match the
|
|
// JSON encoder.
|
|
zero = vf.Len() == 0
|
|
}
|
|
if zero {
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
result[name] = v.Field(i).Interface()
|
|
}
|
|
}
|
|
|
|
func (c *hcontext) writeModel(ct string, status int, model interface{}) {
|
|
// Is this allowed? Find the right response.
|
|
modelRef := ""
|
|
modelType := reflect.TypeOf(model)
|
|
if c.op != nil {
|
|
responses := []Response{}
|
|
names := []string{}
|
|
statuses := []string{}
|
|
for _, r := range c.op.responses {
|
|
statuses = append(statuses, fmt.Sprintf("%d", r.status))
|
|
if r.status == status || r.status == 0 {
|
|
responses = append(responses, r)
|
|
if r.model != nil {
|
|
names = append(names, r.model.Name())
|
|
}
|
|
}
|
|
}
|
|
|
|
if len(responses) == 0 {
|
|
panic(fmt.Errorf("HTTP status %d not allowed for %s %s, expected one of %s", status, c.r.Method, c.r.URL.Path, statuses))
|
|
}
|
|
|
|
found := false
|
|
for _, r := range responses {
|
|
if r.model == modelType {
|
|
found = true
|
|
modelRef = r.modelRef
|
|
break
|
|
}
|
|
}
|
|
|
|
if !found {
|
|
panic(fmt.Errorf("Invalid model %s, expecting %s for %s %s", modelType, strings.Join(names, ", "), c.r.Method, c.r.URL.Path))
|
|
}
|
|
}
|
|
|
|
// If possible, insert a link relation header to the JSON Schema describing
|
|
// this response. If it's an object (not an array), then we can also try
|
|
// inserting the `$schema` key to make editing & validation easier.
|
|
parts := strings.Split(modelRef, "/")
|
|
if len(parts) > 0 {
|
|
id := parts[len(parts)-1]
|
|
|
|
link := c.Header().Get("Link")
|
|
if link != "" {
|
|
link += ", "
|
|
}
|
|
link += "<" + c.schemasPath + "/" + id + ".json>; rel=\"describedby\""
|
|
c.Header().Set("Link", link)
|
|
|
|
if modelType.Kind() == reflect.Ptr {
|
|
modelType = modelType.Elem()
|
|
}
|
|
if !c.disableSchemaProperty && modelType != nil && modelType.Kind() == reflect.Struct && modelType != timeType {
|
|
tmp := map[string]interface{}{}
|
|
shallowStructToMap(reflect.ValueOf(model), tmp)
|
|
if tmp["$schema"] == nil {
|
|
tmp["$schema"] = c.URLPrefix() + c.schemasPath + "/" + id + ".json"
|
|
}
|
|
model = tmp
|
|
}
|
|
}
|
|
|
|
// Do the appropriate encoding.
|
|
var encoded []byte
|
|
var err error
|
|
if strings.HasPrefix(ct, "application/json") || strings.HasSuffix(ct, "+json") {
|
|
encoded, err = json.Marshal(model)
|
|
if err != nil {
|
|
panic(fmt.Errorf("Unable to marshal JSON: %w", err))
|
|
}
|
|
} else if strings.HasPrefix(ct, "application/yaml") || strings.HasPrefix(ct, "application/x-yaml") || strings.HasSuffix(ct, "+yaml") {
|
|
encoded, err = yaml.Marshal(model)
|
|
if err != nil {
|
|
panic(fmt.Errorf("Unable to marshal YAML: %w", err))
|
|
}
|
|
} else if strings.HasPrefix(ct, "application/cbor") || strings.HasSuffix(ct, "+cbor") {
|
|
opts := cbor.CanonicalEncOptions()
|
|
opts.Time = cbor.TimeRFC3339Nano
|
|
opts.TimeTag = cbor.EncTagRequired
|
|
mode, err := opts.EncMode()
|
|
if err != nil {
|
|
panic(fmt.Errorf("Unable to marshal CBOR: %w", err))
|
|
}
|
|
encoded, err = mode.Marshal(model)
|
|
if err != nil {
|
|
panic(fmt.Errorf("Unable to marshal CBOR: %w", err))
|
|
}
|
|
}
|
|
|
|
// Encoding succeeded, write the data!
|
|
c.Header().Set("Content-Type", ct)
|
|
c.WriteHeader(status)
|
|
c.Write(encoded)
|
|
c.closed = true
|
|
}
|
|
|
|
// selectContentType selects the best availalable content type via content
|
|
// negotiation with the client, defaulting to JSON.
|
|
func selectContentType(r *http.Request) string {
|
|
ct := "application/json"
|
|
|
|
if accept := r.Header.Get("Accept"); accept != "" {
|
|
best := negotiation.SelectQValue(accept, []string{
|
|
"application/cbor",
|
|
"application/json",
|
|
"application/yaml",
|
|
"application/x-yaml",
|
|
})
|
|
|
|
if best != "" {
|
|
ct = best
|
|
}
|
|
}
|
|
|
|
return ct
|
|
}
|
|
|
|
func (c *hcontext) WriteContent(name string, content io.ReadSeeker, lastModified time.Time) {
|
|
if c.closed {
|
|
panic(fmt.Errorf("Trying to write to response after WriteModel, WriteError, or WriteContent for %s %s", c.r.Method, c.r.URL.Path))
|
|
}
|
|
|
|
http.ServeContent(c.ResponseWriter, c.r, name, lastModified, content)
|
|
c.closed = true
|
|
}
|