Merge pull request #4 from weiser/add-configurable-cors-middleware

Add configurable CORS middleware
This commit is contained in:
Daniel G. Taylor 2020-06-15 15:16:02 -07:00 committed by GitHub
commit bf8461db67
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 63 additions and 1 deletions

View file

@ -615,6 +615,21 @@ g.NoRoute(huma.Handler404())
r := huma.NewRouter("My API", "1.0.0", huma.WithGin(g))
```
## Custom CORS Handler
If you would like CORS preflight requests to allow specific headers, do the following:
```go
// CORS: Allow non-standard headers "Authorization" and "X-My-Header" in preflight requests
cfg := cors.DefaultConfig()
cfg.AllowAllOrigins = true
cfg.AllowHeaders = append(cfg.AllowHeaders, "Authorization", "X-My-Header")
// And manual settings:
r := huma.NewRouter("My API", "1.0.0", huma.CORSHandler(cors.New(cfg)))
```
## Custom HTTP Server
You can have full control over the `http.Server` that is created.

View file

@ -375,6 +375,15 @@ func DocsHandler(f Handler) RouterOption {
}}
}
// CORSHandler sets the CORS handler function. This can be used to set custom
// domains, headers, auth, etc. If not given, then a default CORS handler is
// used instead.
func CORSHandler(f Handler) RouterOption {
return &routerOption{func(r *Router) {
r.corsHandler = f
}}
}
// OpenAPIHook registers a function to be called after the OpenAPI spec is
// generated but before being sent to the client.
func OpenAPIHook(f func(*gabs.Container)) RouterOption {

View file

@ -318,6 +318,7 @@ type Router struct {
root *cobra.Command
prestart []func()
docsHandler Handler
corsHandler Handler
// Tracks the currently running server for graceful shutdown.
server *http.Server
@ -335,7 +336,6 @@ func NewRouter(docs, version string, options ...RouterOption) *Router {
g := gin.New()
g.Use(Recovery())
g.Use(LogMiddleware())
g.Use(cors.Default())
g.Use(PreferMinimalMiddleware())
g.Use(ServiceLinkMiddleware())
g.NoRoute(Handler404())
@ -357,6 +357,7 @@ func NewRouter(docs, version string, options ...RouterOption) *Router {
engine: g,
prestart: []func(){},
docsHandler: RapiDocHandler(title),
corsHandler: cors.Default(),
}
r.setupCLI()
@ -366,6 +367,11 @@ func NewRouter(docs, version string, options ...RouterOption) *Router {
option.ApplyRouter(r)
}
// Apply CORS handler *after* options in case a custom Gin Engine is passed.
r.GinEngine().Use(func(c *gin.Context) {
r.corsHandler(c)
})
// Validate the router/API setup.
if err := r.api.validate(); err != nil {
panic(err)

View file

@ -12,6 +12,7 @@ import (
"time"
"github.com/danielgtaylor/huma/schema"
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
@ -178,6 +179,37 @@ func TestRouterDefault(t *testing.T) {
_ = NewTestRouter(t)
}
func TestRouterConfigurableCors(t *testing.T) {
cfg := cors.DefaultConfig()
cfg.AllowAllOrigins = true
cfg.AllowHeaders = append(cfg.AllowHeaders, "Authorization", "X-My-Header")
r := NewTestRouter(t, CORSHandler(cors.New(cfg)))
type PongResponse struct {
Value string `json:"value" description:"The echoed back word"`
}
r.Resource("/ping",
ResponseJSON(http.StatusOK, "Successful echo response"),
ResponseError(http.StatusBadRequest, "Invalid input"),
).Get("ping", func() (*PongResponse, *ErrorModel) {
return &PongResponse{Value: "pong"}, nil
})
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodOptions, "/ping", nil)
req.Header.Add("Origin", "blah")
r.ServeHTTP(w, req)
assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin"))
allowedHeaders := w.Header().Get("Access-Control-Allow-Headers")
assert.Equal(t, true, strings.Contains(allowedHeaders, "Authorization"))
assert.Equal(t, true, strings.Contains(allowedHeaders, "X-My-Header"))
}
func TestRouter(t *testing.T) {
type EchoResponse struct {
Value string `json:"value" description:"The echoed back word"`