mirror of
https://github.com/Fishwaldo/huma.git
synced 2025-03-15 11:21:42 +00:00
test: add middleware/negotiation tests
This commit is contained in:
parent
0c8ba518be
commit
6797f3a848
10 changed files with 338 additions and 6 deletions
1
go.sum
1
go.sum
|
@ -75,6 +75,7 @@ github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Z
|
|||
github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
|
||||
github.com/google/go-cmp v0.2.0 h1:+dTQ8DZQJz0Mb/HjFlkptS1FeQ4cWSnN941F8aEG4SQ=
|
||||
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
|
||||
github.com/google/go-cmp v0.3.0 h1:crn/baboCvb5fXaQ0IJ1SGTsTVrWpDsCWC8EGETZijY=
|
||||
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
|
||||
github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
|
||||
github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
|
||||
|
|
144
middleware/encoding_test.go
Normal file
144
middleware/encoding_test.go
Normal file
|
@ -0,0 +1,144 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/andybalholm/brotli"
|
||||
"github.com/danielgtaylor/huma"
|
||||
"github.com/danielgtaylor/huma/responses"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestContentEncodingTooSmall(t *testing.T) {
|
||||
app, _ := NewTestRouter(t)
|
||||
app.Resource("/").Get("root", "test",
|
||||
responses.OK().ContentType("text/plain"),
|
||||
).Run(func(ctx huma.Context) {
|
||||
ctx.Write([]byte("Short string"))
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Add("Accept-Encoding", "gzip, br")
|
||||
app.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, w.Result().StatusCode, http.StatusOK)
|
||||
assert.Equal(t, "", w.Result().Header.Get("Content-Encoding"))
|
||||
assert.Equal(t, "Short string", w.Body.String())
|
||||
}
|
||||
|
||||
func TestContentEncodingIgnoredPath(t *testing.T) {
|
||||
app, _ := NewTestRouter(t)
|
||||
app.Resource("/foo.png").Get("root", "test",
|
||||
responses.OK().ContentType("image/png"),
|
||||
).Run(func(ctx huma.Context) {
|
||||
ctx.Header().Set("Content-Type", "image/png")
|
||||
ctx.Write([]byte("fake png"))
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest(http.MethodGet, "/foo.png", nil)
|
||||
req.Header.Add("Accept-Encoding", "gzip, br")
|
||||
app.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, w.Result().StatusCode, http.StatusOK)
|
||||
assert.Equal(t, "", w.Result().Header.Get("Content-Encoding"))
|
||||
assert.Equal(t, "fake png", w.Body.String())
|
||||
}
|
||||
|
||||
func TestContentEncodingCompressed(t *testing.T) {
|
||||
app, _ := NewTestRouter(t)
|
||||
app.Resource("/").Get("root", "test",
|
||||
responses.OK(),
|
||||
).Run(func(ctx huma.Context) {
|
||||
ctx.Write(make([]byte, 1500))
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Add("Accept-Encoding", "gzip, br")
|
||||
app.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, w.Result().StatusCode, http.StatusOK)
|
||||
assert.Equal(t, "br", w.Result().Header.Get("Content-Encoding"))
|
||||
assert.Less(t, len(w.Body.String()), 1500)
|
||||
|
||||
br := brotli.NewReader(w.Body)
|
||||
decoded, _ := ioutil.ReadAll(br)
|
||||
assert.Equal(t, 1500, len(decoded))
|
||||
}
|
||||
|
||||
func TestContentEncodingCompressedPick(t *testing.T) {
|
||||
app, _ := NewTestRouter(t)
|
||||
app.Resource("/").Get("root", "test",
|
||||
responses.OK(),
|
||||
).Run(func(ctx huma.Context) {
|
||||
ctx.Write(make([]byte, 1500))
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Add("Accept-Encoding", "gzip, br; q=0.9, deflate")
|
||||
app.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, w.Result().StatusCode, http.StatusOK)
|
||||
assert.Equal(t, "gzip", w.Result().Header.Get("Content-Encoding"))
|
||||
assert.Less(t, len(w.Body.String()), 1500)
|
||||
}
|
||||
|
||||
func TestContentEncodingCompressedMultiWrite(t *testing.T) {
|
||||
app, _ := NewTestRouter(t)
|
||||
app.Resource("/").Get("root", "test",
|
||||
responses.OK(),
|
||||
).Run(func(ctx huma.Context) {
|
||||
buf := make([]byte, 750)
|
||||
ctx.Write(buf)
|
||||
ctx.Write(buf)
|
||||
ctx.Write(buf)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Add("Accept-Encoding", "gzip")
|
||||
app.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, w.Result().StatusCode, http.StatusOK)
|
||||
assert.Equal(t, "gzip", w.Result().Header.Get("Content-Encoding"))
|
||||
assert.Less(t, len(w.Body.String()), 2250)
|
||||
|
||||
gr, _ := gzip.NewReader(w.Body)
|
||||
decoded, _ := ioutil.ReadAll(gr)
|
||||
assert.Equal(t, 2250, len(decoded))
|
||||
}
|
||||
|
||||
func TestContentEncodingError(t *testing.T) {
|
||||
var status int
|
||||
|
||||
app, _ := NewTestRouter(t)
|
||||
app.Middleware(func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
wrapped := &statusRecorder{ResponseWriter: w}
|
||||
next.ServeHTTP(wrapped, r)
|
||||
status = wrapped.status
|
||||
})
|
||||
})
|
||||
|
||||
app.Resource("/").Get("root", "test",
|
||||
responses.OK(),
|
||||
).Run(func(ctx huma.Context) {
|
||||
ctx.WriteHeader(http.StatusNotFound)
|
||||
ctx.Write([]byte("some text"))
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Add("Accept-Encoding", "gzip")
|
||||
app.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusNotFound, status)
|
||||
assert.Equal(t, http.StatusNotFound, w.Result().StatusCode)
|
||||
}
|
14
middleware/logger_test.go
Normal file
14
middleware/logger_test.go
Normal file
|
@ -0,0 +1,14 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewLogger(t *testing.T) {
|
||||
// Make sure it returns a logger
|
||||
l, err := NewDefaultLogger()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, l)
|
||||
}
|
|
@ -33,9 +33,7 @@ func DefaultChain(next http.Handler) http.Handler {
|
|||
// Defaults sets up the default middleware. This convenience function adds the
|
||||
// `DefaultChain` to the router and adds the `--debug` option for logging to
|
||||
// the CLI if app is a CLI.
|
||||
func Defaults(app interface {
|
||||
Middlewarer
|
||||
}) {
|
||||
func Defaults(app Middlewarer) {
|
||||
// Add the default middleware chain.
|
||||
app.Middleware(DefaultChain)
|
||||
|
||||
|
|
21
middleware/middleware_test.go
Normal file
21
middleware/middleware_test.go
Normal file
|
@ -0,0 +1,21 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type fakeApp struct{}
|
||||
|
||||
func (a *fakeApp) Middleware(middlewares ...func(next http.Handler) http.Handler) {}
|
||||
|
||||
func (a *fakeApp) Flag(name string, short string, description string, defaultValue interface{}) {}
|
||||
|
||||
func (a *fakeApp) PreStart(f func()) {
|
||||
f()
|
||||
}
|
||||
|
||||
func TestDefaults(t *testing.T) {
|
||||
app := &fakeApp{}
|
||||
Defaults(app)
|
||||
}
|
51
middleware/minimal_test.go
Normal file
51
middleware/minimal_test.go
Normal file
|
@ -0,0 +1,51 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/danielgtaylor/huma"
|
||||
"github.com/danielgtaylor/huma/responses"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestPreferMinimalMiddleware(t *testing.T) {
|
||||
app, _ := NewTestRouter(t)
|
||||
|
||||
app.Resource("/test").Get("id", "desc",
|
||||
responses.OK().ContentType("text/plain"),
|
||||
).Run(func(ctx huma.Context) {
|
||||
ctx.Write([]byte("Hello, test"))
|
||||
})
|
||||
|
||||
app.Resource("/non200").Get("id", "desc",
|
||||
responses.BadRequest().ContentType("text/plain"),
|
||||
).Run(func(ctx huma.Context) {
|
||||
ctx.WriteHeader(http.StatusBadRequest)
|
||||
ctx.Write([]byte("Error details"))
|
||||
})
|
||||
|
||||
// Normal request
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest(http.MethodGet, "/test", nil)
|
||||
app.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.NotEmpty(t, w.Body.String())
|
||||
|
||||
// Prefer minimal should return 204 No Content
|
||||
w = httptest.NewRecorder()
|
||||
req, _ = http.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.Header.Add("prefer", "return=minimal")
|
||||
app.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusNoContent, w.Code)
|
||||
assert.Empty(t, w.Body.String())
|
||||
|
||||
// Prefer minimal which can still return non-200 response bodies
|
||||
w = httptest.NewRecorder()
|
||||
req, _ = http.NewRequest(http.MethodGet, "/non200", nil)
|
||||
req.Header.Add("prefer", "return=minimal")
|
||||
app.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
assert.NotEmpty(t, w.Body.String())
|
||||
}
|
84
middleware/recovery_test.go
Normal file
84
middleware/recovery_test.go
Normal file
|
@ -0,0 +1,84 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/danielgtaylor/huma"
|
||||
"github.com/danielgtaylor/huma/responses"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
"go.uber.org/zap/zaptest"
|
||||
"go.uber.org/zap/zaptest/observer"
|
||||
)
|
||||
|
||||
func NewTestRouter(t testing.TB) (*huma.Router, *observer.ObservedLogs) {
|
||||
core, logs := observer.New(zapcore.DebugLevel)
|
||||
|
||||
router := huma.New("Test API", "1.0.0")
|
||||
router.Middleware(DefaultChain)
|
||||
|
||||
NewLogger = func() (*zap.Logger, error) {
|
||||
l := zaptest.NewLogger(t, zaptest.WrapOptions(zap.WrapCore(func(zapcore.Core) zapcore.Core { return core })))
|
||||
return l, nil
|
||||
}
|
||||
|
||||
return router, logs
|
||||
}
|
||||
|
||||
func TestRecoveryMiddleware(t *testing.T) {
|
||||
app, _ := NewTestRouter(t)
|
||||
|
||||
app.Resource("/panic").Get("panic", "Panic recovery test",
|
||||
responses.NoContent(),
|
||||
).Run(func(ctx huma.Context) {
|
||||
panic(fmt.Errorf("Some error"))
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest(http.MethodGet, "/panic", nil)
|
||||
app.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusInternalServerError, w.Code)
|
||||
assert.Equal(t, "application/problem+json", w.Result().Header.Get("content-type"))
|
||||
}
|
||||
|
||||
func TestRecoveryMiddlewareString(t *testing.T) {
|
||||
app, _ := NewTestRouter(t)
|
||||
|
||||
app.Resource("/panic").Get("panic", "Panic recovery test",
|
||||
responses.NoContent(),
|
||||
).Run(func(ctx huma.Context) {
|
||||
panic("Some error")
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest(http.MethodGet, "/panic", nil)
|
||||
app.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusInternalServerError, w.Code)
|
||||
assert.Equal(t, "application/problem+json", w.Result().Header.Get("content-type"))
|
||||
}
|
||||
|
||||
func TestRecoveryMiddlewareLogBody(t *testing.T) {
|
||||
app, log := NewTestRouter(t)
|
||||
|
||||
app.Resource("/panic").Put("panic", "Panic recovery test",
|
||||
responses.NoContent(),
|
||||
).Run(func(ctx huma.Context, input struct {
|
||||
Body struct {
|
||||
Foo string `json:"foo"`
|
||||
}
|
||||
}) {
|
||||
panic(fmt.Errorf("Some error"))
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest(http.MethodPut, "/panic", strings.NewReader(`{"foo": "bar"}`))
|
||||
app.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusInternalServerError, w.Code)
|
||||
assert.Equal(t, "application/problem+json", w.Result().Header.Get("content-type"))
|
||||
assert.Contains(t, log.All()[0].ContextMap()["request"], `{"foo": "bar"}`)
|
||||
}
|
19
negotiation/negotiation_test.go
Normal file
19
negotiation/negotiation_test.go
Normal file
|
@ -0,0 +1,19 @@
|
|||
package negotiation
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAccept(t *testing.T) {
|
||||
assert.Equal(t, "b", SelectQValue("a; q=0.5, b;q=1.0,c; q=0.3", []string{"a", "b", "d"}))
|
||||
}
|
||||
|
||||
func TestAcceptBest(t *testing.T) {
|
||||
assert.Equal(t, "b", SelectQValue("a; q=1.0, b;q=1.0,c; q=0.3", []string{"b", "a"}))
|
||||
}
|
||||
|
||||
func TestNoMatch(t *testing.T) {
|
||||
assert.Equal(t, "", SelectQValue("a; q=1.0, b;q=1.0,c; q=0.3", []string{"d", "e"}))
|
||||
}
|
|
@ -35,7 +35,6 @@ type oaParam struct {
|
|||
Required bool `json:"required,omitempty"`
|
||||
Schema *schema.Schema `json:"schema,omitempty"`
|
||||
Deprecated bool `json:"deprecated,omitempty"`
|
||||
Example interface{} `json:"example,omitempty"`
|
||||
Explode *bool `json:"explode,omitempty"`
|
||||
|
||||
// Internal params are excluded from the OpenAPI document and can set up
|
||||
|
|
|
@ -139,8 +139,6 @@ func (o *Operation) Run(handler interface{}) {
|
|||
register = o.resource.mux.Delete
|
||||
}
|
||||
|
||||
// TODO: get input param definitions?
|
||||
|
||||
t := reflect.TypeOf(handler)
|
||||
if t.Kind() == reflect.Func && t.NumIn() > 1 {
|
||||
var err error
|
||||
|
@ -158,6 +156,9 @@ func (o *Operation) Run(handler interface{}) {
|
|||
}
|
||||
}
|
||||
|
||||
// Future improvement idea: use a sync.Pool for the input structure to save
|
||||
// on allocations if the struct has a Reset() method.
|
||||
|
||||
register("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
// Limit the request body size and set a read timeout.
|
||||
if r.Body != nil {
|
||||
|
|
Loading…
Add table
Reference in a new issue