test: add middleware/negotiation tests

This commit is contained in:
Daniel G. Taylor 2020-08-29 14:29:37 -07:00
parent 0c8ba518be
commit 6797f3a848
No known key found for this signature in database
GPG key ID: 7BD6DC99C9A87E22
10 changed files with 338 additions and 6 deletions

1
go.sum
View file

@ -75,6 +75,7 @@ github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Z
github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
github.com/google/go-cmp v0.2.0 h1:+dTQ8DZQJz0Mb/HjFlkptS1FeQ4cWSnN941F8aEG4SQ=
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
github.com/google/go-cmp v0.3.0 h1:crn/baboCvb5fXaQ0IJ1SGTsTVrWpDsCWC8EGETZijY=
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=

144
middleware/encoding_test.go Normal file
View file

@ -0,0 +1,144 @@
package middleware
import (
"compress/gzip"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
"github.com/andybalholm/brotli"
"github.com/danielgtaylor/huma"
"github.com/danielgtaylor/huma/responses"
"github.com/stretchr/testify/assert"
)
func TestContentEncodingTooSmall(t *testing.T) {
app, _ := NewTestRouter(t)
app.Resource("/").Get("root", "test",
responses.OK().ContentType("text/plain"),
).Run(func(ctx huma.Context) {
ctx.Write([]byte("Short string"))
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/", nil)
req.Header.Add("Accept-Encoding", "gzip, br")
app.ServeHTTP(w, req)
assert.Equal(t, w.Result().StatusCode, http.StatusOK)
assert.Equal(t, "", w.Result().Header.Get("Content-Encoding"))
assert.Equal(t, "Short string", w.Body.String())
}
func TestContentEncodingIgnoredPath(t *testing.T) {
app, _ := NewTestRouter(t)
app.Resource("/foo.png").Get("root", "test",
responses.OK().ContentType("image/png"),
).Run(func(ctx huma.Context) {
ctx.Header().Set("Content-Type", "image/png")
ctx.Write([]byte("fake png"))
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/foo.png", nil)
req.Header.Add("Accept-Encoding", "gzip, br")
app.ServeHTTP(w, req)
assert.Equal(t, w.Result().StatusCode, http.StatusOK)
assert.Equal(t, "", w.Result().Header.Get("Content-Encoding"))
assert.Equal(t, "fake png", w.Body.String())
}
func TestContentEncodingCompressed(t *testing.T) {
app, _ := NewTestRouter(t)
app.Resource("/").Get("root", "test",
responses.OK(),
).Run(func(ctx huma.Context) {
ctx.Write(make([]byte, 1500))
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/", nil)
req.Header.Add("Accept-Encoding", "gzip, br")
app.ServeHTTP(w, req)
assert.Equal(t, w.Result().StatusCode, http.StatusOK)
assert.Equal(t, "br", w.Result().Header.Get("Content-Encoding"))
assert.Less(t, len(w.Body.String()), 1500)
br := brotli.NewReader(w.Body)
decoded, _ := ioutil.ReadAll(br)
assert.Equal(t, 1500, len(decoded))
}
func TestContentEncodingCompressedPick(t *testing.T) {
app, _ := NewTestRouter(t)
app.Resource("/").Get("root", "test",
responses.OK(),
).Run(func(ctx huma.Context) {
ctx.Write(make([]byte, 1500))
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/", nil)
req.Header.Add("Accept-Encoding", "gzip, br; q=0.9, deflate")
app.ServeHTTP(w, req)
assert.Equal(t, w.Result().StatusCode, http.StatusOK)
assert.Equal(t, "gzip", w.Result().Header.Get("Content-Encoding"))
assert.Less(t, len(w.Body.String()), 1500)
}
func TestContentEncodingCompressedMultiWrite(t *testing.T) {
app, _ := NewTestRouter(t)
app.Resource("/").Get("root", "test",
responses.OK(),
).Run(func(ctx huma.Context) {
buf := make([]byte, 750)
ctx.Write(buf)
ctx.Write(buf)
ctx.Write(buf)
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/", nil)
req.Header.Add("Accept-Encoding", "gzip")
app.ServeHTTP(w, req)
assert.Equal(t, w.Result().StatusCode, http.StatusOK)
assert.Equal(t, "gzip", w.Result().Header.Get("Content-Encoding"))
assert.Less(t, len(w.Body.String()), 2250)
gr, _ := gzip.NewReader(w.Body)
decoded, _ := ioutil.ReadAll(gr)
assert.Equal(t, 2250, len(decoded))
}
func TestContentEncodingError(t *testing.T) {
var status int
app, _ := NewTestRouter(t)
app.Middleware(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
wrapped := &statusRecorder{ResponseWriter: w}
next.ServeHTTP(wrapped, r)
status = wrapped.status
})
})
app.Resource("/").Get("root", "test",
responses.OK(),
).Run(func(ctx huma.Context) {
ctx.WriteHeader(http.StatusNotFound)
ctx.Write([]byte("some text"))
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/", nil)
req.Header.Add("Accept-Encoding", "gzip")
app.ServeHTTP(w, req)
assert.Equal(t, http.StatusNotFound, status)
assert.Equal(t, http.StatusNotFound, w.Result().StatusCode)
}

14
middleware/logger_test.go Normal file
View file

@ -0,0 +1,14 @@
package middleware
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestNewLogger(t *testing.T) {
// Make sure it returns a logger
l, err := NewDefaultLogger()
assert.NoError(t, err)
assert.NotNil(t, l)
}

View file

@ -33,9 +33,7 @@ func DefaultChain(next http.Handler) http.Handler {
// Defaults sets up the default middleware. This convenience function adds the
// `DefaultChain` to the router and adds the `--debug` option for logging to
// the CLI if app is a CLI.
func Defaults(app interface {
Middlewarer
}) {
func Defaults(app Middlewarer) {
// Add the default middleware chain.
app.Middleware(DefaultChain)

View file

@ -0,0 +1,21 @@
package middleware
import (
"net/http"
"testing"
)
type fakeApp struct{}
func (a *fakeApp) Middleware(middlewares ...func(next http.Handler) http.Handler) {}
func (a *fakeApp) Flag(name string, short string, description string, defaultValue interface{}) {}
func (a *fakeApp) PreStart(f func()) {
f()
}
func TestDefaults(t *testing.T) {
app := &fakeApp{}
Defaults(app)
}

View file

@ -0,0 +1,51 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/danielgtaylor/huma"
"github.com/danielgtaylor/huma/responses"
"github.com/stretchr/testify/assert"
)
func TestPreferMinimalMiddleware(t *testing.T) {
app, _ := NewTestRouter(t)
app.Resource("/test").Get("id", "desc",
responses.OK().ContentType("text/plain"),
).Run(func(ctx huma.Context) {
ctx.Write([]byte("Hello, test"))
})
app.Resource("/non200").Get("id", "desc",
responses.BadRequest().ContentType("text/plain"),
).Run(func(ctx huma.Context) {
ctx.WriteHeader(http.StatusBadRequest)
ctx.Write([]byte("Error details"))
})
// Normal request
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/test", nil)
app.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.NotEmpty(t, w.Body.String())
// Prefer minimal should return 204 No Content
w = httptest.NewRecorder()
req, _ = http.NewRequest(http.MethodGet, "/test", nil)
req.Header.Add("prefer", "return=minimal")
app.ServeHTTP(w, req)
assert.Equal(t, http.StatusNoContent, w.Code)
assert.Empty(t, w.Body.String())
// Prefer minimal which can still return non-200 response bodies
w = httptest.NewRecorder()
req, _ = http.NewRequest(http.MethodGet, "/non200", nil)
req.Header.Add("prefer", "return=minimal")
app.ServeHTTP(w, req)
assert.Equal(t, http.StatusBadRequest, w.Code)
assert.NotEmpty(t, w.Body.String())
}

View file

@ -0,0 +1,84 @@
package middleware
import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/danielgtaylor/huma"
"github.com/danielgtaylor/huma/responses"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"go.uber.org/zap/zaptest"
"go.uber.org/zap/zaptest/observer"
)
func NewTestRouter(t testing.TB) (*huma.Router, *observer.ObservedLogs) {
core, logs := observer.New(zapcore.DebugLevel)
router := huma.New("Test API", "1.0.0")
router.Middleware(DefaultChain)
NewLogger = func() (*zap.Logger, error) {
l := zaptest.NewLogger(t, zaptest.WrapOptions(zap.WrapCore(func(zapcore.Core) zapcore.Core { return core })))
return l, nil
}
return router, logs
}
func TestRecoveryMiddleware(t *testing.T) {
app, _ := NewTestRouter(t)
app.Resource("/panic").Get("panic", "Panic recovery test",
responses.NoContent(),
).Run(func(ctx huma.Context) {
panic(fmt.Errorf("Some error"))
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/panic", nil)
app.ServeHTTP(w, req)
assert.Equal(t, http.StatusInternalServerError, w.Code)
assert.Equal(t, "application/problem+json", w.Result().Header.Get("content-type"))
}
func TestRecoveryMiddlewareString(t *testing.T) {
app, _ := NewTestRouter(t)
app.Resource("/panic").Get("panic", "Panic recovery test",
responses.NoContent(),
).Run(func(ctx huma.Context) {
panic("Some error")
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/panic", nil)
app.ServeHTTP(w, req)
assert.Equal(t, http.StatusInternalServerError, w.Code)
assert.Equal(t, "application/problem+json", w.Result().Header.Get("content-type"))
}
func TestRecoveryMiddlewareLogBody(t *testing.T) {
app, log := NewTestRouter(t)
app.Resource("/panic").Put("panic", "Panic recovery test",
responses.NoContent(),
).Run(func(ctx huma.Context, input struct {
Body struct {
Foo string `json:"foo"`
}
}) {
panic(fmt.Errorf("Some error"))
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPut, "/panic", strings.NewReader(`{"foo": "bar"}`))
app.ServeHTTP(w, req)
assert.Equal(t, http.StatusInternalServerError, w.Code)
assert.Equal(t, "application/problem+json", w.Result().Header.Get("content-type"))
assert.Contains(t, log.All()[0].ContextMap()["request"], `{"foo": "bar"}`)
}

View file

@ -0,0 +1,19 @@
package negotiation
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestAccept(t *testing.T) {
assert.Equal(t, "b", SelectQValue("a; q=0.5, b;q=1.0,c; q=0.3", []string{"a", "b", "d"}))
}
func TestAcceptBest(t *testing.T) {
assert.Equal(t, "b", SelectQValue("a; q=1.0, b;q=1.0,c; q=0.3", []string{"b", "a"}))
}
func TestNoMatch(t *testing.T) {
assert.Equal(t, "", SelectQValue("a; q=1.0, b;q=1.0,c; q=0.3", []string{"d", "e"}))
}

View file

@ -35,7 +35,6 @@ type oaParam struct {
Required bool `json:"required,omitempty"`
Schema *schema.Schema `json:"schema,omitempty"`
Deprecated bool `json:"deprecated,omitempty"`
Example interface{} `json:"example,omitempty"`
Explode *bool `json:"explode,omitempty"`
// Internal params are excluded from the OpenAPI document and can set up

View file

@ -139,8 +139,6 @@ func (o *Operation) Run(handler interface{}) {
register = o.resource.mux.Delete
}
// TODO: get input param definitions?
t := reflect.TypeOf(handler)
if t.Kind() == reflect.Func && t.NumIn() > 1 {
var err error
@ -158,6 +156,9 @@ func (o *Operation) Run(handler interface{}) {
}
}
// Future improvement idea: use a sync.Pool for the input structure to save
// on allocations if the struct has a Reset() method.
register("/", func(w http.ResponseWriter, r *http.Request) {
// Limit the request body size and set a read timeout.
if r.Body != nil {