huma/router_test.go
2022-04-21 23:29:44 -07:00

480 lines
12 KiB
Go

package huma
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/danielgtaylor/huma/schema"
"github.com/stretchr/testify/assert"
)
func newTestRouter() *Router {
app := New("Test API", "1.0.0")
return app
}
func TestRouterServiceLink(t *testing.T) {
r := newTestRouter()
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/", nil)
r.ServeHTTP(w, req)
assert.Contains(t, w.Header().Get("Link"), `</openapi.json>; rel="service-desc"`)
assert.Contains(t, w.Header().Get("Link"), `</docs>; rel="service-doc"`)
}
func TestRouterHello(t *testing.T) {
r := New("Test", "1.0.0")
r.Resource("/test").Get("test", "Test",
NewResponse(http.StatusNoContent, "test"),
).Run(func(ctx Context) {
ctx.WriteHeader(http.StatusNoContent)
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/test", nil)
r.ServeHTTP(w, req)
// Assert the response is as expected.
assert.Equal(t, http.StatusNoContent, w.Code)
}
func TestStreamingInput(t *testing.T) {
r := New("Test", "1.0.0")
r.Resource("/stream").Post("stream", "Stream test",
NewResponse(http.StatusNoContent, "test"),
NewResponse(http.StatusInternalServerError, "error"),
).Run(func(ctx Context, input struct {
Body io.Reader
}) {
_, err := ioutil.ReadAll(input.Body)
if err != nil {
ctx.WriteError(http.StatusInternalServerError, "Problem reading input", err)
}
ctx.WriteHeader(http.StatusNoContent)
})
w := httptest.NewRecorder()
body := bytes.NewReader(make([]byte, 1024))
req, _ := http.NewRequest(http.MethodPost, "/stream", body)
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusNoContent, w.Code)
}
func TestModelInputOutput(t *testing.T) {
type Response struct {
Category string `json:"category"`
Hidden bool `json:"hidden"`
Auth string `json:"auth"`
ID string `json:"id"`
Age int `json:"age"`
}
r := New("Test", "1.0.0")
r.Resource("/players/{category}").Post("player", "Create player",
NewResponse(http.StatusOK, "test").Model(Response{}),
).Run(func(ctx Context, input struct {
Category string `path:"category"`
Hidden bool `query:"hidden"`
Auth string `header:"Authorization"`
Body struct {
ID string `json:"id"`
Age int `json:"age" minimum:"16"`
}
}) {
ctx.WriteModel(http.StatusOK, Response{
Category: input.Category,
Hidden: input.Hidden,
Auth: input.Auth,
ID: input.Body.ID,
Age: input.Body.Age,
})
})
w := httptest.NewRecorder()
body := bytes.NewReader([]byte(`{"id": "abc123", "age": 25}`))
req, _ := http.NewRequest(http.MethodPost, "/players/fps?hidden=true", body)
req.Header.Set("Authorization", "dummy")
req.Host = "example.com"
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.JSONEq(t, `{
"$schema": "https://example.com/schemas/Response.json",
"category": "fps",
"hidden": true,
"auth": "dummy",
"id": "abc123",
"age": 25
}`, w.Body.String())
// Should be able to get OpenAPI describing this API with its resource,
// operation, schema, etc.
w = httptest.NewRecorder()
req, _ = http.NewRequest(http.MethodGet, "/openapi.json", nil)
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
// Should be able to get model referenced by the response.
w = httptest.NewRecorder()
req, _ = http.NewRequest(http.MethodGet, "/schemas/Response.json", nil)
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
// Missing schema should return 404
w = httptest.NewRecorder()
req, _ = http.NewRequest(http.MethodGet, "/schemas/Missing.json", nil)
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusNotFound, w.Code)
}
func TestRouterEmbeddedStructOutput(t *testing.T) {
type CreatedField struct {
Created time.Time `json:"created,omitempty"`
}
type Resp struct {
CreatedField
Another string `json:"another"`
Ignored string `json:"-"`
}
now := time.Now()
r := New("Test", "1.0.0")
r.Resource("/test").Get("test", "Test",
NewResponse(http.StatusOK, "test").Model(&Resp{}),
).Run(func(ctx Context) {
ctx.WriteModel(http.StatusOK, &Resp{
CreatedField: CreatedField{
Created: now,
},
Another: "foo",
})
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/test", nil)
req.Host = "example.com"
r.ServeHTTP(w, req)
// Assert the response is as expected.
assert.Equal(t, http.StatusOK, w.Code)
assert.JSONEq(t, fmt.Sprintf(`{
"$schema": "https://example.com/schemas/Resp.json",
"created": "%s",
"another": "foo"
}`, now.Format(time.RFC3339Nano)), w.Body.String())
}
func TestTooBigBody(t *testing.T) {
app := newTestRouter()
type Input struct {
Body struct {
ID string `json:"id"`
}
}
op := app.Resource("/test").Put("put", "desc",
NewResponse(http.StatusNoContent, "desc"),
)
op.MaxBodyBytes(5)
op.Run(func(ctx Context, input Input) {
// Do nothing...
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPut, "/test", strings.NewReader(`{"id": "foo"}`))
app.ServeHTTP(w, req)
assert.Equal(t, http.StatusRequestEntityTooLarge, w.Code)
assert.Contains(t, w.Body.String(), "Request body too large")
// With content length
w = httptest.NewRecorder()
req, _ = http.NewRequest(http.MethodPut, "/test", strings.NewReader(`{"id": "foo"}`))
req.Header.Set("Content-Length", "13")
app.ServeHTTP(w, req)
assert.Equal(t, http.StatusRequestEntityTooLarge, w.Code)
assert.Contains(t, w.Body.String(), "Request body too large")
}
type timeoutError struct{}
func (e *timeoutError) Error() string {
return "timed out"
}
func (e *timeoutError) Timeout() bool {
return true
}
func (e *timeoutError) Temporary() bool {
return false
}
type slowReader struct{}
func (r *slowReader) Read(p []byte) (int, error) {
return 0, &timeoutError{}
}
func TestBodySlow(t *testing.T) {
app := newTestRouter()
type Input struct {
Body struct {
ID string
}
}
op := app.Resource("/test").Put("put", "desc",
NewResponse(http.StatusNoContent, "desc"),
)
op.BodyReadTimeout(1 * time.Millisecond)
op.Run(func(ctx Context, input Input) {
// Do nothing...
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPut, "/test", &slowReader{})
app.ServeHTTP(w, req)
assert.Equal(t, http.StatusRequestTimeout, w.Code)
assert.Contains(t, w.Body.String(), "timed out")
}
func TestErrorHandlers(t *testing.T) {
app := newTestRouter()
app.Resource("/").Get("root", "desc",
NewResponse(http.StatusNoContent, "desc"),
).Run(func(ctx Context) {
// Do nothing
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/notfound", nil)
app.ServeHTTP(w, req)
assert.Equal(t, http.StatusNotFound, w.Code)
assert.Equal(t, "application/problem+json", w.Header().Get("Content-Type"))
assert.Contains(t, w.Body.String(), "/notfound")
w = httptest.NewRecorder()
req, _ = http.NewRequest(http.MethodPut, "/", nil)
app.ServeHTTP(w, req)
assert.Equal(t, http.StatusMethodNotAllowed, w.Code)
assert.Equal(t, "application/problem+json", w.Header().Get("Content-Type"))
assert.Contains(t, w.Body.String(), "PUT")
}
func TestInvalidPathParam(t *testing.T) {
type Input struct {
ThingID string `path:"thing-if"`
}
app := newTestRouter()
// The router has no middleware, so no panic recovery will happen. This lets
// us test via a simple assertion that it would panic, and the actual test
// to ensure a 5xx error happens in the `middleware` package instead.
assert.Panics(t, func() {
app.Resource("/things/{thing-id}").Get("get", "Test",
NewResponse(http.StatusNoContent, "desc"),
).Run(func(ctx Context, input Input) {
// Do nothing
})
})
}
func TestRouterSecurity(t *testing.T) {
app := newTestRouter()
// Document that the API gateway handles auth via OAuth2 Authorization Code.
app.GatewayAuthCode("default", "https://example.com/authorize", "https://example.com/token", nil)
app.GatewayClientCredentials("m2m", "https://example.com/token", nil)
app.GatewayBasicAuth("basic")
// Every call must be authenticated using the default auth mechanism
// registered above.
app.SecurityRequirement("default")
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/openapi.json", nil)
app.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var parsed map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &parsed)
assert.Nil(t, err)
assert.Equal(t, parsed["security"], []interface{}{
map[string]interface{}{
"default": []interface{}{},
}})
assert.Equal(t, parsed["components"].(map[string]interface{})["securitySchemes"], map[string]interface{}{
"default": map[string]interface{}{
"type": "oauth2",
"flows": map[string]interface{}{
"authorizationCode": map[string]interface{}{
"authorizationUrl": "https://example.com/authorize",
"tokenUrl": "https://example.com/token",
},
},
},
"m2m": map[string]interface{}{
"type": "oauth2",
"flows": map[string]interface{}{
"clientCredentials": map[string]interface{}{
"tokenUrl": "https://example.com/token",
},
},
},
"basic": map[string]interface{}{
"type": "http",
"scheme": "basic",
"flows": map[string]interface{}{},
},
})
}
// TODO: test app.AutoConfig
func TestRouterAutoConfig(t *testing.T) {
app := newTestRouter()
app.GatewayAuthCode("authcode", "https://example.com/authorize", "https://example.com/token", nil)
app.SecurityRequirement("authcode")
app.AutoConfig(AutoConfig{
Security: "authcode",
Prompt: map[string]AutoConfigVar{
"extra": {
Description: "Some extra value",
Example: "abc123",
},
},
Params: map[string]string{
"another": "https://example.com/extras/{extra}",
},
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/openapi.json", nil)
app.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var parsed map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &parsed)
assert.Nil(t, err)
assert.Equal(t, parsed["x-cli-config"], map[string]interface{}{
"security": "authcode",
"prompt": map[string]interface{}{
"extra": map[string]interface{}{
"description": "Some extra value",
"example": "abc123",
},
},
"params": map[string]interface{}{
"another": "https://example.com/extras/{extra}",
},
})
}
func TestCustomRequestSchema(t *testing.T) {
app := newTestRouter()
op := app.Resource("/foo").Post("id", "doc", NewResponse(http.StatusOK, "ok"))
op.RequestSchema(&schema.Schema{
Type: schema.TypeInteger,
Minimum: schema.F(0),
Maximum: schema.F(100),
})
op.Run(func(ctx Context, input struct {
Body int
}) {
// Custom schema should be used so we never get here.
ctx.WriteHeader(http.StatusOK)
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/foo", strings.NewReader("1234"))
app.ServeHTTP(w, req)
assert.Equal(t, http.StatusUnprocessableEntity, w.Result().StatusCode)
}
func TestGetOperationName(t *testing.T) {
app := newTestRouter()
var opInfo *OperationInfo
app.Middleware(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, r)
opInfo = GetOperationInfo(r.Context())
})
})
app.Resource("/").Get("test-id", "doc",
NewResponse(http.StatusOK, "ok"),
).Run(func(ctx Context) {
// Do nothing!
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/", nil)
app.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Result().StatusCode)
assert.Equal(t, "test-id", opInfo.ID)
assert.Equal(t, "/", opInfo.URITemplate)
assert.Equal(t, "doc", opInfo.Summary)
assert.Equal(t, []string{}, opInfo.Tags)
}
func TestGetOperationDoesNotCrash(t *testing.T) {
ctx := context.Background()
assert.NotPanics(t, func() {
info := GetOperationInfo(ctx)
assert.NotNil(t, info)
})
}
func TestSubResource(t *testing.T) {
app := newTestRouter()
// This should not crash.
app.Resource("/").SubResource("/foo").SubResource("/bar").Get("get-bar", "docs", NewResponse(http.StatusOK, "ok")).Run(func(ctx Context) {
// Do nothing
})
}
func TestDefaultResponse(t *testing.T) {
app := newTestRouter()
// This should not crash.
app.Resource("/").Get("get-root", "docs", NewResponse(0, "Default repsonse")).Run(func(ctx Context) {
ctx.WriteHeader(http.StatusOK)
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/", nil)
app.ServeHTTP(w, req)
// This should not panic and should return the 200 OK
assert.Equal(t, http.StatusOK, w.Result().StatusCode)
}