diff --git a/go.sum b/go.sum index e68118d..b195d9e 100644 --- a/go.sum +++ b/go.sum @@ -75,6 +75,7 @@ github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Z github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/go-cmp v0.2.0 h1:+dTQ8DZQJz0Mb/HjFlkptS1FeQ4cWSnN941F8aEG4SQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/go-cmp v0.3.0 h1:crn/baboCvb5fXaQ0IJ1SGTsTVrWpDsCWC8EGETZijY= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= diff --git a/middleware/encoding_test.go b/middleware/encoding_test.go new file mode 100644 index 0000000..a792873 --- /dev/null +++ b/middleware/encoding_test.go @@ -0,0 +1,144 @@ +package middleware + +import ( + "compress/gzip" + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" + + "github.com/andybalholm/brotli" + "github.com/danielgtaylor/huma" + "github.com/danielgtaylor/huma/responses" + "github.com/stretchr/testify/assert" +) + +func TestContentEncodingTooSmall(t *testing.T) { + app, _ := NewTestRouter(t) + app.Resource("/").Get("root", "test", + responses.OK().ContentType("text/plain"), + ).Run(func(ctx huma.Context) { + ctx.Write([]byte("Short string")) + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/", nil) + req.Header.Add("Accept-Encoding", "gzip, br") + app.ServeHTTP(w, req) + + assert.Equal(t, w.Result().StatusCode, http.StatusOK) + assert.Equal(t, "", w.Result().Header.Get("Content-Encoding")) + assert.Equal(t, "Short string", w.Body.String()) +} + +func TestContentEncodingIgnoredPath(t *testing.T) { + app, _ := NewTestRouter(t) + app.Resource("/foo.png").Get("root", "test", + responses.OK().ContentType("image/png"), + ).Run(func(ctx huma.Context) { + ctx.Header().Set("Content-Type", "image/png") + ctx.Write([]byte("fake png")) + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/foo.png", nil) + req.Header.Add("Accept-Encoding", "gzip, br") + app.ServeHTTP(w, req) + + assert.Equal(t, w.Result().StatusCode, http.StatusOK) + assert.Equal(t, "", w.Result().Header.Get("Content-Encoding")) + assert.Equal(t, "fake png", w.Body.String()) +} + +func TestContentEncodingCompressed(t *testing.T) { + app, _ := NewTestRouter(t) + app.Resource("/").Get("root", "test", + responses.OK(), + ).Run(func(ctx huma.Context) { + ctx.Write(make([]byte, 1500)) + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/", nil) + req.Header.Add("Accept-Encoding", "gzip, br") + app.ServeHTTP(w, req) + + assert.Equal(t, w.Result().StatusCode, http.StatusOK) + assert.Equal(t, "br", w.Result().Header.Get("Content-Encoding")) + assert.Less(t, len(w.Body.String()), 1500) + + br := brotli.NewReader(w.Body) + decoded, _ := ioutil.ReadAll(br) + assert.Equal(t, 1500, len(decoded)) +} + +func TestContentEncodingCompressedPick(t *testing.T) { + app, _ := NewTestRouter(t) + app.Resource("/").Get("root", "test", + responses.OK(), + ).Run(func(ctx huma.Context) { + ctx.Write(make([]byte, 1500)) + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/", nil) + req.Header.Add("Accept-Encoding", "gzip, br; q=0.9, deflate") + app.ServeHTTP(w, req) + + assert.Equal(t, w.Result().StatusCode, http.StatusOK) + assert.Equal(t, "gzip", w.Result().Header.Get("Content-Encoding")) + assert.Less(t, len(w.Body.String()), 1500) +} + +func TestContentEncodingCompressedMultiWrite(t *testing.T) { + app, _ := NewTestRouter(t) + app.Resource("/").Get("root", "test", + responses.OK(), + ).Run(func(ctx huma.Context) { + buf := make([]byte, 750) + ctx.Write(buf) + ctx.Write(buf) + ctx.Write(buf) + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/", nil) + req.Header.Add("Accept-Encoding", "gzip") + app.ServeHTTP(w, req) + + assert.Equal(t, w.Result().StatusCode, http.StatusOK) + assert.Equal(t, "gzip", w.Result().Header.Get("Content-Encoding")) + assert.Less(t, len(w.Body.String()), 2250) + + gr, _ := gzip.NewReader(w.Body) + decoded, _ := ioutil.ReadAll(gr) + assert.Equal(t, 2250, len(decoded)) +} + +func TestContentEncodingError(t *testing.T) { + var status int + + app, _ := NewTestRouter(t) + app.Middleware(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wrapped := &statusRecorder{ResponseWriter: w} + next.ServeHTTP(wrapped, r) + status = wrapped.status + }) + }) + + app.Resource("/").Get("root", "test", + responses.OK(), + ).Run(func(ctx huma.Context) { + ctx.WriteHeader(http.StatusNotFound) + ctx.Write([]byte("some text")) + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/", nil) + req.Header.Add("Accept-Encoding", "gzip") + app.ServeHTTP(w, req) + + assert.Equal(t, http.StatusNotFound, status) + assert.Equal(t, http.StatusNotFound, w.Result().StatusCode) +} diff --git a/middleware/logger_test.go b/middleware/logger_test.go new file mode 100644 index 0000000..688edb2 --- /dev/null +++ b/middleware/logger_test.go @@ -0,0 +1,14 @@ +package middleware + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewLogger(t *testing.T) { + // Make sure it returns a logger + l, err := NewDefaultLogger() + assert.NoError(t, err) + assert.NotNil(t, l) +} diff --git a/middleware/middleware.go b/middleware/middleware.go index cd2d7c8..ec602a3 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -33,9 +33,7 @@ func DefaultChain(next http.Handler) http.Handler { // Defaults sets up the default middleware. This convenience function adds the // `DefaultChain` to the router and adds the `--debug` option for logging to // the CLI if app is a CLI. -func Defaults(app interface { - Middlewarer -}) { +func Defaults(app Middlewarer) { // Add the default middleware chain. app.Middleware(DefaultChain) diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go new file mode 100644 index 0000000..8c807cf --- /dev/null +++ b/middleware/middleware_test.go @@ -0,0 +1,21 @@ +package middleware + +import ( + "net/http" + "testing" +) + +type fakeApp struct{} + +func (a *fakeApp) Middleware(middlewares ...func(next http.Handler) http.Handler) {} + +func (a *fakeApp) Flag(name string, short string, description string, defaultValue interface{}) {} + +func (a *fakeApp) PreStart(f func()) { + f() +} + +func TestDefaults(t *testing.T) { + app := &fakeApp{} + Defaults(app) +} diff --git a/middleware/minimal_test.go b/middleware/minimal_test.go new file mode 100644 index 0000000..db3dbb9 --- /dev/null +++ b/middleware/minimal_test.go @@ -0,0 +1,51 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/danielgtaylor/huma" + "github.com/danielgtaylor/huma/responses" + "github.com/stretchr/testify/assert" +) + +func TestPreferMinimalMiddleware(t *testing.T) { + app, _ := NewTestRouter(t) + + app.Resource("/test").Get("id", "desc", + responses.OK().ContentType("text/plain"), + ).Run(func(ctx huma.Context) { + ctx.Write([]byte("Hello, test")) + }) + + app.Resource("/non200").Get("id", "desc", + responses.BadRequest().ContentType("text/plain"), + ).Run(func(ctx huma.Context) { + ctx.WriteHeader(http.StatusBadRequest) + ctx.Write([]byte("Error details")) + }) + + // Normal request + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/test", nil) + app.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + assert.NotEmpty(t, w.Body.String()) + + // Prefer minimal should return 204 No Content + w = httptest.NewRecorder() + req, _ = http.NewRequest(http.MethodGet, "/test", nil) + req.Header.Add("prefer", "return=minimal") + app.ServeHTTP(w, req) + assert.Equal(t, http.StatusNoContent, w.Code) + assert.Empty(t, w.Body.String()) + + // Prefer minimal which can still return non-200 response bodies + w = httptest.NewRecorder() + req, _ = http.NewRequest(http.MethodGet, "/non200", nil) + req.Header.Add("prefer", "return=minimal") + app.ServeHTTP(w, req) + assert.Equal(t, http.StatusBadRequest, w.Code) + assert.NotEmpty(t, w.Body.String()) +} diff --git a/middleware/recovery_test.go b/middleware/recovery_test.go new file mode 100644 index 0000000..9978b7a --- /dev/null +++ b/middleware/recovery_test.go @@ -0,0 +1,84 @@ +package middleware + +import ( + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/danielgtaylor/huma" + "github.com/danielgtaylor/huma/responses" + "github.com/stretchr/testify/assert" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "go.uber.org/zap/zaptest" + "go.uber.org/zap/zaptest/observer" +) + +func NewTestRouter(t testing.TB) (*huma.Router, *observer.ObservedLogs) { + core, logs := observer.New(zapcore.DebugLevel) + + router := huma.New("Test API", "1.0.0") + router.Middleware(DefaultChain) + + NewLogger = func() (*zap.Logger, error) { + l := zaptest.NewLogger(t, zaptest.WrapOptions(zap.WrapCore(func(zapcore.Core) zapcore.Core { return core }))) + return l, nil + } + + return router, logs +} + +func TestRecoveryMiddleware(t *testing.T) { + app, _ := NewTestRouter(t) + + app.Resource("/panic").Get("panic", "Panic recovery test", + responses.NoContent(), + ).Run(func(ctx huma.Context) { + panic(fmt.Errorf("Some error")) + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/panic", nil) + app.ServeHTTP(w, req) + assert.Equal(t, http.StatusInternalServerError, w.Code) + assert.Equal(t, "application/problem+json", w.Result().Header.Get("content-type")) +} + +func TestRecoveryMiddlewareString(t *testing.T) { + app, _ := NewTestRouter(t) + + app.Resource("/panic").Get("panic", "Panic recovery test", + responses.NoContent(), + ).Run(func(ctx huma.Context) { + panic("Some error") + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/panic", nil) + app.ServeHTTP(w, req) + assert.Equal(t, http.StatusInternalServerError, w.Code) + assert.Equal(t, "application/problem+json", w.Result().Header.Get("content-type")) +} + +func TestRecoveryMiddlewareLogBody(t *testing.T) { + app, log := NewTestRouter(t) + + app.Resource("/panic").Put("panic", "Panic recovery test", + responses.NoContent(), + ).Run(func(ctx huma.Context, input struct { + Body struct { + Foo string `json:"foo"` + } + }) { + panic(fmt.Errorf("Some error")) + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodPut, "/panic", strings.NewReader(`{"foo": "bar"}`)) + app.ServeHTTP(w, req) + assert.Equal(t, http.StatusInternalServerError, w.Code) + assert.Equal(t, "application/problem+json", w.Result().Header.Get("content-type")) + assert.Contains(t, log.All()[0].ContextMap()["request"], `{"foo": "bar"}`) +} diff --git a/negotiation/negotiation_test.go b/negotiation/negotiation_test.go new file mode 100644 index 0000000..0219c7a --- /dev/null +++ b/negotiation/negotiation_test.go @@ -0,0 +1,19 @@ +package negotiation + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAccept(t *testing.T) { + assert.Equal(t, "b", SelectQValue("a; q=0.5, b;q=1.0,c; q=0.3", []string{"a", "b", "d"})) +} + +func TestAcceptBest(t *testing.T) { + assert.Equal(t, "b", SelectQValue("a; q=1.0, b;q=1.0,c; q=0.3", []string{"b", "a"})) +} + +func TestNoMatch(t *testing.T) { + assert.Equal(t, "", SelectQValue("a; q=1.0, b;q=1.0,c; q=0.3", []string{"d", "e"})) +} diff --git a/openapi.go b/openapi.go index 75fe60b..687aae6 100644 --- a/openapi.go +++ b/openapi.go @@ -35,7 +35,6 @@ type oaParam struct { Required bool `json:"required,omitempty"` Schema *schema.Schema `json:"schema,omitempty"` Deprecated bool `json:"deprecated,omitempty"` - Example interface{} `json:"example,omitempty"` Explode *bool `json:"explode,omitempty"` // Internal params are excluded from the OpenAPI document and can set up diff --git a/operation.go b/operation.go index 31e890c..6f3ae14 100644 --- a/operation.go +++ b/operation.go @@ -139,8 +139,6 @@ func (o *Operation) Run(handler interface{}) { register = o.resource.mux.Delete } - // TODO: get input param definitions? - t := reflect.TypeOf(handler) if t.Kind() == reflect.Func && t.NumIn() > 1 { var err error @@ -158,6 +156,9 @@ func (o *Operation) Run(handler interface{}) { } } + // Future improvement idea: use a sync.Pool for the input structure to save + // on allocations if the struct has a Reset() method. + register("/", func(w http.ResponseWriter, r *http.Request) { // Limit the request body size and set a read timeout. if r.Body != nil {