mirror of
https://github.com/Fishwaldo/huma.git
synced 2025-03-15 19:31:27 +00:00
369 lines
10 KiB
Go
369 lines
10 KiB
Go
package huma
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/fxamacker/cbor/v2"
|
|
"github.com/goccy/go-yaml"
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
func TestGetContextFromRequest(t *testing.T) {
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
r = r.WithContext(context.WithValue(r.Context(), contextKey("foo"), "bar"))
|
|
ctx := ContextFromRequest(w, r)
|
|
assert.Equal(t, "bar", ctx.Value(contextKey("foo")))
|
|
})
|
|
|
|
w := httptest.NewRecorder()
|
|
r, _ := http.NewRequest(http.MethodGet, "/", nil)
|
|
handler(w, r)
|
|
}
|
|
|
|
func TestContentNegotiation(t *testing.T) {
|
|
type Response struct {
|
|
Value string `json:"value"`
|
|
}
|
|
|
|
app := newTestRouter()
|
|
|
|
app.Resource("/negotiated").Get("test", "Test",
|
|
NewResponse(200, "desc").Model(Response{}),
|
|
).Run(func(ctx Context) {
|
|
ctx.WriteModel(http.StatusOK, Response{
|
|
Value: "Hello, world!",
|
|
})
|
|
})
|
|
|
|
var parsed Response
|
|
expected := Response{
|
|
Value: "Hello, world!",
|
|
}
|
|
|
|
// No preference
|
|
w := httptest.NewRecorder()
|
|
req, _ := http.NewRequest(http.MethodGet, "/negotiated", nil)
|
|
app.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
assert.Equal(t, "application/json", w.Header().Get("Content-Type"))
|
|
err := json.Unmarshal(w.Body.Bytes(), &parsed)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, expected, parsed)
|
|
|
|
// Prefer JSON
|
|
w = httptest.NewRecorder()
|
|
req, _ = http.NewRequest(http.MethodGet, "/negotiated", nil)
|
|
req.Header.Set("Accept", "application/json")
|
|
app.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
assert.Equal(t, "application/json", w.Header().Get("Content-Type"))
|
|
err = json.Unmarshal(w.Body.Bytes(), &parsed)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, expected, parsed)
|
|
|
|
// Prefer YAML
|
|
w = httptest.NewRecorder()
|
|
req, _ = http.NewRequest(http.MethodGet, "/negotiated", nil)
|
|
req.Header.Set("Accept", "application/yaml")
|
|
app.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
assert.Equal(t, "application/yaml", w.Header().Get("Content-Type"))
|
|
err = yaml.Unmarshal(w.Body.Bytes(), &parsed)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, expected, parsed)
|
|
|
|
// Prefer CBOR
|
|
w = httptest.NewRecorder()
|
|
req, _ = http.NewRequest(http.MethodGet, "/negotiated", nil)
|
|
req.Header.Set("Accept", "application/cbor")
|
|
app.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
assert.Equal(t, "application/cbor", w.Header().Get("Content-Type"))
|
|
err = cbor.Unmarshal(w.Body.Bytes(), &parsed)
|
|
assert.NoError(t, err)
|
|
assert.EqualValues(t, expected, parsed)
|
|
}
|
|
|
|
func TestErrorNegotiation(t *testing.T) {
|
|
app := newTestRouter()
|
|
|
|
app.Resource("/error").Get("test", "Test",
|
|
NewResponse(400, "desc").Model(&ErrorModel{}),
|
|
).Run(func(ctx Context) {
|
|
ctx.AddError(fmt.Errorf("some error"))
|
|
ctx.AddError(&ErrorDetail{
|
|
Message: "Invalid value",
|
|
Location: "body.field",
|
|
Value: "test",
|
|
})
|
|
ctx.WriteError(http.StatusBadRequest, "test error")
|
|
})
|
|
|
|
var parsed ErrorModel
|
|
expected := ErrorModel{
|
|
Status: http.StatusBadRequest,
|
|
Title: http.StatusText(http.StatusBadRequest),
|
|
Detail: "test error",
|
|
Errors: []*ErrorDetail{
|
|
{Message: "some error"},
|
|
{
|
|
Message: "Invalid value",
|
|
Location: "body.field",
|
|
Value: "test",
|
|
},
|
|
},
|
|
}
|
|
|
|
// No preference
|
|
w := httptest.NewRecorder()
|
|
req, _ := http.NewRequest(http.MethodGet, "/error", nil)
|
|
app.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, http.StatusBadRequest, w.Code)
|
|
assert.Equal(t, "application/problem+json", w.Header().Get("Content-Type"))
|
|
err := json.Unmarshal(w.Body.Bytes(), &parsed)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, expected, parsed)
|
|
|
|
// Prefer JSON
|
|
w = httptest.NewRecorder()
|
|
req, _ = http.NewRequest(http.MethodGet, "/error", nil)
|
|
req.Header.Set("Accept", "application/json")
|
|
app.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, http.StatusBadRequest, w.Code)
|
|
assert.Equal(t, "application/problem+json", w.Header().Get("Content-Type"))
|
|
err = json.Unmarshal(w.Body.Bytes(), &parsed)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, expected, parsed)
|
|
|
|
// Prefer YAML
|
|
w = httptest.NewRecorder()
|
|
req, _ = http.NewRequest(http.MethodGet, "/error", nil)
|
|
req.Header.Set("Accept", "application/yaml")
|
|
app.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, http.StatusBadRequest, w.Code)
|
|
assert.Equal(t, "application/problem+yaml", w.Header().Get("Content-Type"))
|
|
err = yaml.Unmarshal(w.Body.Bytes(), &parsed)
|
|
assert.NoError(t, err)
|
|
assert.EqualValues(t, expected, parsed)
|
|
|
|
// Prefer CBOR
|
|
w = httptest.NewRecorder()
|
|
req, _ = http.NewRequest(http.MethodGet, "/error", nil)
|
|
req.Header.Set("Accept", "application/cbor")
|
|
app.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, http.StatusBadRequest, w.Code)
|
|
assert.Equal(t, "application/problem+cbor", w.Header().Get("Content-Type"))
|
|
err = cbor.Unmarshal(w.Body.Bytes(), &parsed)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, expected, parsed)
|
|
}
|
|
|
|
func TestInvalidModel(t *testing.T) {
|
|
type R1 struct {
|
|
Foo string `json:"foo"`
|
|
}
|
|
|
|
type R2 struct {
|
|
Bar string `json:"bar"`
|
|
}
|
|
|
|
app := newTestRouter()
|
|
|
|
app.Resource("/bad-status").Get("test", "Test",
|
|
NewResponse(http.StatusOK, "desc").Model(R1{}),
|
|
).Run(func(ctx Context) {
|
|
ctx.WriteModel(http.StatusNoContent, R2{Bar: "blah"})
|
|
})
|
|
|
|
app.Resource("/bad-model").Get("test", "Test",
|
|
NewResponse(http.StatusOK, "desc").Model(R1{}),
|
|
).Run(func(ctx Context) {
|
|
ctx.WriteModel(http.StatusOK, R2{Bar: "blah"})
|
|
})
|
|
|
|
assert.Panics(t, func() {
|
|
w := httptest.NewRecorder()
|
|
req, _ := http.NewRequest(http.MethodGet, "/bad-status", nil)
|
|
app.ServeHTTP(w, req)
|
|
})
|
|
|
|
assert.Panics(t, func() {
|
|
w := httptest.NewRecorder()
|
|
req, _ := http.NewRequest(http.MethodGet, "/bad-model", nil)
|
|
app.ServeHTTP(w, req)
|
|
})
|
|
}
|
|
|
|
func TestInvalidHeader(t *testing.T) {
|
|
app := newTestRouter()
|
|
|
|
app.Resource("/").Get("test", "Test",
|
|
NewResponse(http.StatusNoContent, "desc").Headers("Extra"),
|
|
).Run(func(ctx Context) {
|
|
// Typo in the header should not be allowed
|
|
ctx.Header().Set("Extra2", "some-value")
|
|
ctx.WriteHeader(http.StatusNoContent)
|
|
})
|
|
|
|
assert.Panics(t, func() {
|
|
w := httptest.NewRecorder()
|
|
req, _ := http.NewRequest(http.MethodGet, "/", nil)
|
|
app.ServeHTTP(w, req)
|
|
})
|
|
}
|
|
|
|
func TestWriteAfterClose(t *testing.T) {
|
|
app := newTestRouter()
|
|
|
|
app.Resource("/").Get("test", "Test",
|
|
NewResponse(http.StatusBadRequest, "desc").Model(&ErrorModel{}),
|
|
).Run(func(ctx Context) {
|
|
ctx.WriteError(http.StatusBadRequest, "some error")
|
|
// Second write should fail
|
|
ctx.WriteError(http.StatusBadRequest, "some error")
|
|
})
|
|
|
|
assert.Panics(t, func() {
|
|
w := httptest.NewRecorder()
|
|
req, _ := http.NewRequest(http.MethodGet, "/", nil)
|
|
app.ServeHTTP(w, req)
|
|
})
|
|
}
|
|
|
|
func TestValue(t *testing.T) {
|
|
key := contextKey("foo")
|
|
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
ctx := ContextFromRequest(w, r)
|
|
|
|
// Create a copy with a new value attached.
|
|
ctx2 := ctx.WithValue(key, "bar")
|
|
|
|
assert.Equal(t, nil, ctx.Value(key))
|
|
assert.Equal(t, "bar", ctx2.Value(key))
|
|
|
|
// Set a value in-place in the original
|
|
ctx.SetValue(key, "baz")
|
|
|
|
// Make sure it was set on ctx, and that ctx2 wasn't modified since it's
|
|
// a copy created before the call to `SetValue(...)`.
|
|
assert.Equal(t, "baz", ctx.Value(key))
|
|
assert.Equal(t, "bar", ctx2.Value(key))
|
|
})
|
|
|
|
w := httptest.NewRecorder()
|
|
r, _ := http.NewRequest(http.MethodGet, "/", nil)
|
|
handler(w, r)
|
|
}
|
|
|
|
func TestWriteContent(t *testing.T) {
|
|
app := newTestRouter()
|
|
|
|
b := []byte("Test Byte Data")
|
|
|
|
app.Resource("/content").Get("test", "Test",
|
|
NewResponse(200, "desc").Model(Response{}),
|
|
).Run(func(ctx Context) {
|
|
ctx.Header().Set("Content-Type", "application/octet-stream")
|
|
content := bytes.NewReader(b)
|
|
ctx.WriteContent("", content, time.Time{})
|
|
})
|
|
|
|
w := httptest.NewRecorder()
|
|
req, _ := http.NewRequest(http.MethodGet, "/content", nil)
|
|
app.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, 200, w.Result().StatusCode)
|
|
assert.Equal(t, "application/octet-stream", w.Header().Get("Content-Type"))
|
|
assert.Equal(t, b, w.Body.Bytes())
|
|
}
|
|
|
|
func TestWriteContentRespectsRange(t *testing.T) {
|
|
app := newTestRouter()
|
|
|
|
b := []byte("Test Byte Data")
|
|
|
|
app.Resource("/content").Get("test", "Test",
|
|
NewResponse(206, "desc").Model(Response{}),
|
|
).Run(func(ctx Context) {
|
|
ctx.Header().Set("Content-Type", "application/octet-stream")
|
|
content := bytes.NewReader(b)
|
|
ctx.WriteContent("", content, time.Time{})
|
|
})
|
|
|
|
w := httptest.NewRecorder()
|
|
req, _ := http.NewRequest(http.MethodGet, "/content", nil)
|
|
req.Header.Set("Range", "bytes=0-5")
|
|
app.ServeHTTP(w, req)
|
|
|
|
// confirms that Range is properly being forwarded to ServeContent.
|
|
// we'll assume more advanced range use cases are properly tested
|
|
// in the http library.
|
|
assert.Equal(t, w.Result().StatusCode, 206)
|
|
assert.Equal(t, []byte("Test B"), w.Body.Bytes())
|
|
}
|
|
|
|
func TestWriteContentLastModified(t *testing.T) {
|
|
app := newTestRouter()
|
|
|
|
b := []byte("Test Byte Data")
|
|
modTime := time.Now()
|
|
|
|
app.Resource("/content").Get("test", "Test",
|
|
NewResponse(206, "desc").Model(Response{}),
|
|
).Run(func(ctx Context) {
|
|
ctx.Header().Set("Content-Type", "application/octet-stream")
|
|
content := bytes.NewReader(b)
|
|
ctx.WriteContent("", content, modTime)
|
|
})
|
|
|
|
w := httptest.NewRecorder()
|
|
req, _ := http.NewRequest(http.MethodGet, "/content", nil)
|
|
app.ServeHTTP(w, req)
|
|
|
|
// confirms that modTime is properly being forwarded to ServeContent.
|
|
// We'll assume the more advanced modTime use cases are properly tested
|
|
// in http library.
|
|
strTime := modTime.UTC().Format(http.TimeFormat)
|
|
|
|
assert.Equal(t, strTime, w.Header().Get("Last-Modified"))
|
|
|
|
}
|
|
|
|
func TestWriteContentName(t *testing.T) {
|
|
app := newTestRouter()
|
|
|
|
b := []byte("Test Byte Data")
|
|
|
|
app.Resource("/content").Get("test", "Test",
|
|
NewResponse(206, "desc").Model(Response{}),
|
|
).Run(func(ctx Context) {
|
|
|
|
content := bytes.NewReader(b)
|
|
ctx.WriteContent("/path/with/content.mp4", content, time.Time{})
|
|
})
|
|
|
|
w := httptest.NewRecorder()
|
|
req, _ := http.NewRequest(http.MethodGet, "/content", nil)
|
|
app.ServeHTTP(w, req)
|
|
|
|
// confirms that name is properly being forwarded to ServeContent.
|
|
// We'll assume the more advanced modTime use cases are properly tested
|
|
// in http library.
|
|
assert.Equal(t, "video/mp4", w.Header().Get("Content-Type"))
|
|
}
|