mirror of
https://github.com/Fishwaldo/huma.git
synced 2025-03-15 19:31:27 +00:00
Merge pull request #4 from weiser/add-configurable-cors-middleware
Add configurable CORS middleware
This commit is contained in:
commit
bf8461db67
4 changed files with 63 additions and 1 deletions
15
README.md
15
README.md
|
@ -615,6 +615,21 @@ g.NoRoute(huma.Handler404())
|
|||
r := huma.NewRouter("My API", "1.0.0", huma.WithGin(g))
|
||||
```
|
||||
|
||||
|
||||
## Custom CORS Handler
|
||||
|
||||
If you would like CORS preflight requests to allow specific headers, do the following:
|
||||
|
||||
```go
|
||||
// CORS: Allow non-standard headers "Authorization" and "X-My-Header" in preflight requests
|
||||
cfg := cors.DefaultConfig()
|
||||
cfg.AllowAllOrigins = true
|
||||
cfg.AllowHeaders = append(cfg.AllowHeaders, "Authorization", "X-My-Header")
|
||||
|
||||
// And manual settings:
|
||||
r := huma.NewRouter("My API", "1.0.0", huma.CORSHandler(cors.New(cfg)))
|
||||
```
|
||||
|
||||
## Custom HTTP Server
|
||||
|
||||
You can have full control over the `http.Server` that is created.
|
||||
|
|
|
@ -375,6 +375,15 @@ func DocsHandler(f Handler) RouterOption {
|
|||
}}
|
||||
}
|
||||
|
||||
// CORSHandler sets the CORS handler function. This can be used to set custom
|
||||
// domains, headers, auth, etc. If not given, then a default CORS handler is
|
||||
// used instead.
|
||||
func CORSHandler(f Handler) RouterOption {
|
||||
return &routerOption{func(r *Router) {
|
||||
r.corsHandler = f
|
||||
}}
|
||||
}
|
||||
|
||||
// OpenAPIHook registers a function to be called after the OpenAPI spec is
|
||||
// generated but before being sent to the client.
|
||||
func OpenAPIHook(f func(*gabs.Container)) RouterOption {
|
||||
|
|
|
@ -318,6 +318,7 @@ type Router struct {
|
|||
root *cobra.Command
|
||||
prestart []func()
|
||||
docsHandler Handler
|
||||
corsHandler Handler
|
||||
|
||||
// Tracks the currently running server for graceful shutdown.
|
||||
server *http.Server
|
||||
|
@ -335,7 +336,6 @@ func NewRouter(docs, version string, options ...RouterOption) *Router {
|
|||
g := gin.New()
|
||||
g.Use(Recovery())
|
||||
g.Use(LogMiddleware())
|
||||
g.Use(cors.Default())
|
||||
g.Use(PreferMinimalMiddleware())
|
||||
g.Use(ServiceLinkMiddleware())
|
||||
g.NoRoute(Handler404())
|
||||
|
@ -357,6 +357,7 @@ func NewRouter(docs, version string, options ...RouterOption) *Router {
|
|||
engine: g,
|
||||
prestart: []func(){},
|
||||
docsHandler: RapiDocHandler(title),
|
||||
corsHandler: cors.Default(),
|
||||
}
|
||||
|
||||
r.setupCLI()
|
||||
|
@ -366,6 +367,11 @@ func NewRouter(docs, version string, options ...RouterOption) *Router {
|
|||
option.ApplyRouter(r)
|
||||
}
|
||||
|
||||
// Apply CORS handler *after* options in case a custom Gin Engine is passed.
|
||||
r.GinEngine().Use(func(c *gin.Context) {
|
||||
r.corsHandler(c)
|
||||
})
|
||||
|
||||
// Validate the router/API setup.
|
||||
if err := r.api.validate(); err != nil {
|
||||
panic(err)
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/danielgtaylor/huma/schema"
|
||||
"github.com/gin-contrib/cors"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/zap"
|
||||
|
@ -178,6 +179,37 @@ func TestRouterDefault(t *testing.T) {
|
|||
_ = NewTestRouter(t)
|
||||
}
|
||||
|
||||
func TestRouterConfigurableCors(t *testing.T) {
|
||||
cfg := cors.DefaultConfig()
|
||||
cfg.AllowAllOrigins = true
|
||||
cfg.AllowHeaders = append(cfg.AllowHeaders, "Authorization", "X-My-Header")
|
||||
|
||||
r := NewTestRouter(t, CORSHandler(cors.New(cfg)))
|
||||
|
||||
type PongResponse struct {
|
||||
Value string `json:"value" description:"The echoed back word"`
|
||||
}
|
||||
|
||||
r.Resource("/ping",
|
||||
ResponseJSON(http.StatusOK, "Successful echo response"),
|
||||
ResponseError(http.StatusBadRequest, "Invalid input"),
|
||||
).Get("ping", func() (*PongResponse, *ErrorModel) {
|
||||
|
||||
return &PongResponse{Value: "pong"}, nil
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest(http.MethodOptions, "/ping", nil)
|
||||
req.Header.Add("Origin", "blah")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin"))
|
||||
allowedHeaders := w.Header().Get("Access-Control-Allow-Headers")
|
||||
assert.Equal(t, true, strings.Contains(allowedHeaders, "Authorization"))
|
||||
assert.Equal(t, true, strings.Contains(allowedHeaders, "X-My-Header"))
|
||||
|
||||
}
|
||||
|
||||
func TestRouter(t *testing.T) {
|
||||
type EchoResponse struct {
|
||||
Value string `json:"value" description:"The echoed back word"`
|
||||
|
|
Loading…
Add table
Reference in a new issue