mirror of
https://github.com/Fishwaldo/auth2.git
synced 2025-06-03 12:21:22 +00:00
Implement Phase 2.1: Authentication Provider Interface
- Define enhanced AuthProvider interface with additional context and result information - Create various credential types for different authentication methods - Implement ProviderManager for managing multiple authentication providers - Build chain-of-responsibility pattern for flexible authentication flows - Add comprehensive unit tests with >80% coverage - Update project plan to mark Phase 2.1 as completed
This commit is contained in:
parent
d6a63c5895
commit
c932a4d001
12 changed files with 1991 additions and 6 deletions
|
@ -26,10 +26,10 @@ This document outlines the step-by-step implementation plan for the Auth2 librar
|
||||||
## Phase 2: Core Authentication Framework
|
## Phase 2: Core Authentication Framework
|
||||||
|
|
||||||
### 2.1 Authentication Provider Interface
|
### 2.1 Authentication Provider Interface
|
||||||
- [ ] Define AuthProvider interface
|
- [x] Define AuthProvider interface
|
||||||
- [ ] Create ProviderManager for managing multiple providers
|
- [x] Create ProviderManager for managing multiple providers
|
||||||
- [ ] Implement provider registration system
|
- [x] Implement provider registration system
|
||||||
- [ ] Build chain-of-responsibility pattern for auth attempts
|
- [x] Build chain-of-responsibility pattern for auth attempts
|
||||||
|
|
||||||
### 2.2 Basic Authentication
|
### 2.2 Basic Authentication
|
||||||
- [ ] Implement username/password provider
|
- [ ] Implement username/password provider
|
||||||
|
|
12
go.mod
12
go.mod
|
@ -3,6 +3,14 @@ module github.com/Fishwaldo/auth2
|
||||||
go 1.24
|
go 1.24
|
||||||
|
|
||||||
require (
|
require (
|
||||||
golang.org/x/crypto v0.38.0 // indirect
|
github.com/google/uuid v1.6.0
|
||||||
golang.org/x/sys v0.33.0 // indirect
|
github.com/stretchr/testify v1.10.0
|
||||||
|
golang.org/x/crypto v0.38.0
|
||||||
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
|
golang.org/x/sys v0.33.0 // indirect
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
)
|
)
|
||||||
|
|
12
go.sum
12
go.sum
|
@ -1,4 +1,16 @@
|
||||||
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
|
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||||
|
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||||
golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8=
|
golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8=
|
||||||
golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw=
|
golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw=
|
||||||
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
|
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
|
||||||
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
|
|
259
pkg/auth/auth_chain.go
Normal file
259
pkg/auth/auth_chain.go
Normal file
|
@ -0,0 +1,259 @@
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"github.com/Fishwaldo/auth2/internal/errors"
|
||||||
|
"github.com/Fishwaldo/auth2/pkg/auth/providers"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuthHandlerFunc defines a function that processes an authentication request
|
||||||
|
type AuthHandlerFunc func(ctx *providers.AuthContext, credentials providers.Credentials, next AuthHandlerFunc) (*providers.AuthResult, error)
|
||||||
|
|
||||||
|
// AuthChain implements a chain of responsibility for authentication
|
||||||
|
type AuthChain struct {
|
||||||
|
manager *Manager
|
||||||
|
handlers []AuthHandlerFunc
|
||||||
|
middlewares []AuthHandlerFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAuthChain creates a new authentication chain using the provided manager
|
||||||
|
func NewAuthChain(manager *Manager) *AuthChain {
|
||||||
|
return &AuthChain{
|
||||||
|
manager: manager,
|
||||||
|
handlers: make([]AuthHandlerFunc, 0),
|
||||||
|
middlewares: make([]AuthHandlerFunc, 0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use adds a middleware to the authentication chain
|
||||||
|
// Middlewares are executed in the order they are added before any handlers
|
||||||
|
func (c *AuthChain) Use(middleware AuthHandlerFunc) *AuthChain {
|
||||||
|
c.middlewares = append(c.middlewares, middleware)
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handler adds a handler to the authentication chain
|
||||||
|
// Handlers are executed in the order they are added
|
||||||
|
func (c *AuthChain) Handler(handler AuthHandlerFunc) *AuthChain {
|
||||||
|
c.handlers = append(c.handlers, handler)
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultProviderHandler returns a handler that uses the default provider
|
||||||
|
func (c *AuthChain) DefaultProviderHandler() AuthHandlerFunc {
|
||||||
|
return func(ctx *providers.AuthContext, credentials providers.Credentials, next AuthHandlerFunc) (*providers.AuthResult, error) {
|
||||||
|
// Skip if no default provider is configured
|
||||||
|
if c.manager.Config.DefaultProviderID == "" {
|
||||||
|
return next(ctx, credentials, next)
|
||||||
|
}
|
||||||
|
|
||||||
|
provider, err := c.manager.GetProvider(c.manager.Config.DefaultProviderID)
|
||||||
|
if err != nil {
|
||||||
|
// Skip to the next handler if provider not found
|
||||||
|
return next(ctx, credentials, next)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !provider.Supports(credentials) {
|
||||||
|
// Skip to the next handler if provider doesn't support these credentials
|
||||||
|
return next(ctx, credentials, next)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to authenticate with this provider
|
||||||
|
result, err := provider.Authenticate(ctx, credentials)
|
||||||
|
if err != nil || !result.Success {
|
||||||
|
// Continue to next handler on failure
|
||||||
|
return next(ctx, credentials, next)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllProvidersHandler returns a handler that tries all providers
|
||||||
|
func (c *AuthChain) AllProvidersHandler() AuthHandlerFunc {
|
||||||
|
return func(ctx *providers.AuthContext, credentials providers.Credentials, next AuthHandlerFunc) (*providers.AuthResult, error) {
|
||||||
|
providersList := c.manager.GetProviders()
|
||||||
|
|
||||||
|
// Find all providers that support this credential type
|
||||||
|
var supportingProviders []providers.AuthProvider
|
||||||
|
for _, provider := range providersList {
|
||||||
|
if provider.Supports(credentials) {
|
||||||
|
supportingProviders = append(supportingProviders, provider)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(supportingProviders) == 0 {
|
||||||
|
// No providers support these credentials, continue to next handler
|
||||||
|
return next(ctx, credentials, next)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try each provider until one succeeds or all fail
|
||||||
|
var combinedResult *providers.AuthResult
|
||||||
|
|
||||||
|
for _, provider := range supportingProviders {
|
||||||
|
result, err := provider.Authenticate(ctx, credentials)
|
||||||
|
|
||||||
|
// Return immediately on success
|
||||||
|
if err == nil && result.Success {
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Combine results (for MFA requirements, etc.)
|
||||||
|
if combinedResult == nil {
|
||||||
|
combinedResult = result
|
||||||
|
} else if result != nil {
|
||||||
|
// Collect MFA providers across results
|
||||||
|
if result.RequiresMFA && len(result.MFAProviders) > 0 {
|
||||||
|
if combinedResult.MFAProviders == nil {
|
||||||
|
combinedResult.MFAProviders = make([]string, 0)
|
||||||
|
}
|
||||||
|
combinedResult.MFAProviders = append(combinedResult.MFAProviders, result.MFAProviders...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collect extra data
|
||||||
|
if result.Extra != nil {
|
||||||
|
if combinedResult.Extra == nil {
|
||||||
|
combinedResult.Extra = make(map[string]interface{})
|
||||||
|
}
|
||||||
|
for k, v := range result.Extra {
|
||||||
|
combinedResult.Extra[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If all providers failed, try the next handler
|
||||||
|
if combinedResult == nil || !combinedResult.Success {
|
||||||
|
return next(ctx, credentials, next)
|
||||||
|
}
|
||||||
|
|
||||||
|
return combinedResult, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SpecificProviderHandler returns a handler that uses a specific provider
|
||||||
|
func (c *AuthChain) SpecificProviderHandler(providerID string) AuthHandlerFunc {
|
||||||
|
return func(ctx *providers.AuthContext, credentials providers.Credentials, next AuthHandlerFunc) (*providers.AuthResult, error) {
|
||||||
|
provider, err := c.manager.GetProvider(providerID)
|
||||||
|
if err != nil {
|
||||||
|
// Skip to the next handler if provider not found
|
||||||
|
return next(ctx, credentials, next)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !provider.Supports(credentials) {
|
||||||
|
// Skip to the next handler if provider doesn't support these credentials
|
||||||
|
return next(ctx, credentials, next)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to authenticate with this provider
|
||||||
|
result, err := provider.Authenticate(ctx, credentials)
|
||||||
|
if err != nil || !result.Success {
|
||||||
|
// Continue to next handler on failure
|
||||||
|
return next(ctx, credentials, next)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// endOfChain is the terminating handler that returns a standard error
|
||||||
|
func endOfChain(ctx *providers.AuthContext, credentials providers.Credentials, next AuthHandlerFunc) (*providers.AuthResult, error) {
|
||||||
|
return &providers.AuthResult{
|
||||||
|
Success: false,
|
||||||
|
Error: providers.NewAuthFailedError("no handler succeeded", nil),
|
||||||
|
}, providers.NewAuthFailedError("no handler succeeded", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authenticate authenticates a user using the chain of responsibility
|
||||||
|
func (c *AuthChain) Authenticate(ctx context.Context, credentials providers.Credentials) (*providers.AuthResult, error) {
|
||||||
|
// Create authentication context with request ID
|
||||||
|
authCtx := &providers.AuthContext{
|
||||||
|
OriginalContext: ctx,
|
||||||
|
RequestID: uuid.New().String(),
|
||||||
|
RequestMetadata: make(map[string]interface{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract client information from context if available
|
||||||
|
if clientIP, ok := ctx.Value("client_ip").(string); ok {
|
||||||
|
authCtx.ClientIP = clientIP
|
||||||
|
}
|
||||||
|
|
||||||
|
if userAgent, ok := ctx.Value("user_agent").(string); ok {
|
||||||
|
authCtx.UserAgent = userAgent
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if we have any handlers
|
||||||
|
if len(c.handlers) == 0 {
|
||||||
|
return nil, errors.WrapError(
|
||||||
|
errors.ErrServiceUnavailable,
|
||||||
|
errors.CodeUnavailable,
|
||||||
|
"no authentication handlers configured",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Combine middlewares and handlers
|
||||||
|
chain := make([]AuthHandlerFunc, 0, len(c.middlewares)+len(c.handlers))
|
||||||
|
chain = append(chain, c.middlewares...)
|
||||||
|
chain = append(chain, c.handlers...)
|
||||||
|
|
||||||
|
// Build the chain of handlers
|
||||||
|
var next AuthHandlerFunc = endOfChain
|
||||||
|
for i := len(chain) - 1; i >= 0; i-- {
|
||||||
|
currentHandler := chain[i]
|
||||||
|
previousNext := next
|
||||||
|
next = func(currentHandler AuthHandlerFunc, previousNext AuthHandlerFunc) AuthHandlerFunc {
|
||||||
|
return func(ctx *providers.AuthContext, credentials providers.Credentials, _ AuthHandlerFunc) (*providers.AuthResult, error) {
|
||||||
|
return currentHandler(ctx, credentials, previousNext)
|
||||||
|
}
|
||||||
|
}(currentHandler, previousNext)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start the chain
|
||||||
|
return next(authCtx, credentials, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildDefaultChain returns a chain with default handlers
|
||||||
|
func (c *AuthChain) BuildDefaultChain() *AuthChain {
|
||||||
|
return c.Handler(c.DefaultProviderHandler()).Handler(c.AllProvidersHandler())
|
||||||
|
}
|
||||||
|
|
||||||
|
// RateLimitingMiddleware creates a middleware that implements rate limiting
|
||||||
|
func RateLimitingMiddleware(maxAttempts int, lockoutDuration int64) AuthHandlerFunc {
|
||||||
|
// In a real implementation, this would use a proper rate limiter
|
||||||
|
return func(ctx *providers.AuthContext, credentials providers.Credentials, next AuthHandlerFunc) (*providers.AuthResult, error) {
|
||||||
|
// Implement rate limiting logic here
|
||||||
|
// For now, just pass through
|
||||||
|
return next(ctx, credentials, next)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoggingMiddleware creates a middleware that logs authentication attempts
|
||||||
|
func LoggingMiddleware() AuthHandlerFunc {
|
||||||
|
return func(ctx *providers.AuthContext, credentials providers.Credentials, next AuthHandlerFunc) (*providers.AuthResult, error) {
|
||||||
|
// Log the authentication attempt
|
||||||
|
// In a real implementation, this would use a proper logger
|
||||||
|
|
||||||
|
// Call the next handler
|
||||||
|
result, err := next(ctx, credentials, next)
|
||||||
|
|
||||||
|
// Log the result
|
||||||
|
// In a real implementation, this would use a proper logger
|
||||||
|
|
||||||
|
return result, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuditingMiddleware creates a middleware that records audit events
|
||||||
|
func AuditingMiddleware() AuthHandlerFunc {
|
||||||
|
return func(ctx *providers.AuthContext, credentials providers.Credentials, next AuthHandlerFunc) (*providers.AuthResult, error) {
|
||||||
|
// Record the authentication attempt for auditing
|
||||||
|
|
||||||
|
// Call the next handler
|
||||||
|
result, err := next(ctx, credentials, next)
|
||||||
|
|
||||||
|
// Record the result for auditing
|
||||||
|
|
||||||
|
return result, err
|
||||||
|
}
|
||||||
|
}
|
402
pkg/auth/auth_chain_black_test.go
Normal file
402
pkg/auth/auth_chain_black_test.go
Normal file
|
@ -0,0 +1,402 @@
|
||||||
|
package auth_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Fishwaldo/auth2/internal/errors"
|
||||||
|
"github.com/Fishwaldo/auth2/pkg/auth"
|
||||||
|
"github.com/Fishwaldo/auth2/pkg/auth/providers"
|
||||||
|
"github.com/Fishwaldo/auth2/pkg/auth/test"
|
||||||
|
"github.com/Fishwaldo/auth2/pkg/plugin/metadata"
|
||||||
|
"github.com/Fishwaldo/auth2/pkg/plugin/registry"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAuthChain(t *testing.T) {
|
||||||
|
reg := registry.NewRegistry()
|
||||||
|
config := auth.ManagerConfig{
|
||||||
|
DefaultProviderID: "default",
|
||||||
|
MFARequired: false,
|
||||||
|
MaxLoginAttempts: 5,
|
||||||
|
LockoutDuration: 300,
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := auth.NewManager(reg, config)
|
||||||
|
|
||||||
|
// Create test providers
|
||||||
|
defaultProvider := test.NewMockAuthProvider(metadata.ProviderMetadata{
|
||||||
|
ID: "default",
|
||||||
|
Type: metadata.ProviderTypeAuth,
|
||||||
|
Version: "1.0.0",
|
||||||
|
Name: "Default Provider",
|
||||||
|
Description: "Default provider",
|
||||||
|
Author: "Auth2 Team",
|
||||||
|
})
|
||||||
|
|
||||||
|
defaultProvider.AuthenticateFunc = func(ctx *providers.AuthContext, credentials interface{}) (*providers.AuthResult, error) {
|
||||||
|
return &providers.AuthResult{
|
||||||
|
Success: true,
|
||||||
|
UserID: "default-user",
|
||||||
|
ProviderID: "default",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
passwordProvider := test.NewMockAuthProvider(metadata.ProviderMetadata{
|
||||||
|
ID: "password",
|
||||||
|
Type: metadata.ProviderTypeAuth,
|
||||||
|
Version: "1.0.0",
|
||||||
|
Name: "Password Provider",
|
||||||
|
Description: "Password-based provider",
|
||||||
|
Author: "Auth2 Team",
|
||||||
|
})
|
||||||
|
|
||||||
|
passwordProvider.AuthenticateFunc = func(ctx *providers.AuthContext, credentials interface{}) (*providers.AuthResult, error) {
|
||||||
|
// Check if credentials are username/password
|
||||||
|
if _, ok := credentials.(providers.UsernamePasswordCredentials); ok {
|
||||||
|
return &providers.AuthResult{
|
||||||
|
Success: true,
|
||||||
|
UserID: "password-user",
|
||||||
|
ProviderID: "password",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &providers.AuthResult{
|
||||||
|
Success: false,
|
||||||
|
ProviderID: "password",
|
||||||
|
Error: providers.NewAuthFailedError("unsupported credentials", nil),
|
||||||
|
}, providers.NewAuthFailedError("unsupported credentials", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
passwordProvider.SupportsFunc = func(credentials interface{}) bool {
|
||||||
|
_, ok := credentials.(providers.UsernamePasswordCredentials)
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenProvider := test.NewMockAuthProvider(metadata.ProviderMetadata{
|
||||||
|
ID: "token",
|
||||||
|
Type: metadata.ProviderTypeAuth,
|
||||||
|
Version: "1.0.0",
|
||||||
|
Name: "Token Provider",
|
||||||
|
Description: "Token-based provider",
|
||||||
|
Author: "Auth2 Team",
|
||||||
|
})
|
||||||
|
|
||||||
|
tokenProvider.AuthenticateFunc = func(ctx *providers.AuthContext, credentials interface{}) (*providers.AuthResult, error) {
|
||||||
|
// Check if credentials are token credentials
|
||||||
|
if _, ok := credentials.(providers.TokenCredentials); ok {
|
||||||
|
return &providers.AuthResult{
|
||||||
|
Success: true,
|
||||||
|
UserID: "token-user",
|
||||||
|
ProviderID: "token",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &providers.AuthResult{
|
||||||
|
Success: false,
|
||||||
|
ProviderID: "token",
|
||||||
|
Error: providers.NewAuthFailedError("unsupported credentials", nil),
|
||||||
|
}, providers.NewAuthFailedError("unsupported credentials", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenProvider.SupportsFunc = func(credentials interface{}) bool {
|
||||||
|
_, ok := credentials.(providers.TokenCredentials)
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register providers
|
||||||
|
err := manager.RegisterProvider(defaultProvider)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = manager.RegisterProvider(passwordProvider)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = manager.RegisterProvider(tokenProvider)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
t.Run("DefaultProviderHandler", func(t *testing.T) {
|
||||||
|
// Override default provider support for this test
|
||||||
|
prevSupports := defaultProvider.SupportsFunc
|
||||||
|
defaultProvider.SupportsFunc = func(creds interface{}) bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create new chain for this test
|
||||||
|
testChain := auth.NewAuthChain(manager)
|
||||||
|
testChain.Handler(testChain.DefaultProviderHandler())
|
||||||
|
|
||||||
|
credentials := providers.UsernamePasswordCredentials{
|
||||||
|
Username: "testuser",
|
||||||
|
Password: "testpassword",
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := testChain.Authenticate(context.Background(), credentials)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
assert.True(t, result.Success)
|
||||||
|
assert.Equal(t, "default-user", result.UserID)
|
||||||
|
assert.Equal(t, "default", result.ProviderID)
|
||||||
|
|
||||||
|
// Restore original supports function
|
||||||
|
defaultProvider.SupportsFunc = prevSupports
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("AllProvidersHandler_Password", func(t *testing.T) {
|
||||||
|
// Override default provider support for this test
|
||||||
|
prevDefaultSupports := defaultProvider.SupportsFunc
|
||||||
|
defaultProvider.SupportsFunc = func(creds interface{}) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create new chain for this test
|
||||||
|
testChain := auth.NewAuthChain(manager)
|
||||||
|
testChain.Handler(testChain.AllProvidersHandler())
|
||||||
|
|
||||||
|
credentials := providers.UsernamePasswordCredentials{
|
||||||
|
Username: "testuser",
|
||||||
|
Password: "testpassword",
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := testChain.Authenticate(context.Background(), credentials)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
assert.True(t, result.Success)
|
||||||
|
assert.Equal(t, "password-user", result.UserID)
|
||||||
|
assert.Equal(t, "password", result.ProviderID)
|
||||||
|
|
||||||
|
// Restore original supports function
|
||||||
|
defaultProvider.SupportsFunc = prevDefaultSupports
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("AllProvidersHandler_Token", func(t *testing.T) {
|
||||||
|
// Override default provider support for this test
|
||||||
|
prevDefaultSupports := defaultProvider.SupportsFunc
|
||||||
|
defaultProvider.SupportsFunc = func(creds interface{}) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create new chain for this test
|
||||||
|
testChain := auth.NewAuthChain(manager)
|
||||||
|
testChain.Handler(testChain.AllProvidersHandler())
|
||||||
|
|
||||||
|
credentials := providers.TokenCredentials{
|
||||||
|
TokenType: "Bearer",
|
||||||
|
TokenValue: "test-token",
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := testChain.Authenticate(context.Background(), credentials)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
assert.True(t, result.Success)
|
||||||
|
assert.Equal(t, "token-user", result.UserID)
|
||||||
|
assert.Equal(t, "token", result.ProviderID)
|
||||||
|
|
||||||
|
// Restore original supports function
|
||||||
|
defaultProvider.SupportsFunc = prevDefaultSupports
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("SpecificProviderHandler", func(t *testing.T) {
|
||||||
|
// Create new chain for this test
|
||||||
|
testChain := auth.NewAuthChain(manager)
|
||||||
|
testChain.Handler(testChain.SpecificProviderHandler("token"))
|
||||||
|
|
||||||
|
credentials := providers.TokenCredentials{
|
||||||
|
TokenType: "Bearer",
|
||||||
|
TokenValue: "test-token",
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := testChain.Authenticate(context.Background(), credentials)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
assert.True(t, result.Success)
|
||||||
|
assert.Equal(t, "token-user", result.UserID)
|
||||||
|
assert.Equal(t, "token", result.ProviderID)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ChainOfResponsibility", func(t *testing.T) {
|
||||||
|
// Configure chain with multiple handlers
|
||||||
|
testChain := auth.NewAuthChain(manager)
|
||||||
|
testChain.Handler(testChain.SpecificProviderHandler("nonexistent")) // This will fail
|
||||||
|
testChain.Handler(testChain.SpecificProviderHandler("token")) // This will succeed for token credentials
|
||||||
|
testChain.Handler(testChain.SpecificProviderHandler("password")) // This won't be reached for token credentials
|
||||||
|
|
||||||
|
// Test with token credentials
|
||||||
|
tokenCreds := providers.TokenCredentials{
|
||||||
|
TokenType: "Bearer",
|
||||||
|
TokenValue: "test-token",
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := testChain.Authenticate(context.Background(), tokenCreds)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
assert.True(t, result.Success)
|
||||||
|
assert.Equal(t, "token-user", result.UserID)
|
||||||
|
assert.Equal(t, "token", result.ProviderID)
|
||||||
|
|
||||||
|
// Test with password credentials
|
||||||
|
passwordCreds := providers.UsernamePasswordCredentials{
|
||||||
|
Username: "testuser",
|
||||||
|
Password: "testpassword",
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err = testChain.Authenticate(context.Background(), passwordCreds)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
assert.True(t, result.Success)
|
||||||
|
assert.Equal(t, "password-user", result.UserID)
|
||||||
|
assert.Equal(t, "password", result.ProviderID)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("EmptyChain", func(t *testing.T) {
|
||||||
|
emptyChain := auth.NewAuthChain(manager)
|
||||||
|
|
||||||
|
credentials := providers.UsernamePasswordCredentials{
|
||||||
|
Username: "testuser",
|
||||||
|
Password: "testpassword",
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := emptyChain.Authenticate(context.Background(), credentials)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Nil(t, result)
|
||||||
|
assert.True(t, errors.Is(err, errors.ErrServiceUnavailable))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("AllHandlersFail", func(t *testing.T) {
|
||||||
|
// Save original support functions
|
||||||
|
prevDefaultSupports := defaultProvider.SupportsFunc
|
||||||
|
prevPasswordSupports := passwordProvider.SupportsFunc
|
||||||
|
prevTokenSupports := tokenProvider.SupportsFunc
|
||||||
|
|
||||||
|
// Configure providers to not support any credentials
|
||||||
|
defaultProvider.SupportsFunc = func(creds interface{}) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
passwordProvider.SupportsFunc = func(creds interface{}) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenProvider.SupportsFunc = func(creds interface{}) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
testChain := auth.NewAuthChain(manager)
|
||||||
|
testChain.Handler(testChain.DefaultProviderHandler())
|
||||||
|
testChain.Handler(testChain.AllProvidersHandler())
|
||||||
|
|
||||||
|
credentials := providers.SAMLCredentials{
|
||||||
|
SAMLResponse: "test-response",
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := testChain.Authenticate(context.Background(), credentials)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
assert.False(t, result.Success)
|
||||||
|
|
||||||
|
// Restore original support functions
|
||||||
|
defaultProvider.SupportsFunc = prevDefaultSupports
|
||||||
|
passwordProvider.SupportsFunc = prevPasswordSupports
|
||||||
|
tokenProvider.SupportsFunc = prevTokenSupports
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Middleware", func(t *testing.T) {
|
||||||
|
var middlewareCalled bool
|
||||||
|
|
||||||
|
middleware := func(ctx *providers.AuthContext, credentials providers.Credentials, next auth.AuthHandlerFunc) (*providers.AuthResult, error) {
|
||||||
|
middlewareCalled = true
|
||||||
|
return next(ctx, credentials, next)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Override default provider support for this test
|
||||||
|
prevSupports := defaultProvider.SupportsFunc
|
||||||
|
defaultProvider.SupportsFunc = func(creds interface{}) bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
testChain := auth.NewAuthChain(manager)
|
||||||
|
testChain.Use(middleware)
|
||||||
|
testChain.Handler(testChain.DefaultProviderHandler())
|
||||||
|
|
||||||
|
credentials := providers.UsernamePasswordCredentials{
|
||||||
|
Username: "testuser",
|
||||||
|
Password: "testpassword",
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := testChain.Authenticate(context.Background(), credentials)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
assert.True(t, result.Success)
|
||||||
|
assert.True(t, middlewareCalled)
|
||||||
|
|
||||||
|
// Restore original supports function
|
||||||
|
defaultProvider.SupportsFunc = prevSupports
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("BuildDefaultChain", func(t *testing.T) {
|
||||||
|
// Override supports functions for this test
|
||||||
|
prevDefaultSupports := defaultProvider.SupportsFunc
|
||||||
|
|
||||||
|
// First test: default provider supports the credentials
|
||||||
|
defaultProvider.SupportsFunc = func(creds interface{}) bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
testChain := auth.NewAuthChain(manager).BuildDefaultChain()
|
||||||
|
|
||||||
|
credentials := providers.UsernamePasswordCredentials{
|
||||||
|
Username: "testuser",
|
||||||
|
Password: "testpassword",
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := testChain.Authenticate(context.Background(), credentials)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
assert.True(t, result.Success)
|
||||||
|
assert.Equal(t, "default-user", result.UserID)
|
||||||
|
assert.Equal(t, "default", result.ProviderID)
|
||||||
|
|
||||||
|
// Second test: default provider doesn't support the credentials
|
||||||
|
defaultProvider.SupportsFunc = func(creds interface{}) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
testChain = auth.NewAuthChain(manager).BuildDefaultChain()
|
||||||
|
|
||||||
|
result, err = testChain.Authenticate(context.Background(), credentials)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
assert.True(t, result.Success)
|
||||||
|
assert.Equal(t, "password-user", result.UserID)
|
||||||
|
assert.Equal(t, "password", result.ProviderID)
|
||||||
|
|
||||||
|
// Restore original supports function
|
||||||
|
defaultProvider.SupportsFunc = prevDefaultSupports
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("BuiltInMiddleware", func(t *testing.T) {
|
||||||
|
// Override default provider support for this test
|
||||||
|
prevSupports := defaultProvider.SupportsFunc
|
||||||
|
defaultProvider.SupportsFunc = func(creds interface{}) bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
testChain := auth.NewAuthChain(manager)
|
||||||
|
testChain.Use(auth.LoggingMiddleware())
|
||||||
|
testChain.Use(auth.AuditingMiddleware())
|
||||||
|
testChain.Use(auth.RateLimitingMiddleware(5, 300))
|
||||||
|
testChain.Handler(testChain.DefaultProviderHandler())
|
||||||
|
|
||||||
|
credentials := providers.UsernamePasswordCredentials{
|
||||||
|
Username: "testuser",
|
||||||
|
Password: "testpassword",
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := testChain.Authenticate(context.Background(), credentials)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
assert.True(t, result.Success)
|
||||||
|
|
||||||
|
// Restore original supports function
|
||||||
|
defaultProvider.SupportsFunc = prevSupports
|
||||||
|
})
|
||||||
|
}
|
322
pkg/auth/auth_manager.go
Normal file
322
pkg/auth/auth_manager.go
Normal file
|
@ -0,0 +1,322 @@
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/Fishwaldo/auth2/internal/errors"
|
||||||
|
"github.com/Fishwaldo/auth2/pkg/auth/providers"
|
||||||
|
"github.com/Fishwaldo/auth2/pkg/plugin/metadata"
|
||||||
|
"github.com/Fishwaldo/auth2/pkg/plugin/registry"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ManagerConfig contains configuration for the AuthManager
|
||||||
|
type ManagerConfig struct {
|
||||||
|
DefaultProviderID string
|
||||||
|
MFARequired bool
|
||||||
|
MFARequiredForRoles []string
|
||||||
|
SessionDuration int64
|
||||||
|
TokenExpiration int64
|
||||||
|
MaxLoginAttempts int
|
||||||
|
LockoutDuration int64
|
||||||
|
PasswordPolicyEnabled bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Manager manages authentication providers and handles authentication flows
|
||||||
|
type Manager struct {
|
||||||
|
// Config is the configuration for the manager
|
||||||
|
Config ManagerConfig
|
||||||
|
|
||||||
|
// Providers is a map of provider ID to provider
|
||||||
|
providers map[string]providers.AuthProvider
|
||||||
|
|
||||||
|
// Registry is the provider registry
|
||||||
|
registry *registry.Registry
|
||||||
|
|
||||||
|
// RWMutex for thread safety
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewManager creates a new AuthManager with the provided registry and configuration
|
||||||
|
func NewManager(reg *registry.Registry, config ManagerConfig) *Manager {
|
||||||
|
return &Manager{
|
||||||
|
Config: config,
|
||||||
|
providers: make(map[string]providers.AuthProvider),
|
||||||
|
registry: reg,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterProvider registers an auth provider with the manager
|
||||||
|
func (m *Manager) RegisterProvider(provider providers.AuthProvider) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
meta := provider.GetMetadata()
|
||||||
|
|
||||||
|
// Check if provider is already registered
|
||||||
|
if _, exists := m.providers[meta.ID]; exists {
|
||||||
|
return errors.NewPluginError(
|
||||||
|
errors.ErrIncompatiblePlugin,
|
||||||
|
string(meta.Type),
|
||||||
|
meta.ID,
|
||||||
|
"provider already registered",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register with the global registry if provided
|
||||||
|
if m.registry != nil {
|
||||||
|
if err := m.registry.RegisterProvider(provider); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register with the local provider map
|
||||||
|
m.providers[meta.ID] = provider
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProvider returns a provider by ID
|
||||||
|
func (m *Manager) GetProvider(providerID string) (providers.AuthProvider, error) {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
provider, exists := m.providers[providerID]
|
||||||
|
if !exists {
|
||||||
|
return nil, errors.NewPluginError(
|
||||||
|
errors.ErrPluginNotFound,
|
||||||
|
string(metadata.ProviderTypeAuth),
|
||||||
|
providerID,
|
||||||
|
"provider not registered",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return provider, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProviders returns all registered providers
|
||||||
|
func (m *Manager) GetProviders() map[string]providers.AuthProvider {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
// Create a copy to prevent concurrent modification
|
||||||
|
result := make(map[string]providers.AuthProvider, len(m.providers))
|
||||||
|
for id, provider := range m.providers {
|
||||||
|
result[id] = provider
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnregisterProvider removes a provider from the manager
|
||||||
|
func (m *Manager) UnregisterProvider(providerID string) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
// Check if provider exists
|
||||||
|
if _, exists := m.providers[providerID]; !exists {
|
||||||
|
return errors.NewPluginError(
|
||||||
|
errors.ErrPluginNotFound,
|
||||||
|
string(metadata.ProviderTypeAuth),
|
||||||
|
providerID,
|
||||||
|
"provider not registered",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unregister from the global registry if provided
|
||||||
|
if m.registry != nil {
|
||||||
|
if err := m.registry.UnregisterProvider(metadata.ProviderTypeAuth, providerID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unregister from the local provider map
|
||||||
|
delete(m.providers, providerID)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthenticateWithCredentials authenticates a user with the provided credentials
|
||||||
|
func (m *Manager) AuthenticateWithCredentials(ctx context.Context, credentials providers.Credentials) (*providers.AuthResult, error) {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
// Check if we have any providers
|
||||||
|
if len(m.providers) == 0 {
|
||||||
|
return nil, errors.WrapError(
|
||||||
|
errors.ErrServiceUnavailable,
|
||||||
|
errors.CodeUnavailable,
|
||||||
|
"no authentication providers registered",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create authentication context with request ID
|
||||||
|
authCtx := &providers.AuthContext{
|
||||||
|
OriginalContext: ctx,
|
||||||
|
RequestID: uuid.New().String(),
|
||||||
|
RequestMetadata: make(map[string]interface{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract client information from context if available
|
||||||
|
if clientIP, ok := ctx.Value("client_ip").(string); ok {
|
||||||
|
authCtx.ClientIP = clientIP
|
||||||
|
}
|
||||||
|
|
||||||
|
if userAgent, ok := ctx.Value("user_agent").(string); ok {
|
||||||
|
authCtx.UserAgent = userAgent
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find all providers that support this credential type
|
||||||
|
var supportingProviders []providers.AuthProvider
|
||||||
|
for _, provider := range m.providers {
|
||||||
|
if provider.Supports(credentials) {
|
||||||
|
supportingProviders = append(supportingProviders, provider)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(supportingProviders) == 0 {
|
||||||
|
return nil, errors.WrapError(
|
||||||
|
errors.ErrUnsupported,
|
||||||
|
errors.CodeUnsupported,
|
||||||
|
fmt.Sprintf("no provider supports credentials of type %s", credentials.GetType()),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try each provider until one succeeds or all fail
|
||||||
|
var lastError error
|
||||||
|
var combinedResult *providers.AuthResult
|
||||||
|
|
||||||
|
for _, provider := range supportingProviders {
|
||||||
|
result, err := provider.Authenticate(authCtx, credentials)
|
||||||
|
|
||||||
|
// Return immediately on success
|
||||||
|
if err == nil && result.Success {
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store last error for context
|
||||||
|
lastError = err
|
||||||
|
|
||||||
|
// Combine results (for MFA requirements, etc.)
|
||||||
|
if combinedResult == nil {
|
||||||
|
combinedResult = result
|
||||||
|
} else if result != nil {
|
||||||
|
// Collect MFA providers across results
|
||||||
|
if result.RequiresMFA && len(result.MFAProviders) > 0 {
|
||||||
|
if combinedResult.MFAProviders == nil {
|
||||||
|
combinedResult.MFAProviders = make([]string, 0)
|
||||||
|
}
|
||||||
|
combinedResult.MFAProviders = append(combinedResult.MFAProviders, result.MFAProviders...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collect extra data
|
||||||
|
if result.Extra != nil {
|
||||||
|
if combinedResult.Extra == nil {
|
||||||
|
combinedResult.Extra = make(map[string]interface{})
|
||||||
|
}
|
||||||
|
for k, v := range result.Extra {
|
||||||
|
combinedResult.Extra[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If all providers failed, return the combined result with the last error
|
||||||
|
if combinedResult == nil {
|
||||||
|
combinedResult = &providers.AuthResult{
|
||||||
|
Success: false,
|
||||||
|
Error: lastError,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
combinedResult.Error = lastError
|
||||||
|
}
|
||||||
|
|
||||||
|
return combinedResult, lastError
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthenticateWithProviderID authenticates a user with a specific provider
|
||||||
|
func (m *Manager) AuthenticateWithProviderID(ctx context.Context, providerID string, credentials providers.Credentials) (*providers.AuthResult, error) {
|
||||||
|
provider, err := m.GetProvider(providerID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !provider.Supports(credentials) {
|
||||||
|
return nil, errors.WrapError(
|
||||||
|
errors.ErrUnsupported,
|
||||||
|
errors.CodeUnsupported,
|
||||||
|
fmt.Sprintf("provider %s does not support credentials of type %s",
|
||||||
|
providerID, credentials.GetType()),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create authentication context with request ID
|
||||||
|
authCtx := &providers.AuthContext{
|
||||||
|
OriginalContext: ctx,
|
||||||
|
RequestID: uuid.New().String(),
|
||||||
|
RequestMetadata: make(map[string]interface{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract client information from context if available
|
||||||
|
if clientIP, ok := ctx.Value("client_ip").(string); ok {
|
||||||
|
authCtx.ClientIP = clientIP
|
||||||
|
}
|
||||||
|
|
||||||
|
if userAgent, ok := ctx.Value("user_agent").(string); ok {
|
||||||
|
authCtx.UserAgent = userAgent
|
||||||
|
}
|
||||||
|
|
||||||
|
return provider.Authenticate(authCtx, credentials)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateConfig validates the manager configuration
|
||||||
|
func (m *Manager) ValidateConfig() error {
|
||||||
|
if m.Config.DefaultProviderID != "" {
|
||||||
|
if _, err := m.GetProvider(m.Config.DefaultProviderID); err != nil {
|
||||||
|
return errors.WrapError(
|
||||||
|
err,
|
||||||
|
errors.CodeConfiguration,
|
||||||
|
fmt.Sprintf("default provider %s not registered", m.Config.DefaultProviderID),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize initializes all registered providers
|
||||||
|
func (m *Manager) Initialize(ctx context.Context, configs map[string]interface{}) error {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
var errs []error
|
||||||
|
|
||||||
|
// Initialize all registered providers
|
||||||
|
for id, provider := range m.providers {
|
||||||
|
// Get provider-specific configuration
|
||||||
|
configKey := fmt.Sprintf("auth.%s", id)
|
||||||
|
config, ok := configs[configKey]
|
||||||
|
if !ok {
|
||||||
|
// Use nil config if not provided
|
||||||
|
config = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize the provider
|
||||||
|
if err := provider.Initialize(ctx, config); err != nil {
|
||||||
|
errs = append(errs, errors.NewPluginError(
|
||||||
|
err,
|
||||||
|
string(metadata.ProviderTypeAuth),
|
||||||
|
id,
|
||||||
|
"provider initialization failed",
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(errs) > 0 {
|
||||||
|
return fmt.Errorf("provider initialization errors: %v", errs)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
335
pkg/auth/auth_manager_black_test.go
Normal file
335
pkg/auth/auth_manager_black_test.go
Normal file
|
@ -0,0 +1,335 @@
|
||||||
|
package auth_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Fishwaldo/auth2/internal/errors"
|
||||||
|
"github.com/Fishwaldo/auth2/pkg/auth"
|
||||||
|
"github.com/Fishwaldo/auth2/pkg/auth/providers"
|
||||||
|
"github.com/Fishwaldo/auth2/pkg/auth/test"
|
||||||
|
"github.com/Fishwaldo/auth2/pkg/plugin/metadata"
|
||||||
|
"github.com/Fishwaldo/auth2/pkg/plugin/registry"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAuthManager(t *testing.T) {
|
||||||
|
reg := registry.NewRegistry()
|
||||||
|
config := auth.ManagerConfig{
|
||||||
|
DefaultProviderID: "default",
|
||||||
|
MFARequired: false,
|
||||||
|
MaxLoginAttempts: 5,
|
||||||
|
LockoutDuration: 300,
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := auth.NewManager(reg, config)
|
||||||
|
|
||||||
|
// Create test providers
|
||||||
|
successProvider := test.NewMockAuthProvider(metadata.ProviderMetadata{
|
||||||
|
ID: "success",
|
||||||
|
Type: metadata.ProviderTypeAuth,
|
||||||
|
Version: "1.0.0",
|
||||||
|
Name: "Success Provider",
|
||||||
|
Description: "Always succeeds",
|
||||||
|
Author: "Auth2 Team",
|
||||||
|
})
|
||||||
|
|
||||||
|
successProvider.AuthenticateFunc = func(ctx *providers.AuthContext, credentials interface{}) (*providers.AuthResult, error) {
|
||||||
|
return &providers.AuthResult{
|
||||||
|
Success: true,
|
||||||
|
UserID: "user123",
|
||||||
|
ProviderID: "success",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
failureProvider := test.NewMockAuthProvider(metadata.ProviderMetadata{
|
||||||
|
ID: "failure",
|
||||||
|
Type: metadata.ProviderTypeAuth,
|
||||||
|
Version: "1.0.0",
|
||||||
|
Name: "Failure Provider",
|
||||||
|
Description: "Always fails",
|
||||||
|
Author: "Auth2 Team",
|
||||||
|
})
|
||||||
|
|
||||||
|
failureProvider.AuthenticateFunc = func(ctx *providers.AuthContext, credentials interface{}) (*providers.AuthResult, error) {
|
||||||
|
errResult := providers.NewAuthFailedError("authentication failed", nil)
|
||||||
|
return &providers.AuthResult{
|
||||||
|
Success: false,
|
||||||
|
ProviderID: "failure",
|
||||||
|
Error: errResult,
|
||||||
|
}, errResult
|
||||||
|
}
|
||||||
|
|
||||||
|
mfaProvider := test.NewMockAuthProvider(metadata.ProviderMetadata{
|
||||||
|
ID: "mfa",
|
||||||
|
Type: metadata.ProviderTypeAuth,
|
||||||
|
Version: "1.0.0",
|
||||||
|
Name: "MFA Provider",
|
||||||
|
Description: "Requires MFA",
|
||||||
|
Author: "Auth2 Team",
|
||||||
|
})
|
||||||
|
|
||||||
|
mfaProvider.AuthenticateFunc = func(ctx *providers.AuthContext, credentials interface{}) (*providers.AuthResult, error) {
|
||||||
|
errResult := providers.NewMFARequiredError("user123", []string{"totp", "webauthn"})
|
||||||
|
return &providers.AuthResult{
|
||||||
|
Success: false,
|
||||||
|
UserID: "user123",
|
||||||
|
ProviderID: "mfa",
|
||||||
|
RequiresMFA: true,
|
||||||
|
MFAProviders: []string{"totp", "webauthn"},
|
||||||
|
Error: errResult,
|
||||||
|
}, errResult
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("RegisterProvider", func(t *testing.T) {
|
||||||
|
err := manager.RegisterProvider(successProvider)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = manager.RegisterProvider(failureProvider)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = manager.RegisterProvider(mfaProvider)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Register same provider again should fail
|
||||||
|
err = manager.RegisterProvider(successProvider)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.True(t, errors.IsPluginError(err))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("GetProvider", func(t *testing.T) {
|
||||||
|
provider, err := manager.GetProvider("success")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "success", provider.GetMetadata().ID)
|
||||||
|
|
||||||
|
provider, err = manager.GetProvider("nonexistent")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Nil(t, provider)
|
||||||
|
assert.True(t, errors.IsPluginError(err))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("GetProviders", func(t *testing.T) {
|
||||||
|
providers := manager.GetProviders()
|
||||||
|
assert.Len(t, providers, 3)
|
||||||
|
assert.Contains(t, providers, "success")
|
||||||
|
assert.Contains(t, providers, "failure")
|
||||||
|
assert.Contains(t, providers, "mfa")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("AuthenticateWithCredentials_Success", func(t *testing.T) {
|
||||||
|
credentials := providers.UsernamePasswordCredentials{
|
||||||
|
Username: "testuser",
|
||||||
|
Password: "testpassword",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure provider to support these credentials
|
||||||
|
successProvider.SupportsFunc = func(creds interface{}) bool {
|
||||||
|
_, ok := creds.(providers.UsernamePasswordCredentials)
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
failureProvider.SupportsFunc = func(creds interface{}) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
mfaProvider.SupportsFunc = func(creds interface{}) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := manager.AuthenticateWithCredentials(context.Background(), credentials)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
assert.True(t, result.Success)
|
||||||
|
assert.Equal(t, "user123", result.UserID)
|
||||||
|
assert.Equal(t, "success", result.ProviderID)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("AuthenticateWithCredentials_Failure", func(t *testing.T) {
|
||||||
|
credentials := providers.TokenCredentials{
|
||||||
|
TokenType: "Bearer",
|
||||||
|
TokenValue: "invalid-token",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure providers to not support these credentials
|
||||||
|
successProvider.SupportsFunc = func(creds interface{}) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
failureProvider.SupportsFunc = func(creds interface{}) bool {
|
||||||
|
_, ok := creds.(providers.TokenCredentials)
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
mfaProvider.SupportsFunc = func(creds interface{}) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := manager.AuthenticateWithCredentials(context.Background(), credentials)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
assert.False(t, result.Success)
|
||||||
|
assert.Equal(t, "failure", result.ProviderID)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("AuthenticateWithCredentials_MFA", func(t *testing.T) {
|
||||||
|
credentials := providers.OAuthCredentials{
|
||||||
|
ProviderName: "google",
|
||||||
|
Code: "test-code",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure providers to support these credentials
|
||||||
|
successProvider.SupportsFunc = func(creds interface{}) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
failureProvider.SupportsFunc = func(creds interface{}) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
mfaProvider.SupportsFunc = func(creds interface{}) bool {
|
||||||
|
_, ok := creds.(providers.OAuthCredentials)
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := manager.AuthenticateWithCredentials(context.Background(), credentials)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
assert.False(t, result.Success)
|
||||||
|
assert.True(t, result.RequiresMFA)
|
||||||
|
assert.Equal(t, "mfa", result.ProviderID)
|
||||||
|
assert.Equal(t, []string{"totp", "webauthn"}, result.MFAProviders)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("AuthenticateWithCredentials_NoProviders", func(t *testing.T) {
|
||||||
|
credentials := providers.SAMLCredentials{
|
||||||
|
SAMLResponse: "test-response",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure all providers to not support these credentials
|
||||||
|
successProvider.SupportsFunc = func(creds interface{}) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
failureProvider.SupportsFunc = func(creds interface{}) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
mfaProvider.SupportsFunc = func(creds interface{}) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := manager.AuthenticateWithCredentials(context.Background(), credentials)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.True(t, errors.Is(err, errors.ErrUnsupported))
|
||||||
|
assert.Nil(t, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("AuthenticateWithProviderID_Success", func(t *testing.T) {
|
||||||
|
credentials := providers.UsernamePasswordCredentials{
|
||||||
|
Username: "testuser",
|
||||||
|
Password: "testpassword",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure provider to support these credentials
|
||||||
|
successProvider.SupportsFunc = func(creds interface{}) bool {
|
||||||
|
_, ok := creds.(providers.UsernamePasswordCredentials)
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := manager.AuthenticateWithProviderID(context.Background(), "success", credentials)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
assert.True(t, result.Success)
|
||||||
|
assert.Equal(t, "user123", result.UserID)
|
||||||
|
assert.Equal(t, "success", result.ProviderID)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("AuthenticateWithProviderID_Failure", func(t *testing.T) {
|
||||||
|
credentials := providers.UsernamePasswordCredentials{
|
||||||
|
Username: "testuser",
|
||||||
|
Password: "testpassword",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure failure provider to support these credentials
|
||||||
|
failureProvider.SupportsFunc = func(creds interface{}) bool {
|
||||||
|
_, ok := creds.(providers.UsernamePasswordCredentials)
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := manager.AuthenticateWithProviderID(context.Background(), "failure", credentials)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
assert.False(t, result.Success)
|
||||||
|
assert.Equal(t, "failure", result.ProviderID)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("AuthenticateWithProviderID_NotFound", func(t *testing.T) {
|
||||||
|
credentials := providers.UsernamePasswordCredentials{
|
||||||
|
Username: "testuser",
|
||||||
|
Password: "testpassword",
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := manager.AuthenticateWithProviderID(context.Background(), "nonexistent", credentials)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Nil(t, result)
|
||||||
|
assert.True(t, errors.IsPluginError(err))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("AuthenticateWithProviderID_Unsupported", func(t *testing.T) {
|
||||||
|
credentials := providers.UsernamePasswordCredentials{
|
||||||
|
Username: "testuser",
|
||||||
|
Password: "testpassword",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure provider to not support these credentials
|
||||||
|
successProvider.SupportsFunc = func(creds interface{}) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := manager.AuthenticateWithProviderID(context.Background(), "success", credentials)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Nil(t, result)
|
||||||
|
assert.True(t, errors.Is(err, errors.ErrUnsupported))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("UnregisterProvider", func(t *testing.T) {
|
||||||
|
err := manager.UnregisterProvider("success")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify provider was removed
|
||||||
|
provider, err := manager.GetProvider("success")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Nil(t, provider)
|
||||||
|
|
||||||
|
// Unregister nonexistent provider should fail
|
||||||
|
err = manager.UnregisterProvider("nonexistent")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.True(t, errors.IsPluginError(err))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ValidateConfig", func(t *testing.T) {
|
||||||
|
// Current config has default provider "default" which doesn't exist
|
||||||
|
err := manager.ValidateConfig()
|
||||||
|
assert.Error(t, err)
|
||||||
|
|
||||||
|
// Set default provider to a registered provider
|
||||||
|
manager.Config.DefaultProviderID = "failure"
|
||||||
|
err = manager.ValidateConfig()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Initialize", func(t *testing.T) {
|
||||||
|
err := manager.Initialize(context.Background(), map[string]interface{}{
|
||||||
|
"auth.failure": map[string]interface{}{
|
||||||
|
"option": "value",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify Initialize was called on the provider
|
||||||
|
assert.Len(t, failureProvider.InitializeCalls, 1)
|
||||||
|
assert.Equal(t, map[string]interface{}{
|
||||||
|
"option": "value",
|
||||||
|
}, failureProvider.InitializeCalls[0].Config)
|
||||||
|
})
|
||||||
|
}
|
126
pkg/auth/providers/credentials.go
Normal file
126
pkg/auth/providers/credentials.go
Normal file
|
@ -0,0 +1,126 @@
|
||||||
|
package providers
|
||||||
|
|
||||||
|
// CredentialType defines the type of authentication credentials
|
||||||
|
type CredentialType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// CredentialTypeUsernamePassword represents username/password credentials
|
||||||
|
CredentialTypeUsernamePassword CredentialType = "username_password"
|
||||||
|
|
||||||
|
// CredentialTypeOAuth represents OAuth credentials
|
||||||
|
CredentialTypeOAuth CredentialType = "oauth"
|
||||||
|
|
||||||
|
// CredentialTypeSAML represents SAML credentials
|
||||||
|
CredentialTypeSAML CredentialType = "saml"
|
||||||
|
|
||||||
|
// CredentialTypeWebAuthn represents WebAuthn credentials
|
||||||
|
CredentialTypeWebAuthn CredentialType = "webauthn"
|
||||||
|
|
||||||
|
// CredentialTypeMFA represents MFA verification credentials
|
||||||
|
CredentialTypeMFA CredentialType = "mfa"
|
||||||
|
|
||||||
|
// CredentialTypeSession represents session credentials
|
||||||
|
CredentialTypeSession CredentialType = "session"
|
||||||
|
|
||||||
|
// CredentialTypeToken represents token credentials
|
||||||
|
CredentialTypeToken CredentialType = "token"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Credentials is the base interface for all authentication credentials
|
||||||
|
type Credentials interface {
|
||||||
|
// GetType returns the type of credentials
|
||||||
|
GetType() CredentialType
|
||||||
|
}
|
||||||
|
|
||||||
|
// UsernamePasswordCredentials represents username/password credentials
|
||||||
|
type UsernamePasswordCredentials struct {
|
||||||
|
Username string
|
||||||
|
Password string
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetType returns the type of credentials
|
||||||
|
func (c UsernamePasswordCredentials) GetType() CredentialType {
|
||||||
|
return CredentialTypeUsernamePassword
|
||||||
|
}
|
||||||
|
|
||||||
|
// OAuthCredentials represents OAuth credentials
|
||||||
|
type OAuthCredentials struct {
|
||||||
|
ProviderName string
|
||||||
|
Code string
|
||||||
|
RedirectURI string
|
||||||
|
State string
|
||||||
|
Scope string
|
||||||
|
TokenType string
|
||||||
|
AccessToken string
|
||||||
|
RefreshToken string
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetType returns the type of credentials
|
||||||
|
func (c OAuthCredentials) GetType() CredentialType {
|
||||||
|
return CredentialTypeOAuth
|
||||||
|
}
|
||||||
|
|
||||||
|
// SAMLCredentials represents SAML credentials
|
||||||
|
type SAMLCredentials struct {
|
||||||
|
SAMLResponse string
|
||||||
|
RelayState string
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetType returns the type of credentials
|
||||||
|
func (c SAMLCredentials) GetType() CredentialType {
|
||||||
|
return CredentialTypeSAML
|
||||||
|
}
|
||||||
|
|
||||||
|
// WebAuthnCredentials represents WebAuthn credentials
|
||||||
|
type WebAuthnCredentials struct {
|
||||||
|
CredentialID []byte
|
||||||
|
AuthenticatorData []byte
|
||||||
|
ClientDataJSON []byte
|
||||||
|
Signature []byte
|
||||||
|
UserHandle []byte
|
||||||
|
Challenge string
|
||||||
|
RelyingPartyID string
|
||||||
|
UserVerification string
|
||||||
|
Extensions map[string]interface{}
|
||||||
|
RegistrationPhase bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetType returns the type of credentials
|
||||||
|
func (c WebAuthnCredentials) GetType() CredentialType {
|
||||||
|
return CredentialTypeWebAuthn
|
||||||
|
}
|
||||||
|
|
||||||
|
// MFACredentials represents MFA verification credentials
|
||||||
|
type MFACredentials struct {
|
||||||
|
UserID string
|
||||||
|
ProviderID string
|
||||||
|
Code string
|
||||||
|
Challenge string
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetType returns the type of credentials
|
||||||
|
func (c MFACredentials) GetType() CredentialType {
|
||||||
|
return CredentialTypeMFA
|
||||||
|
}
|
||||||
|
|
||||||
|
// SessionCredentials represents session-based credentials
|
||||||
|
type SessionCredentials struct {
|
||||||
|
SessionID string
|
||||||
|
Token string
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetType returns the type of credentials
|
||||||
|
func (c SessionCredentials) GetType() CredentialType {
|
||||||
|
return CredentialTypeSession
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokenCredentials represents token-based credentials
|
||||||
|
type TokenCredentials struct {
|
||||||
|
TokenType string
|
||||||
|
TokenValue string
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetType returns the type of credentials
|
||||||
|
func (c TokenCredentials) GetType() CredentialType {
|
||||||
|
return CredentialTypeToken
|
||||||
|
}
|
121
pkg/auth/providers/credentials_black_test.go
Normal file
121
pkg/auth/providers/credentials_black_test.go
Normal file
|
@ -0,0 +1,121 @@
|
||||||
|
package providers_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Fishwaldo/auth2/pkg/auth/providers"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCredentialTypes(t *testing.T) {
|
||||||
|
t.Run("UsernamePasswordCredentials", func(t *testing.T) {
|
||||||
|
creds := providers.UsernamePasswordCredentials{
|
||||||
|
Username: "testuser",
|
||||||
|
Password: "testpassword",
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, providers.CredentialTypeUsernamePassword, creds.GetType())
|
||||||
|
assert.Equal(t, "testuser", creds.Username)
|
||||||
|
assert.Equal(t, "testpassword", creds.Password)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("OAuthCredentials", func(t *testing.T) {
|
||||||
|
creds := providers.OAuthCredentials{
|
||||||
|
ProviderName: "google",
|
||||||
|
Code: "test-code",
|
||||||
|
RedirectURI: "https://example.com/callback",
|
||||||
|
State: "test-state",
|
||||||
|
Scope: "email profile",
|
||||||
|
TokenType: "Bearer",
|
||||||
|
AccessToken: "test-access-token",
|
||||||
|
RefreshToken: "test-refresh-token",
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, providers.CredentialTypeOAuth, creds.GetType())
|
||||||
|
assert.Equal(t, "google", creds.ProviderName)
|
||||||
|
assert.Equal(t, "test-code", creds.Code)
|
||||||
|
assert.Equal(t, "https://example.com/callback", creds.RedirectURI)
|
||||||
|
assert.Equal(t, "test-state", creds.State)
|
||||||
|
assert.Equal(t, "email profile", creds.Scope)
|
||||||
|
assert.Equal(t, "Bearer", creds.TokenType)
|
||||||
|
assert.Equal(t, "test-access-token", creds.AccessToken)
|
||||||
|
assert.Equal(t, "test-refresh-token", creds.RefreshToken)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("SAMLCredentials", func(t *testing.T) {
|
||||||
|
creds := providers.SAMLCredentials{
|
||||||
|
SAMLResponse: "test-saml-response",
|
||||||
|
RelayState: "test-relay-state",
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, providers.CredentialTypeSAML, creds.GetType())
|
||||||
|
assert.Equal(t, "test-saml-response", creds.SAMLResponse)
|
||||||
|
assert.Equal(t, "test-relay-state", creds.RelayState)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("WebAuthnCredentials", func(t *testing.T) {
|
||||||
|
creds := providers.WebAuthnCredentials{
|
||||||
|
CredentialID: []byte("test-credential-id"),
|
||||||
|
AuthenticatorData: []byte("test-authenticator-data"),
|
||||||
|
ClientDataJSON: []byte("test-client-data-json"),
|
||||||
|
Signature: []byte("test-signature"),
|
||||||
|
UserHandle: []byte("test-user-handle"),
|
||||||
|
Challenge: "test-challenge",
|
||||||
|
RelyingPartyID: "example.com",
|
||||||
|
UserVerification: "required",
|
||||||
|
Extensions: map[string]interface{}{
|
||||||
|
"test-extension": "test-value",
|
||||||
|
},
|
||||||
|
RegistrationPhase: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, providers.CredentialTypeWebAuthn, creds.GetType())
|
||||||
|
assert.Equal(t, []byte("test-credential-id"), creds.CredentialID)
|
||||||
|
assert.Equal(t, []byte("test-authenticator-data"), creds.AuthenticatorData)
|
||||||
|
assert.Equal(t, []byte("test-client-data-json"), creds.ClientDataJSON)
|
||||||
|
assert.Equal(t, []byte("test-signature"), creds.Signature)
|
||||||
|
assert.Equal(t, []byte("test-user-handle"), creds.UserHandle)
|
||||||
|
assert.Equal(t, "test-challenge", creds.Challenge)
|
||||||
|
assert.Equal(t, "example.com", creds.RelyingPartyID)
|
||||||
|
assert.Equal(t, "required", creds.UserVerification)
|
||||||
|
assert.Equal(t, "test-value", creds.Extensions["test-extension"])
|
||||||
|
assert.True(t, creds.RegistrationPhase)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("MFACredentials", func(t *testing.T) {
|
||||||
|
creds := providers.MFACredentials{
|
||||||
|
UserID: "test-user-id",
|
||||||
|
ProviderID: "totp",
|
||||||
|
Code: "123456",
|
||||||
|
Challenge: "test-challenge",
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, providers.CredentialTypeMFA, creds.GetType())
|
||||||
|
assert.Equal(t, "test-user-id", creds.UserID)
|
||||||
|
assert.Equal(t, "totp", creds.ProviderID)
|
||||||
|
assert.Equal(t, "123456", creds.Code)
|
||||||
|
assert.Equal(t, "test-challenge", creds.Challenge)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("SessionCredentials", func(t *testing.T) {
|
||||||
|
creds := providers.SessionCredentials{
|
||||||
|
SessionID: "test-session-id",
|
||||||
|
Token: "test-token",
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, providers.CredentialTypeSession, creds.GetType())
|
||||||
|
assert.Equal(t, "test-session-id", creds.SessionID)
|
||||||
|
assert.Equal(t, "test-token", creds.Token)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("TokenCredentials", func(t *testing.T) {
|
||||||
|
creds := providers.TokenCredentials{
|
||||||
|
TokenType: "Bearer",
|
||||||
|
TokenValue: "test-token-value",
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, providers.CredentialTypeToken, creds.GetType())
|
||||||
|
assert.Equal(t, "Bearer", creds.TokenType)
|
||||||
|
assert.Equal(t, "test-token-value", creds.TokenValue)
|
||||||
|
})
|
||||||
|
}
|
128
pkg/auth/providers/provider.go
Normal file
128
pkg/auth/providers/provider.go
Normal file
|
@ -0,0 +1,128 @@
|
||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"github.com/Fishwaldo/auth2/internal/errors"
|
||||||
|
"github.com/Fishwaldo/auth2/pkg/plugin/metadata"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuthResult contains the result of an authentication attempt
|
||||||
|
type AuthResult struct {
|
||||||
|
// Success indicates whether the authentication was successful
|
||||||
|
Success bool
|
||||||
|
|
||||||
|
// UserID is the authenticated user's ID (if successful)
|
||||||
|
UserID string
|
||||||
|
|
||||||
|
// ProviderID is the ID of the provider that authenticated the user
|
||||||
|
ProviderID string
|
||||||
|
|
||||||
|
// RequiresMFA indicates whether MFA is required to complete authentication
|
||||||
|
RequiresMFA bool
|
||||||
|
|
||||||
|
// MFAProviders is a list of MFA provider IDs that the user has enabled
|
||||||
|
MFAProviders []string
|
||||||
|
|
||||||
|
// Extra contains additional provider-specific data
|
||||||
|
Extra map[string]interface{}
|
||||||
|
|
||||||
|
// Error contains any error that occurred during authentication
|
||||||
|
Error error
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthContext contains context information for authentication
|
||||||
|
type AuthContext struct {
|
||||||
|
// OriginalContext is the original context passed to the provider
|
||||||
|
OriginalContext context.Context
|
||||||
|
|
||||||
|
// RequestID is a unique identifier for the authentication request
|
||||||
|
RequestID string
|
||||||
|
|
||||||
|
// ClientIP is the IP address of the client making the request
|
||||||
|
ClientIP string
|
||||||
|
|
||||||
|
// UserAgent is the user agent of the client making the request
|
||||||
|
UserAgent string
|
||||||
|
|
||||||
|
// RequestMetadata contains additional request metadata
|
||||||
|
RequestMetadata map[string]interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthProvider defines the interface for authentication providers
|
||||||
|
type AuthProvider interface {
|
||||||
|
metadata.Provider
|
||||||
|
|
||||||
|
// Authenticate verifies user credentials and returns an AuthResult
|
||||||
|
Authenticate(ctx *AuthContext, credentials interface{}) (*AuthResult, error)
|
||||||
|
|
||||||
|
// Supports returns true if this provider supports the given credentials type
|
||||||
|
Supports(credentials interface{}) bool
|
||||||
|
|
||||||
|
// GetMetadata returns provider metadata (already provided by metadata.Provider)
|
||||||
|
// GetMetadata() metadata.ProviderMetadata
|
||||||
|
}
|
||||||
|
|
||||||
|
// BaseAuthProvider provides a base implementation of the AuthProvider interface
|
||||||
|
type BaseAuthProvider struct {
|
||||||
|
*metadata.BaseProvider
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewBaseAuthProvider creates a new BaseAuthProvider
|
||||||
|
func NewBaseAuthProvider(meta metadata.ProviderMetadata) *BaseAuthProvider {
|
||||||
|
return &BaseAuthProvider{
|
||||||
|
BaseProvider: metadata.NewBaseProvider(meta),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authenticate provides a default implementation that always fails
|
||||||
|
func (p *BaseAuthProvider) Authenticate(ctx *AuthContext, credentials interface{}) (*AuthResult, error) {
|
||||||
|
return &AuthResult{
|
||||||
|
Success: false,
|
||||||
|
ProviderID: p.GetMetadata().ID,
|
||||||
|
Error: metadata.ErrNotImplemented,
|
||||||
|
}, metadata.ErrNotImplemented
|
||||||
|
}
|
||||||
|
|
||||||
|
// Supports provides a default implementation that always returns false
|
||||||
|
func (p *BaseAuthProvider) Supports(credentials interface{}) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsAuthenticationError checks if an error is an authentication error
|
||||||
|
func IsAuthenticationError(err error) bool {
|
||||||
|
authError := errors.ErrAuthFailed
|
||||||
|
providerError := metadata.ErrNotImplemented
|
||||||
|
return errors.Is(err, authError) || errors.Is(err, providerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAuthFailedError creates a new authentication failed error
|
||||||
|
func NewAuthFailedError(reason string, details map[string]interface{}) error {
|
||||||
|
if details == nil {
|
||||||
|
details = make(map[string]interface{})
|
||||||
|
}
|
||||||
|
details["reason"] = reason
|
||||||
|
|
||||||
|
return errors.WrapError(errors.ErrAuthFailed, errors.CodeAuthFailed, reason).WithDetails(details)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewInvalidCredentialsError creates a new invalid credentials error
|
||||||
|
func NewInvalidCredentialsError(reason string) error {
|
||||||
|
return errors.WrapError(errors.ErrInvalidCredentials, errors.CodeInvalidCredentials, reason)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewUserNotFoundError creates a new user not found error
|
||||||
|
func NewUserNotFoundError(identifier string) error {
|
||||||
|
return errors.WrapError(errors.ErrUserNotFound, errors.CodeUserNotFound,
|
||||||
|
"user not found").WithDetails(map[string]interface{}{
|
||||||
|
"identifier": identifier,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMFARequiredError creates a new MFA required error
|
||||||
|
func NewMFARequiredError(userID string, availableProviders []string) error {
|
||||||
|
return errors.WrapError(errors.ErrMFARequired, errors.CodeMFARequired,
|
||||||
|
"multi-factor authentication required").WithDetails(map[string]interface{}{
|
||||||
|
"user_id": userID,
|
||||||
|
"mfa_providers": availableProviders,
|
||||||
|
})
|
||||||
|
}
|
132
pkg/auth/providers/provider_black_test.go
Normal file
132
pkg/auth/providers/provider_black_test.go
Normal file
|
@ -0,0 +1,132 @@
|
||||||
|
package providers_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Fishwaldo/auth2/internal/errors"
|
||||||
|
"github.com/Fishwaldo/auth2/pkg/auth/providers"
|
||||||
|
"github.com/Fishwaldo/auth2/pkg/plugin/metadata"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBaseAuthProvider(t *testing.T) {
|
||||||
|
meta := metadata.ProviderMetadata{
|
||||||
|
ID: "test",
|
||||||
|
Type: metadata.ProviderTypeAuth,
|
||||||
|
Version: "1.0.0",
|
||||||
|
Name: "Test Provider",
|
||||||
|
Description: "Test provider for unit tests",
|
||||||
|
Author: "Auth2 Team",
|
||||||
|
}
|
||||||
|
|
||||||
|
provider := providers.NewBaseAuthProvider(meta)
|
||||||
|
|
||||||
|
t.Run("GetMetadata", func(t *testing.T) {
|
||||||
|
returnedMeta := provider.GetMetadata()
|
||||||
|
assert.Equal(t, meta, returnedMeta)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Authenticate", func(t *testing.T) {
|
||||||
|
ctx := &providers.AuthContext{
|
||||||
|
OriginalContext: context.Background(),
|
||||||
|
RequestID: "test-request",
|
||||||
|
}
|
||||||
|
|
||||||
|
credentials := providers.UsernamePasswordCredentials{
|
||||||
|
Username: "testuser",
|
||||||
|
Password: "testpassword",
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := provider.Authenticate(ctx, credentials)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Equal(t, metadata.ErrNotImplemented, err)
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
assert.False(t, result.Success)
|
||||||
|
assert.Equal(t, "test", result.ProviderID)
|
||||||
|
assert.Equal(t, metadata.ErrNotImplemented, result.Error)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Supports", func(t *testing.T) {
|
||||||
|
credentials := providers.UsernamePasswordCredentials{
|
||||||
|
Username: "testuser",
|
||||||
|
Password: "testpassword",
|
||||||
|
}
|
||||||
|
|
||||||
|
supports := provider.Supports(credentials)
|
||||||
|
assert.False(t, supports)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Initialize", func(t *testing.T) {
|
||||||
|
err := provider.Initialize(context.Background(), nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Validate", func(t *testing.T) {
|
||||||
|
err := provider.Validate(context.Background())
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthErrorHelpers(t *testing.T) {
|
||||||
|
t.Run("IsAuthenticationError", func(t *testing.T) {
|
||||||
|
assert.True(t, providers.IsAuthenticationError(errors.ErrAuthFailed))
|
||||||
|
assert.True(t, providers.IsAuthenticationError(metadata.ErrNotImplemented))
|
||||||
|
assert.False(t, providers.IsAuthenticationError(errors.ErrNotFound))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("NewAuthFailedError", func(t *testing.T) {
|
||||||
|
err := providers.NewAuthFailedError("test reason", map[string]interface{}{
|
||||||
|
"test_key": "test_value",
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.True(t, errors.Is(err, errors.ErrAuthFailed))
|
||||||
|
|
||||||
|
var stdErr *errors.Error
|
||||||
|
assert.True(t, errors.As(err, &stdErr))
|
||||||
|
assert.Equal(t, errors.CodeAuthFailed, stdErr.ErrorCode)
|
||||||
|
assert.Equal(t, "test reason", stdErr.Message)
|
||||||
|
assert.Equal(t, "test_value", stdErr.Details["test_key"])
|
||||||
|
assert.Equal(t, "test reason", stdErr.Details["reason"])
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("NewInvalidCredentialsError", func(t *testing.T) {
|
||||||
|
err := providers.NewInvalidCredentialsError("test reason")
|
||||||
|
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.True(t, errors.Is(err, errors.ErrInvalidCredentials))
|
||||||
|
|
||||||
|
var stdErr *errors.Error
|
||||||
|
assert.True(t, errors.As(err, &stdErr))
|
||||||
|
assert.Equal(t, errors.CodeInvalidCredentials, stdErr.ErrorCode)
|
||||||
|
assert.Equal(t, "test reason", stdErr.Message)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("NewUserNotFoundError", func(t *testing.T) {
|
||||||
|
err := providers.NewUserNotFoundError("test@example.com")
|
||||||
|
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.True(t, errors.Is(err, errors.ErrUserNotFound))
|
||||||
|
|
||||||
|
var stdErr *errors.Error
|
||||||
|
assert.True(t, errors.As(err, &stdErr))
|
||||||
|
assert.Equal(t, errors.CodeUserNotFound, stdErr.ErrorCode)
|
||||||
|
assert.Equal(t, "user not found", stdErr.Message)
|
||||||
|
assert.Equal(t, "test@example.com", stdErr.Details["identifier"])
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("NewMFARequiredError", func(t *testing.T) {
|
||||||
|
err := providers.NewMFARequiredError("user123", []string{"totp", "webauthn"})
|
||||||
|
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.True(t, errors.Is(err, errors.ErrMFARequired))
|
||||||
|
|
||||||
|
var stdErr *errors.Error
|
||||||
|
assert.True(t, errors.As(err, &stdErr))
|
||||||
|
assert.Equal(t, errors.CodeMFARequired, stdErr.ErrorCode)
|
||||||
|
assert.Equal(t, "multi-factor authentication required", stdErr.Message)
|
||||||
|
assert.Equal(t, "user123", stdErr.Details["user_id"])
|
||||||
|
assert.Equal(t, []string{"totp", "webauthn"}, stdErr.Details["mfa_providers"])
|
||||||
|
})
|
||||||
|
}
|
140
pkg/auth/test/mock_provider.go
Normal file
140
pkg/auth/test/mock_provider.go
Normal file
|
@ -0,0 +1,140 @@
|
||||||
|
package test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"github.com/Fishwaldo/auth2/internal/errors"
|
||||||
|
"github.com/Fishwaldo/auth2/pkg/auth/providers"
|
||||||
|
"github.com/Fishwaldo/auth2/pkg/plugin/metadata"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockAuthProvider implements the AuthProvider interface for testing
|
||||||
|
type MockAuthProvider struct {
|
||||||
|
*providers.BaseAuthProvider
|
||||||
|
|
||||||
|
// AuthenticateFunc can be set to mock the Authenticate method
|
||||||
|
AuthenticateFunc func(ctx *providers.AuthContext, credentials interface{}) (*providers.AuthResult, error)
|
||||||
|
|
||||||
|
// SupportsFunc can be set to mock the Supports method
|
||||||
|
SupportsFunc func(credentials interface{}) bool
|
||||||
|
|
||||||
|
// InitializeFunc can be set to mock the Initialize method
|
||||||
|
InitializeFunc func(ctx context.Context, config interface{}) error
|
||||||
|
|
||||||
|
// ValidateFunc can be set to mock the Validate method
|
||||||
|
ValidateFunc func(ctx context.Context) error
|
||||||
|
|
||||||
|
// AuthenticateCalls tracks calls to Authenticate
|
||||||
|
AuthenticateCalls []struct {
|
||||||
|
Ctx *providers.AuthContext
|
||||||
|
Credentials interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SupportsCalls tracks calls to Supports
|
||||||
|
SupportsCalls []struct {
|
||||||
|
Credentials interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// InitializeCalls tracks calls to Initialize
|
||||||
|
InitializeCalls []struct {
|
||||||
|
Ctx context.Context
|
||||||
|
Config interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateCalls tracks calls to Validate
|
||||||
|
ValidateCalls []struct {
|
||||||
|
Ctx context.Context
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockAuthProvider creates a new mock auth provider with the given metadata
|
||||||
|
func NewMockAuthProvider(meta metadata.ProviderMetadata) *MockAuthProvider {
|
||||||
|
return &MockAuthProvider{
|
||||||
|
BaseAuthProvider: providers.NewBaseAuthProvider(meta),
|
||||||
|
AuthenticateCalls: make([]struct {
|
||||||
|
Ctx *providers.AuthContext
|
||||||
|
Credentials interface{}
|
||||||
|
}, 0),
|
||||||
|
SupportsCalls: make([]struct {
|
||||||
|
Credentials interface{}
|
||||||
|
}, 0),
|
||||||
|
InitializeCalls: make([]struct {
|
||||||
|
Ctx context.Context
|
||||||
|
Config interface{}
|
||||||
|
}, 0),
|
||||||
|
ValidateCalls: make([]struct {
|
||||||
|
Ctx context.Context
|
||||||
|
}, 0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authenticate mocks the AuthProvider Authenticate method
|
||||||
|
func (m *MockAuthProvider) Authenticate(ctx *providers.AuthContext, credentials interface{}) (*providers.AuthResult, error) {
|
||||||
|
m.AuthenticateCalls = append(m.AuthenticateCalls, struct {
|
||||||
|
Ctx *providers.AuthContext
|
||||||
|
Credentials interface{}
|
||||||
|
}{
|
||||||
|
Ctx: ctx,
|
||||||
|
Credentials: credentials,
|
||||||
|
})
|
||||||
|
|
||||||
|
if m.AuthenticateFunc != nil {
|
||||||
|
return m.AuthenticateFunc(ctx, credentials)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default mock implementation
|
||||||
|
return &providers.AuthResult{
|
||||||
|
Success: false,
|
||||||
|
ProviderID: m.GetMetadata().ID,
|
||||||
|
Error: errors.ErrNotImplemented,
|
||||||
|
}, errors.ErrNotImplemented
|
||||||
|
}
|
||||||
|
|
||||||
|
// Supports mocks the AuthProvider Supports method
|
||||||
|
func (m *MockAuthProvider) Supports(credentials interface{}) bool {
|
||||||
|
m.SupportsCalls = append(m.SupportsCalls, struct {
|
||||||
|
Credentials interface{}
|
||||||
|
}{
|
||||||
|
Credentials: credentials,
|
||||||
|
})
|
||||||
|
|
||||||
|
if m.SupportsFunc != nil {
|
||||||
|
return m.SupportsFunc(credentials)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default mock implementation: support all credentials
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize mocks the Provider Initialize method
|
||||||
|
func (m *MockAuthProvider) Initialize(ctx context.Context, config interface{}) error {
|
||||||
|
m.InitializeCalls = append(m.InitializeCalls, struct {
|
||||||
|
Ctx context.Context
|
||||||
|
Config interface{}
|
||||||
|
}{
|
||||||
|
Ctx: ctx,
|
||||||
|
Config: config,
|
||||||
|
})
|
||||||
|
|
||||||
|
if m.InitializeFunc != nil {
|
||||||
|
return m.InitializeFunc(ctx, config)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default mock implementation
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate mocks the Provider Validate method
|
||||||
|
func (m *MockAuthProvider) Validate(ctx context.Context) error {
|
||||||
|
m.ValidateCalls = append(m.ValidateCalls, struct {
|
||||||
|
Ctx context.Context
|
||||||
|
}{
|
||||||
|
Ctx: ctx,
|
||||||
|
})
|
||||||
|
|
||||||
|
if m.ValidateFunc != nil {
|
||||||
|
return m.ValidateFunc(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default mock implementation
|
||||||
|
return nil
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue