huma/context_test.go

369 lines
10 KiB
Go

package huma
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/fxamacker/cbor/v2"
"github.com/goccy/go-yaml"
"github.com/stretchr/testify/assert"
)
func TestGetContextFromRequest(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r = r.WithContext(context.WithValue(r.Context(), contextKey("foo"), "bar"))
ctx := ContextFromRequest(w, r)
assert.Equal(t, "bar", ctx.Value(contextKey("foo")))
})
w := httptest.NewRecorder()
r, _ := http.NewRequest(http.MethodGet, "/", nil)
handler(w, r)
}
func TestContentNegotiation(t *testing.T) {
type Response struct {
Value string `json:"value"`
}
app := newTestRouter()
app.Resource("/negotiated").Get("test", "Test",
NewResponse(200, "desc").Model(Response{}),
).Run(func(ctx Context) {
ctx.WriteModel(http.StatusOK, Response{
Value: "Hello, world!",
})
})
var parsed Response
expected := Response{
Value: "Hello, world!",
}
// No preference
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/negotiated", nil)
app.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "application/json", w.Header().Get("Content-Type"))
err := json.Unmarshal(w.Body.Bytes(), &parsed)
assert.NoError(t, err)
assert.Equal(t, expected, parsed)
// Prefer JSON
w = httptest.NewRecorder()
req, _ = http.NewRequest(http.MethodGet, "/negotiated", nil)
req.Header.Set("Accept", "application/json")
app.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "application/json", w.Header().Get("Content-Type"))
err = json.Unmarshal(w.Body.Bytes(), &parsed)
assert.NoError(t, err)
assert.Equal(t, expected, parsed)
// Prefer YAML
w = httptest.NewRecorder()
req, _ = http.NewRequest(http.MethodGet, "/negotiated", nil)
req.Header.Set("Accept", "application/yaml")
app.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "application/yaml", w.Header().Get("Content-Type"))
err = yaml.Unmarshal(w.Body.Bytes(), &parsed)
assert.NoError(t, err)
assert.Equal(t, expected, parsed)
// Prefer CBOR
w = httptest.NewRecorder()
req, _ = http.NewRequest(http.MethodGet, "/negotiated", nil)
req.Header.Set("Accept", "application/cbor")
app.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "application/cbor", w.Header().Get("Content-Type"))
err = cbor.Unmarshal(w.Body.Bytes(), &parsed)
assert.NoError(t, err)
assert.EqualValues(t, expected, parsed)
}
func TestErrorNegotiation(t *testing.T) {
app := newTestRouter()
app.Resource("/error").Get("test", "Test",
NewResponse(400, "desc").Model(&ErrorModel{}),
).Run(func(ctx Context) {
ctx.AddError(fmt.Errorf("some error"))
ctx.AddError(&ErrorDetail{
Message: "Invalid value",
Location: "body.field",
Value: "test",
})
ctx.WriteError(http.StatusBadRequest, "test error")
})
var parsed ErrorModel
expected := ErrorModel{
Status: http.StatusBadRequest,
Title: http.StatusText(http.StatusBadRequest),
Detail: "test error",
Errors: []*ErrorDetail{
{Message: "some error"},
{
Message: "Invalid value",
Location: "body.field",
Value: "test",
},
},
}
// No preference
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/error", nil)
app.ServeHTTP(w, req)
assert.Equal(t, http.StatusBadRequest, w.Code)
assert.Equal(t, "application/problem+json", w.Header().Get("Content-Type"))
err := json.Unmarshal(w.Body.Bytes(), &parsed)
assert.NoError(t, err)
assert.Equal(t, expected, parsed)
// Prefer JSON
w = httptest.NewRecorder()
req, _ = http.NewRequest(http.MethodGet, "/error", nil)
req.Header.Set("Accept", "application/json")
app.ServeHTTP(w, req)
assert.Equal(t, http.StatusBadRequest, w.Code)
assert.Equal(t, "application/problem+json", w.Header().Get("Content-Type"))
err = json.Unmarshal(w.Body.Bytes(), &parsed)
assert.NoError(t, err)
assert.Equal(t, expected, parsed)
// Prefer YAML
w = httptest.NewRecorder()
req, _ = http.NewRequest(http.MethodGet, "/error", nil)
req.Header.Set("Accept", "application/yaml")
app.ServeHTTP(w, req)
assert.Equal(t, http.StatusBadRequest, w.Code)
assert.Equal(t, "application/problem+yaml", w.Header().Get("Content-Type"))
err = yaml.Unmarshal(w.Body.Bytes(), &parsed)
assert.NoError(t, err)
assert.EqualValues(t, expected, parsed)
// Prefer CBOR
w = httptest.NewRecorder()
req, _ = http.NewRequest(http.MethodGet, "/error", nil)
req.Header.Set("Accept", "application/cbor")
app.ServeHTTP(w, req)
assert.Equal(t, http.StatusBadRequest, w.Code)
assert.Equal(t, "application/problem+cbor", w.Header().Get("Content-Type"))
err = cbor.Unmarshal(w.Body.Bytes(), &parsed)
assert.NoError(t, err)
assert.Equal(t, expected, parsed)
}
func TestInvalidModel(t *testing.T) {
type R1 struct {
Foo string `json:"foo"`
}
type R2 struct {
Bar string `json:"bar"`
}
app := newTestRouter()
app.Resource("/bad-status").Get("test", "Test",
NewResponse(http.StatusOK, "desc").Model(R1{}),
).Run(func(ctx Context) {
ctx.WriteModel(http.StatusNoContent, R2{Bar: "blah"})
})
app.Resource("/bad-model").Get("test", "Test",
NewResponse(http.StatusOK, "desc").Model(R1{}),
).Run(func(ctx Context) {
ctx.WriteModel(http.StatusOK, R2{Bar: "blah"})
})
assert.Panics(t, func() {
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/bad-status", nil)
app.ServeHTTP(w, req)
})
assert.Panics(t, func() {
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/bad-model", nil)
app.ServeHTTP(w, req)
})
}
func TestInvalidHeader(t *testing.T) {
app := newTestRouter()
app.Resource("/").Get("test", "Test",
NewResponse(http.StatusNoContent, "desc").Headers("Extra"),
).Run(func(ctx Context) {
// Typo in the header should not be allowed
ctx.Header().Set("Extra2", "some-value")
ctx.WriteHeader(http.StatusNoContent)
})
assert.Panics(t, func() {
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/", nil)
app.ServeHTTP(w, req)
})
}
func TestWriteAfterClose(t *testing.T) {
app := newTestRouter()
app.Resource("/").Get("test", "Test",
NewResponse(http.StatusBadRequest, "desc").Model(&ErrorModel{}),
).Run(func(ctx Context) {
ctx.WriteError(http.StatusBadRequest, "some error")
// Second write should fail
ctx.WriteError(http.StatusBadRequest, "some error")
})
assert.Panics(t, func() {
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/", nil)
app.ServeHTTP(w, req)
})
}
func TestValue(t *testing.T) {
key := contextKey("foo")
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := ContextFromRequest(w, r)
// Create a copy with a new value attached.
ctx2 := ctx.WithValue(key, "bar")
assert.Equal(t, nil, ctx.Value(key))
assert.Equal(t, "bar", ctx2.Value(key))
// Set a value in-place in the original
ctx.SetValue(key, "baz")
// Make sure it was set on ctx, and that ctx2 wasn't modified since it's
// a copy created before the call to `SetValue(...)`.
assert.Equal(t, "baz", ctx.Value(key))
assert.Equal(t, "bar", ctx2.Value(key))
})
w := httptest.NewRecorder()
r, _ := http.NewRequest(http.MethodGet, "/", nil)
handler(w, r)
}
func TestWriteContent(t *testing.T) {
app := newTestRouter()
b := []byte("Test Byte Data")
app.Resource("/content").Get("test", "Test",
NewResponse(200, "desc").Model(Response{}),
).Run(func(ctx Context) {
ctx.Header().Set("Content-Type", "application/octet-stream")
content := bytes.NewReader(b)
ctx.WriteContent("", content, time.Time{})
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/content", nil)
app.ServeHTTP(w, req)
assert.Equal(t, 200, w.Result().StatusCode)
assert.Equal(t, "application/octet-stream", w.Header().Get("Content-Type"))
assert.Equal(t, b, w.Body.Bytes())
}
func TestWriteContentRespectsRange(t *testing.T) {
app := newTestRouter()
b := []byte("Test Byte Data")
app.Resource("/content").Get("test", "Test",
NewResponse(206, "desc").Model(Response{}),
).Run(func(ctx Context) {
ctx.Header().Set("Content-Type", "application/octet-stream")
content := bytes.NewReader(b)
ctx.WriteContent("", content, time.Time{})
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/content", nil)
req.Header.Set("Range", "bytes=0-5")
app.ServeHTTP(w, req)
// confirms that Range is properly being forwarded to ServeContent.
// we'll assume more advanced range use cases are properly tested
// in the http library.
assert.Equal(t, w.Result().StatusCode, 206)
assert.Equal(t, []byte("Test B"), w.Body.Bytes())
}
func TestWriteContentLastModified(t *testing.T) {
app := newTestRouter()
b := []byte("Test Byte Data")
modTime := time.Now()
app.Resource("/content").Get("test", "Test",
NewResponse(206, "desc").Model(Response{}),
).Run(func(ctx Context) {
ctx.Header().Set("Content-Type", "application/octet-stream")
content := bytes.NewReader(b)
ctx.WriteContent("", content, modTime)
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/content", nil)
app.ServeHTTP(w, req)
// confirms that modTime is properly being forwarded to ServeContent.
// We'll assume the more advanced modTime use cases are properly tested
// in http library.
strTime := modTime.UTC().Format(http.TimeFormat)
assert.Equal(t, strTime, w.Header().Get("Last-Modified"))
}
func TestWriteContentName(t *testing.T) {
app := newTestRouter()
b := []byte("Test Byte Data")
app.Resource("/content").Get("test", "Test",
NewResponse(206, "desc").Model(Response{}),
).Run(func(ctx Context) {
content := bytes.NewReader(b)
ctx.WriteContent("/path/with/content.mp4", content, time.Time{})
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/content", nil)
app.ServeHTTP(w, req)
// confirms that name is properly being forwarded to ServeContent.
// We'll assume the more advanced modTime use cases are properly tested
// in http library.
assert.Equal(t, "video/mp4", w.Header().Get("Content-Type"))
}