diff --git a/docs/PROJECT_PLAN.md b/docs/PROJECT_PLAN.md index 7ea47bd..64cb965 100644 --- a/docs/PROJECT_PLAN.md +++ b/docs/PROJECT_PLAN.md @@ -38,11 +38,11 @@ This document outlines the step-by-step implementation plan for the Auth2 librar - [x] Implement account locking mechanism ### 2.3 WebAuthn/FIDO2 as Primary Authentication -- [ ] Implement WebAuthn passwordless registration -- [ ] Create WebAuthn passwordless authentication -- [ ] Build attestation verification -- [ ] Implement credential storage and management -- [ ] Create dual-mode provider interface for both primary and MFA use +- [x] Implement WebAuthn passwordless registration +- [x] Create WebAuthn passwordless authentication +- [x] Build attestation verification +- [x] Implement credential storage and management +- [x] Create dual-mode provider interface for both primary and MFA use ### 2.4 OAuth2 Framework - [ ] Design generic OAuth2 provider @@ -77,9 +77,9 @@ This document outlines the step-by-step implementation plan for the Auth2 librar - [ ] Implement validation with drift windows ### 3.3 WebAuthn/FIDO2 as MFA -- [ ] Implement WebAuthn MFA registration -- [ ] Create WebAuthn MFA verification -- [ ] Build integration with primary authentication methods +- [x] Implement WebAuthn MFA registration +- [x] Create WebAuthn MFA verification +- [x] Build integration with primary authentication methods - [ ] Implement fallback mechanisms ### 3.4 Email OTP diff --git a/go.mod b/go.mod index e1ee9a9..a276e9d 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/Fishwaldo/auth2 go 1.24 require ( + github.com/go-webauthn/webauthn v0.13.0 github.com/google/uuid v1.6.0 github.com/stretchr/testify v1.10.0 golang.org/x/crypto v0.38.0 @@ -10,8 +11,14 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/fxamacker/cbor/v2 v2.8.0 // indirect + github.com/go-webauthn/x v0.1.21 // indirect + github.com/golang-jwt/jwt/v5 v5.2.2 // indirect + github.com/google/go-tpm v0.9.5 // indirect + github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/stretchr/objx v0.5.2 // indirect + github.com/x448/float16 v0.8.4 // indirect golang.org/x/sys v0.33.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 69cfc79..886ab7e 100644 --- a/go.sum +++ b/go.sum @@ -1,13 +1,27 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fxamacker/cbor/v2 v2.8.0 h1:fFtUGXUzXPHTIUdne5+zzMPTfffl3RD5qYnkY40vtxU= +github.com/fxamacker/cbor/v2 v2.8.0/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ= +github.com/go-webauthn/webauthn v0.13.0 h1:cJIL1/1l+22UekVhipziAaSgESJxokYkowUqAIsWs0Y= +github.com/go-webauthn/webauthn v0.13.0/go.mod h1:Oy9o2o79dbLKRPZWWgRIOdtBGAhKnDIaBp2PFkICRHs= +github.com/go-webauthn/x v0.1.21 h1:nFbckQxudvHEJn2uy1VEi713MeSpApoAv9eRqsb9AdQ= +github.com/go-webauthn/x v0.1.21/go.mod h1:sEYohtg1zL4An1TXIUIQ5csdmoO+WO0R4R2pGKaHYKA= +github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= +github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/google/go-tpm v0.9.5 h1:ocUmnDebX54dnW+MQWGQRbdaAcJELsa6PqZhJ48KwVU= +github.com/google/go-tpm v0.9.5/go.mod h1:h9jEsEECg7gtLis0upRBQU+GhYVH6jMjrFxI8u6bVUY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= +github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8= golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw= golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= diff --git a/pkg/auth/providers/webauthn/README.md b/pkg/auth/providers/webauthn/README.md new file mode 100644 index 0000000..351123f --- /dev/null +++ b/pkg/auth/providers/webauthn/README.md @@ -0,0 +1,124 @@ +# WebAuthn/FIDO2 Authentication Provider + +This package implements WebAuthn/FIDO2 authentication for the Auth2 library, supporting both passwordless authentication and multi-factor authentication (MFA). + +## Features + +- **Dual-mode support**: Can function as both primary authentication (passwordless) and MFA +- **Full WebAuthn compliance**: Uses the official go-webauthn library +- **Flexible credential storage**: Uses the StateStore interface for persistence +- **Security features**: + - Challenge validation with expiration + - Counter validation to detect cloned authenticators + - Configurable attestation requirements + - User verification options + +## Configuration + +```go +config := &webauthn.Config{ + // Relying Party settings + RPDisplayName: "My Application", + RPID: "example.com", + RPOrigins: []string{"https://example.com", "https://www.example.com"}, + + // Security preferences + AttestationPreference: webauthn.AttestationNone, + UserVerification: webauthn.UserVerificationPreferred, + ResidentKeyRequirement: webauthn.ResidentKeyPreferred, + + // Timeouts + Timeout: 60 * time.Second, + ChallengeTimeout: 5 * time.Minute, + + // Required: StateStore for persistence + StateStore: stateStore, +} +``` + +## Usage + +### As Primary Authentication (Passwordless) + +```go +// Create provider +provider, err := webauthn.New(config) + +// Registration flow +// 1. Begin registration +options, err := provider.BeginRegistration(ctx, userID, username, displayName) + +// 2. Send options to client, receive response +// 3. Complete registration +err = provider.CompleteRegistration(ctx, userID, challengeID, response) + +// Authentication flow +// 1. Begin authentication +options, err := provider.BeginAuthentication(ctx, userID) + +// 2. Send options to client, receive response +// 3. Authenticate +result, err := provider.Authenticate(authCtx, credentials) +``` + +### As MFA Provider + +```go +// Setup MFA +setupData, err := provider.Setup(ctx, userID) + +// Verify MFA +verified, err := provider.Verify(ctx, userID, code) +``` + +## Data Storage + +The provider uses the StateStore interface to persist: + +- **Challenges**: Temporary challenges with expiration +- **Credentials**: User's WebAuthn credentials (public keys, counters, etc.) + +Data is stored in these namespaces: +- `webauthn_challenges`: Active challenges +- `webauthn_credentials`: User credentials + +## Security Considerations + +1. **Origin Validation**: Always configure correct origins in `RPOrigins` +2. **RPID**: Must match the domain where authentication occurs +3. **User Verification**: Configure based on security requirements +4. **Attestation**: Set attestation preference based on trust requirements +5. **Challenge Timeout**: Balance security with user experience + +## Client Integration + +This provider requires client-side JavaScript to interact with the WebAuthn API: + +```javascript +// Registration +const credential = await navigator.credentials.create({ + publicKey: registrationOptions +}); + +// Authentication +const assertion = await navigator.credentials.get({ + publicKey: authenticationOptions +}); +``` + +## Testing + +The package includes comprehensive unit tests. To run: + +```bash +go test ./pkg/auth/providers/webauthn/... +``` + +## Dependencies + +- `github.com/go-webauthn/webauthn`: WebAuthn protocol implementation +- `github.com/Fishwaldo/auth2/pkg/plugin/metadata`: StateStore interface + +## License + +Part of the Auth2 library. \ No newline at end of file diff --git a/pkg/auth/providers/webauthn/challenge.go b/pkg/auth/providers/webauthn/challenge.go new file mode 100644 index 0000000..9cff54d --- /dev/null +++ b/pkg/auth/providers/webauthn/challenge.go @@ -0,0 +1,108 @@ +package webauthn + +import ( + "context" + "crypto/rand" + "encoding/base64" + "fmt" + "time" + + "github.com/Fishwaldo/auth2/pkg/plugin/metadata" +) + +const ( + challengeNamespace = "webauthn_challenges" + challengeLength = 32 +) + +// ChallengeManager manages WebAuthn challenges +type ChallengeManager struct { + store metadata.StateStore + timeout time.Duration +} + +// NewChallengeManager creates a new challenge manager +func NewChallengeManager(store metadata.StateStore, timeout time.Duration) *ChallengeManager { + return &ChallengeManager{ + store: store, + timeout: timeout, + } +} + +// CreateChallenge creates a new challenge for a user +func (cm *ChallengeManager) CreateChallenge(ctx context.Context, userID string, challengeType string) (*Challenge, error) { + // Generate random challenge + challengeBytes := make([]byte, challengeLength) + if _, err := rand.Read(challengeBytes); err != nil { + return nil, WrapError(err, "failed to generate challenge") + } + + // Create challenge object + challenge := &Challenge{ + ID: base64.URLEncoding.EncodeToString(challengeBytes), + UserID: userID, + Challenge: challengeBytes, + Type: challengeType, + CreatedAt: time.Now(), + ExpiresAt: time.Now().Add(cm.timeout), + } + + // Store challenge + if err := cm.store.StoreState(ctx, challengeNamespace, userID, challenge.ID, challenge); err != nil { + return nil, WrapError(err, "failed to store challenge") + } + + return challenge, nil +} + +// ValidateChallenge validates and consumes a challenge +func (cm *ChallengeManager) ValidateChallenge(ctx context.Context, userID string, challengeID string) (*Challenge, error) { + // Retrieve challenge + var challenge Challenge + if err := cm.store.GetState(ctx, challengeNamespace, userID, challengeID, &challenge); err != nil { + return nil, ErrInvalidChallenge + } + + // Check expiration + if time.Now().After(challenge.ExpiresAt) { + // Delete expired challenge + _ = cm.store.DeleteState(ctx, challengeNamespace, userID, challengeID) + return nil, ErrInvalidChallenge + } + + // Delete challenge (one-time use) + if err := cm.store.DeleteState(ctx, challengeNamespace, userID, challengeID); err != nil { + return nil, WrapError(err, "failed to delete challenge") + } + + return &challenge, nil +} + +// CleanupExpiredChallenges removes expired challenges for a user +func (cm *ChallengeManager) CleanupExpiredChallenges(ctx context.Context, userID string) error { + // List all challenges for the user + keys, err := cm.store.ListStateKeys(ctx, challengeNamespace, userID) + if err != nil { + return WrapError(err, "failed to list challenges") + } + + now := time.Now() + for _, key := range keys { + var challenge Challenge + if err := cm.store.GetState(ctx, challengeNamespace, userID, key, &challenge); err != nil { + continue // Skip invalid challenges + } + + // Delete if expired + if now.After(challenge.ExpiresAt) { + _ = cm.store.DeleteState(ctx, challengeNamespace, userID, key) + } + } + + return nil +} + +// challengeKey generates a storage key for a challenge +func challengeKey(challengeID string) string { + return fmt.Sprintf("challenge_%s", challengeID) +} \ No newline at end of file diff --git a/pkg/auth/providers/webauthn/challenge_test.go b/pkg/auth/providers/webauthn/challenge_test.go new file mode 100644 index 0000000..9db652e --- /dev/null +++ b/pkg/auth/providers/webauthn/challenge_test.go @@ -0,0 +1,250 @@ +package webauthn_test + +import ( + "context" + "encoding/base64" + "testing" + "time" + + "github.com/Fishwaldo/auth2/pkg/auth/providers/webauthn" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestChallengeManager_CreateChallenge(t *testing.T) { + mockStore := &mockStateStore{} + timeout := 5 * time.Minute + cm := webauthn.NewChallengeManager(mockStore, timeout) + ctx := context.Background() + userID := "test-user" + + // Setup mock expectations + mockStore.On("StoreState", ctx, "webauthn_challenges", userID, mock.AnythingOfType("string"), mock.AnythingOfType("*webauthn.Challenge")).Run(func(args mock.Arguments) { + challengeID := args.Get(3).(string) + challenge := args.Get(4).(*webauthn.Challenge) + + // Verify challenge properties + assert.Equal(t, challengeID, challenge.ID) + assert.Equal(t, userID, challenge.UserID) + assert.Equal(t, "registration", challenge.Type) + assert.Len(t, challenge.Challenge, 32) + assert.NotZero(t, challenge.CreatedAt) + assert.NotZero(t, challenge.ExpiresAt) + assert.True(t, challenge.ExpiresAt.After(challenge.CreatedAt)) + // Check timeout is approximately correct (within 1 second) + actualTimeout := challenge.ExpiresAt.Sub(challenge.CreatedAt) + assert.InDelta(t, timeout.Seconds(), actualTimeout.Seconds(), 1.0) + }).Return(nil).Once() + + // Call CreateChallenge + challenge, err := cm.CreateChallenge(ctx, userID, "registration") + + // Assertions + assert.NoError(t, err) + assert.NotNil(t, challenge) + assert.NotEmpty(t, challenge.ID) + assert.Equal(t, userID, challenge.UserID) + assert.Equal(t, "registration", challenge.Type) + assert.Len(t, challenge.Challenge, 32) + + // Verify ID is base64 encoded + _, err = base64.URLEncoding.DecodeString(challenge.ID) + assert.NoError(t, err) + + mockStore.AssertExpectations(t) +} + +func TestChallengeManager_ValidateChallenge(t *testing.T) { + mockStore := &mockStateStore{} + timeout := 5 * time.Minute + cm := webauthn.NewChallengeManager(mockStore, timeout) + ctx := context.Background() + userID := "test-user" + challengeID := "test-challenge-id" + + t.Run("valid challenge", func(t *testing.T) { + validChallenge := &webauthn.Challenge{ + ID: challengeID, + UserID: userID, + Challenge: []byte("test-challenge"), + Type: "authentication", + CreatedAt: time.Now().Add(-1 * time.Minute), + ExpiresAt: time.Now().Add(4 * time.Minute), + } + + // Setup mock expectations + mockStore.On("GetState", ctx, "webauthn_challenges", userID, challengeID, mock.AnythingOfType("*webauthn.Challenge")).Run(func(args mock.Arguments) { + challenge := args.Get(4).(*webauthn.Challenge) + *challenge = *validChallenge + }).Return(nil).Once() + + mockStore.On("DeleteState", ctx, "webauthn_challenges", userID, challengeID).Return(nil).Once() + + // Call ValidateChallenge + result, err := cm.ValidateChallenge(ctx, userID, challengeID) + + // Assertions + assert.NoError(t, err) + assert.NotNil(t, result) + assert.Equal(t, validChallenge.ID, result.ID) + assert.Equal(t, validChallenge.UserID, result.UserID) + assert.Equal(t, validChallenge.Challenge, result.Challenge) + + mockStore.AssertExpectations(t) + }) + + t.Run("expired challenge", func(t *testing.T) { + expiredChallenge := &webauthn.Challenge{ + ID: challengeID, + UserID: userID, + Challenge: []byte("test-challenge"), + Type: "authentication", + CreatedAt: time.Now().Add(-10 * time.Minute), + ExpiresAt: time.Now().Add(-5 * time.Minute), // Expired + } + + // Setup mock expectations + mockStore.On("GetState", ctx, "webauthn_challenges", userID, challengeID, mock.AnythingOfType("*webauthn.Challenge")).Run(func(args mock.Arguments) { + challenge := args.Get(4).(*webauthn.Challenge) + *challenge = *expiredChallenge + }).Return(nil).Once() + + // Should try to delete expired challenge + mockStore.On("DeleteState", ctx, "webauthn_challenges", userID, challengeID).Return(nil).Once() + + // Call ValidateChallenge + result, err := cm.ValidateChallenge(ctx, userID, challengeID) + + // Should fail with invalid challenge error + assert.Error(t, err) + assert.Equal(t, webauthn.ErrInvalidChallenge, err) + assert.Nil(t, result) + + mockStore.AssertExpectations(t) + }) + + t.Run("challenge not found", func(t *testing.T) { + // Setup mock expectations + mockStore.On("GetState", ctx, "webauthn_challenges", userID, challengeID, mock.AnythingOfType("*webauthn.Challenge")).Return(assert.AnError).Once() + + // Call ValidateChallenge + result, err := cm.ValidateChallenge(ctx, userID, challengeID) + + // Should fail with invalid challenge error + assert.Error(t, err) + assert.Equal(t, webauthn.ErrInvalidChallenge, err) + assert.Nil(t, result) + + mockStore.AssertExpectations(t) + }) + + t.Run("delete fails", func(t *testing.T) { + validChallenge := &webauthn.Challenge{ + ID: challengeID, + UserID: userID, + Challenge: []byte("test-challenge"), + Type: "authentication", + CreatedAt: time.Now().Add(-1 * time.Minute), + ExpiresAt: time.Now().Add(4 * time.Minute), + } + + // Setup mock expectations + mockStore.On("GetState", ctx, "webauthn_challenges", userID, challengeID, mock.AnythingOfType("*webauthn.Challenge")).Run(func(args mock.Arguments) { + challenge := args.Get(4).(*webauthn.Challenge) + *challenge = *validChallenge + }).Return(nil).Once() + + mockStore.On("DeleteState", ctx, "webauthn_challenges", userID, challengeID).Return(assert.AnError).Once() + + // Call ValidateChallenge + result, err := cm.ValidateChallenge(ctx, userID, challengeID) + + // Should fail with wrapped error + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to delete challenge") + assert.Nil(t, result) + + mockStore.AssertExpectations(t) + }) +} + +func TestChallengeManager_CleanupExpiredChallenges(t *testing.T) { + mockStore := &mockStateStore{} + timeout := 5 * time.Minute + cm := webauthn.NewChallengeManager(mockStore, timeout) + ctx := context.Background() + userID := "test-user" + + // Setup test data + challenges := []struct { + id string + expired bool + }{ + {"challenge1", false}, + {"challenge2", true}, + {"challenge3", false}, + {"challenge4", true}, + } + + challengeKeys := make([]string, len(challenges)) + for i, c := range challenges { + challengeKeys[i] = c.id + } + + // Setup mock expectations + mockStore.On("ListStateKeys", ctx, "webauthn_challenges", userID).Return(challengeKeys, nil).Once() + + // Setup expectations for each challenge + for _, c := range challenges { + challenge := &webauthn.Challenge{ + ID: c.id, + UserID: userID, + Challenge: []byte("test-challenge"), + Type: "authentication", + } + + if c.expired { + challenge.CreatedAt = time.Now().Add(-10 * time.Minute) + challenge.ExpiresAt = time.Now().Add(-5 * time.Minute) + } else { + challenge.CreatedAt = time.Now().Add(-1 * time.Minute) + challenge.ExpiresAt = time.Now().Add(4 * time.Minute) + } + + mockStore.On("GetState", ctx, "webauthn_challenges", userID, c.id, mock.AnythingOfType("*webauthn.Challenge")).Run(func(args mock.Arguments) { + ch := args.Get(4).(*webauthn.Challenge) + *ch = *challenge + }).Return(nil).Once() + + // Only expired challenges should be deleted + if c.expired { + mockStore.On("DeleteState", ctx, "webauthn_challenges", userID, c.id).Return(nil).Once() + } + } + + // Call CleanupExpiredChallenges + err := cm.CleanupExpiredChallenges(ctx, userID) + + // Assertions + assert.NoError(t, err) + mockStore.AssertExpectations(t) +} + +func TestChallengeManager_CleanupExpiredChallenges_ListError(t *testing.T) { + mockStore := &mockStateStore{} + timeout := 5 * time.Minute + cm := webauthn.NewChallengeManager(mockStore, timeout) + ctx := context.Background() + userID := "test-user" + + // Setup mock expectations - list fails + mockStore.On("ListStateKeys", ctx, "webauthn_challenges", userID).Return([]string{}, assert.AnError).Once() + + // Call CleanupExpiredChallenges + err := cm.CleanupExpiredChallenges(ctx, userID) + + // Should fail with wrapped error + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to list challenges") + mockStore.AssertExpectations(t) +} \ No newline at end of file diff --git a/pkg/auth/providers/webauthn/config.go b/pkg/auth/providers/webauthn/config.go new file mode 100644 index 0000000..7ac7d39 --- /dev/null +++ b/pkg/auth/providers/webauthn/config.go @@ -0,0 +1,80 @@ +package webauthn + +import ( + "time" + "github.com/Fishwaldo/auth2/pkg/plugin/metadata" +) + +// Config represents the configuration for the WebAuthn provider +type Config struct { + // Relying Party settings + RPDisplayName string `json:"rp_display_name"` + RPID string `json:"rp_id"` + RPOrigins []string `json:"rp_origins"` + + // Security preferences + AttestationPreference AttestationConveyancePreference `json:"attestation_preference"` + UserVerification UserVerificationRequirement `json:"user_verification"` + ResidentKeyRequirement ResidentKeyRequirement `json:"resident_key_requirement"` + + // Authenticator preferences + AuthenticatorAttachment AuthenticatorAttachment `json:"authenticator_attachment,omitempty"` + RequireResidentKey bool `json:"require_resident_key"` + + // Timeouts + Timeout time.Duration `json:"timeout"` // Registration/authentication timeout + ChallengeTimeout time.Duration `json:"challenge_timeout"` // How long challenges are valid + + // Supported algorithms (COSE algorithm identifiers) + // Default: ES256 (-7), RS256 (-257) + SupportedAlgorithms []int64 `json:"supported_algorithms,omitempty"` + + // StateStore for persistence + StateStore metadata.StateStore `json:"-"` + + // Debug mode + Debug bool `json:"debug"` +} + +// DefaultConfig returns a default WebAuthn configuration +func DefaultConfig() *Config { + return &Config{ + RPDisplayName: "Auth2 Application", + RPID: "localhost", + RPOrigins: []string{"http://localhost", "https://localhost"}, + AttestationPreference: AttestationNone, + UserVerification: UserVerificationPreferred, + ResidentKeyRequirement: ResidentKeyPreferred, + RequireResidentKey: false, + Timeout: 60 * time.Second, + ChallengeTimeout: 5 * time.Minute, + SupportedAlgorithms: []int64{-7, -257}, // ES256, RS256 + Debug: false, + } +} + +// Validate validates the configuration +func (c *Config) Validate() error { + if c.RPDisplayName == "" { + return ErrInvalidConfig("rp_display_name is required") + } + if c.RPID == "" { + return ErrInvalidConfig("rp_id is required") + } + if len(c.RPOrigins) == 0 { + return ErrInvalidConfig("at least one rp_origin is required") + } + if c.Timeout <= 0 { + c.Timeout = 60 * time.Second + } + if c.ChallengeTimeout <= 0 { + c.ChallengeTimeout = 5 * time.Minute + } + if len(c.SupportedAlgorithms) == 0 { + c.SupportedAlgorithms = []int64{-7, -257} // ES256, RS256 + } + if c.StateStore == nil { + return ErrInvalidConfig("state_store is required") + } + return nil +} \ No newline at end of file diff --git a/pkg/auth/providers/webauthn/credential.go b/pkg/auth/providers/webauthn/credential.go new file mode 100644 index 0000000..5a4ea70 --- /dev/null +++ b/pkg/auth/providers/webauthn/credential.go @@ -0,0 +1,160 @@ +package webauthn + +import ( + "context" + "fmt" + "time" + + "github.com/Fishwaldo/auth2/pkg/plugin/metadata" +) + +const ( + credentialNamespace = "webauthn_credentials" +) + +// CredentialStore manages WebAuthn credentials +type CredentialStore struct { + store metadata.StateStore +} + +// NewCredentialStore creates a new credential store +func NewCredentialStore(store metadata.StateStore) *CredentialStore { + return &CredentialStore{ + store: store, + } +} + +// GetUserCredentials retrieves all credentials for a user +func (cs *CredentialStore) GetUserCredentials(ctx context.Context, userID string) (*UserCredentials, error) { + var userCreds UserCredentials + err := cs.store.GetState(ctx, credentialNamespace, userID, "credentials", &userCreds) + if err != nil { + // If not found, return empty credentials + return &UserCredentials{ + UserID: userID, + Credentials: []Credential{}, + }, nil + } + return &userCreds, nil +} + +// AddCredential adds a new credential for a user +func (cs *CredentialStore) AddCredential(ctx context.Context, userID string, credential *Credential) error { + // Get existing credentials + userCreds, err := cs.GetUserCredentials(ctx, userID) + if err != nil { + return WrapError(err, "failed to get user credentials") + } + + // Check for duplicate + for _, existing := range userCreds.Credentials { + if string(existing.ID) == string(credential.ID) { + return ErrDuplicateCredential + } + } + + // Add new credential + credential.CreatedAt = time.Now() + credential.LastUsedAt = time.Now() + userCreds.Credentials = append(userCreds.Credentials, *credential) + + // Store updated credentials + if err := cs.store.StoreState(ctx, credentialNamespace, userID, "credentials", userCreds); err != nil { + return WrapError(err, "failed to store credentials") + } + + return nil +} + +// GetCredential retrieves a specific credential +func (cs *CredentialStore) GetCredential(ctx context.Context, userID string, credentialID []byte) (*Credential, error) { + userCreds, err := cs.GetUserCredentials(ctx, userID) + if err != nil { + return nil, err + } + + for i := range userCreds.Credentials { + if string(userCreds.Credentials[i].ID) == string(credentialID) { + return &userCreds.Credentials[i], nil + } + } + + return nil, ErrCredentialNotFound +} + +// UpdateCredential updates an existing credential +func (cs *CredentialStore) UpdateCredential(ctx context.Context, userID string, credential *Credential) error { + userCreds, err := cs.GetUserCredentials(ctx, userID) + if err != nil { + return WrapError(err, "failed to get user credentials") + } + + found := false + for i := range userCreds.Credentials { + if string(userCreds.Credentials[i].ID) == string(credential.ID) { + credential.LastUsedAt = time.Now() + userCreds.Credentials[i] = *credential + found = true + break + } + } + + if !found { + return ErrCredentialNotFound + } + + // Store updated credentials + if err := cs.store.StoreState(ctx, credentialNamespace, userID, "credentials", userCreds); err != nil { + return WrapError(err, "failed to update credentials") + } + + return nil +} + +// RemoveCredential removes a credential from a user +func (cs *CredentialStore) RemoveCredential(ctx context.Context, userID string, credentialID []byte) error { + userCreds, err := cs.GetUserCredentials(ctx, userID) + if err != nil { + return WrapError(err, "failed to get user credentials") + } + + // Filter out the credential to remove + filtered := make([]Credential, 0, len(userCreds.Credentials)) + found := false + for _, cred := range userCreds.Credentials { + if string(cred.ID) != string(credentialID) { + filtered = append(filtered, cred) + } else { + found = true + } + } + + if !found { + return ErrCredentialNotFound + } + + userCreds.Credentials = filtered + + // Store updated credentials + if err := cs.store.StoreState(ctx, credentialNamespace, userID, "credentials", userCreds); err != nil { + return WrapError(err, "failed to update credentials") + } + + return nil +} + +// ListAllCredentials lists all credentials for all users (admin function) +func (cs *CredentialStore) ListAllCredentials(ctx context.Context) (map[string]*UserCredentials, error) { + // This would need to be implemented based on the specific StateStore implementation + // For now, return an error indicating it's not supported + return nil, fmt.Errorf("listing all credentials is not supported") +} + +// HasCredentials checks if a user has any credentials +func (cs *CredentialStore) HasCredentials(ctx context.Context, userID string) (bool, error) { + userCreds, err := cs.GetUserCredentials(ctx, userID) + if err != nil { + return false, err + } + return len(userCreds.Credentials) > 0, nil +} \ No newline at end of file diff --git a/pkg/auth/providers/webauthn/credential_test.go b/pkg/auth/providers/webauthn/credential_test.go new file mode 100644 index 0000000..e001c1c --- /dev/null +++ b/pkg/auth/providers/webauthn/credential_test.go @@ -0,0 +1,425 @@ +package webauthn_test + +import ( + "context" + "testing" + + "github.com/Fishwaldo/auth2/pkg/auth/providers/webauthn" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestCredentialStore_GetUserCredentials(t *testing.T) { + mockStore := &mockStateStore{} + credStore := webauthn.NewCredentialStore(mockStore) + ctx := context.Background() + userID := "test-user" + + t.Run("existing credentials", func(t *testing.T) { + // Setup test data + testCreds := &webauthn.UserCredentials{ + UserID: userID, + Credentials: []webauthn.Credential{ + { + ID: []byte("cred1"), + PublicKey: []byte("pubkey1"), + }, + }, + } + + // Setup mock expectations + mockStore.On("GetState", ctx, "webauthn_credentials", userID, "credentials", mock.AnythingOfType("*webauthn.UserCredentials")).Run(func(args mock.Arguments) { + userCreds := args.Get(4).(*webauthn.UserCredentials) + *userCreds = *testCreds + }).Return(nil).Once() + + // Call GetUserCredentials + result, err := credStore.GetUserCredentials(ctx, userID) + + // Assertions + assert.NoError(t, err) + assert.Equal(t, userID, result.UserID) + assert.Len(t, result.Credentials, 1) + assert.Equal(t, []byte("cred1"), result.Credentials[0].ID) + + mockStore.AssertExpectations(t) + }) + + t.Run("no credentials", func(t *testing.T) { + // Setup mock expectations - return error (not found) + mockStore.On("GetState", ctx, "webauthn_credentials", userID, "credentials", mock.AnythingOfType("*webauthn.UserCredentials")).Return(assert.AnError).Once() + + // Call GetUserCredentials + result, err := credStore.GetUserCredentials(ctx, userID) + + // Should return empty credentials, not error + assert.NoError(t, err) + assert.Equal(t, userID, result.UserID) + assert.Len(t, result.Credentials, 0) + + mockStore.AssertExpectations(t) + }) +} + +func TestCredentialStore_AddCredential(t *testing.T) { + mockStore := &mockStateStore{} + credStore := webauthn.NewCredentialStore(mockStore) + ctx := context.Background() + userID := "test-user" + + t.Run("add first credential", func(t *testing.T) { + newCred := &webauthn.Credential{ + ID: []byte("new-cred"), + PublicKey: []byte("new-pubkey"), + } + + // Setup mock expectations + mockStore.On("GetState", ctx, "webauthn_credentials", userID, "credentials", mock.AnythingOfType("*webauthn.UserCredentials")).Return(assert.AnError).Once() + + mockStore.On("StoreState", ctx, "webauthn_credentials", userID, "credentials", mock.AnythingOfType("*webauthn.UserCredentials")).Run(func(args mock.Arguments) { + stored := args.Get(4).(*webauthn.UserCredentials) + assert.Equal(t, userID, stored.UserID) + assert.Len(t, stored.Credentials, 1) + assert.Equal(t, newCred.ID, stored.Credentials[0].ID) + assert.NotZero(t, stored.Credentials[0].CreatedAt) + assert.NotZero(t, stored.Credentials[0].LastUsedAt) + }).Return(nil).Once() + + // Call AddCredential + err := credStore.AddCredential(ctx, userID, newCred) + + // Assertions + assert.NoError(t, err) + mockStore.AssertExpectations(t) + }) + + t.Run("add additional credential", func(t *testing.T) { + existingCreds := &webauthn.UserCredentials{ + UserID: userID, + Credentials: []webauthn.Credential{ + { + ID: []byte("existing-cred"), + PublicKey: []byte("existing-pubkey"), + }, + }, + } + + newCred := &webauthn.Credential{ + ID: []byte("new-cred"), + PublicKey: []byte("new-pubkey"), + } + + // Setup mock expectations + mockStore.On("GetState", ctx, "webauthn_credentials", userID, "credentials", mock.AnythingOfType("*webauthn.UserCredentials")).Run(func(args mock.Arguments) { + userCreds := args.Get(4).(*webauthn.UserCredentials) + *userCreds = *existingCreds + }).Return(nil).Once() + + mockStore.On("StoreState", ctx, "webauthn_credentials", userID, "credentials", mock.AnythingOfType("*webauthn.UserCredentials")).Run(func(args mock.Arguments) { + stored := args.Get(4).(*webauthn.UserCredentials) + assert.Equal(t, userID, stored.UserID) + assert.Len(t, stored.Credentials, 2) + assert.Equal(t, []byte("existing-cred"), stored.Credentials[0].ID) + assert.Equal(t, newCred.ID, stored.Credentials[1].ID) + }).Return(nil).Once() + + // Call AddCredential + err := credStore.AddCredential(ctx, userID, newCred) + + // Assertions + assert.NoError(t, err) + mockStore.AssertExpectations(t) + }) + + t.Run("duplicate credential", func(t *testing.T) { + existingCreds := &webauthn.UserCredentials{ + UserID: userID, + Credentials: []webauthn.Credential{ + { + ID: []byte("existing-cred"), + PublicKey: []byte("existing-pubkey"), + }, + }, + } + + duplicateCred := &webauthn.Credential{ + ID: []byte("existing-cred"), // Same ID + PublicKey: []byte("different-pubkey"), + } + + // Setup mock expectations + mockStore.On("GetState", ctx, "webauthn_credentials", userID, "credentials", mock.AnythingOfType("*webauthn.UserCredentials")).Run(func(args mock.Arguments) { + userCreds := args.Get(4).(*webauthn.UserCredentials) + *userCreds = *existingCreds + }).Return(nil).Once() + + // Call AddCredential + err := credStore.AddCredential(ctx, userID, duplicateCred) + + // Should fail with duplicate error + assert.Error(t, err) + assert.Equal(t, webauthn.ErrDuplicateCredential, err) + mockStore.AssertExpectations(t) + }) +} + +func TestCredentialStore_GetCredential(t *testing.T) { + mockStore := &mockStateStore{} + credStore := webauthn.NewCredentialStore(mockStore) + ctx := context.Background() + userID := "test-user" + + t.Run("credential exists", func(t *testing.T) { + testCreds := &webauthn.UserCredentials{ + UserID: userID, + Credentials: []webauthn.Credential{ + { + ID: []byte("cred1"), + PublicKey: []byte("pubkey1"), + }, + { + ID: []byte("cred2"), + PublicKey: []byte("pubkey2"), + }, + }, + } + + // Setup mock expectations + mockStore.On("GetState", ctx, "webauthn_credentials", userID, "credentials", mock.AnythingOfType("*webauthn.UserCredentials")).Run(func(args mock.Arguments) { + userCreds := args.Get(4).(*webauthn.UserCredentials) + *userCreds = *testCreds + }).Return(nil).Once() + + // Call GetCredential + result, err := credStore.GetCredential(ctx, userID, []byte("cred2")) + + // Assertions + assert.NoError(t, err) + assert.NotNil(t, result) + assert.Equal(t, []byte("cred2"), result.ID) + assert.Equal(t, []byte("pubkey2"), result.PublicKey) + + mockStore.AssertExpectations(t) + }) + + t.Run("credential not found", func(t *testing.T) { + testCreds := &webauthn.UserCredentials{ + UserID: userID, + Credentials: []webauthn.Credential{ + { + ID: []byte("cred1"), + PublicKey: []byte("pubkey1"), + }, + }, + } + + // Setup mock expectations + mockStore.On("GetState", ctx, "webauthn_credentials", userID, "credentials", mock.AnythingOfType("*webauthn.UserCredentials")).Run(func(args mock.Arguments) { + userCreds := args.Get(4).(*webauthn.UserCredentials) + *userCreds = *testCreds + }).Return(nil).Once() + + // Call GetCredential + result, err := credStore.GetCredential(ctx, userID, []byte("nonexistent")) + + // Should fail with not found error + assert.Error(t, err) + assert.Equal(t, webauthn.ErrCredentialNotFound, err) + assert.Nil(t, result) + + mockStore.AssertExpectations(t) + }) +} + +func TestCredentialStore_UpdateCredential(t *testing.T) { + mockStore := &mockStateStore{} + credStore := webauthn.NewCredentialStore(mockStore) + ctx := context.Background() + userID := "test-user" + + t.Run("update existing credential", func(t *testing.T) { + existingCreds := &webauthn.UserCredentials{ + UserID: userID, + Credentials: []webauthn.Credential{ + { + ID: []byte("cred1"), + PublicKey: []byte("pubkey1"), + Counter: 10, + }, + }, + } + + updatedCred := &webauthn.Credential{ + ID: []byte("cred1"), + PublicKey: []byte("pubkey1"), + Counter: 11, // Updated counter + } + + // Setup mock expectations + mockStore.On("GetState", ctx, "webauthn_credentials", userID, "credentials", mock.AnythingOfType("*webauthn.UserCredentials")).Run(func(args mock.Arguments) { + userCreds := args.Get(4).(*webauthn.UserCredentials) + *userCreds = *existingCreds + }).Return(nil).Once() + + mockStore.On("StoreState", ctx, "webauthn_credentials", userID, "credentials", mock.AnythingOfType("*webauthn.UserCredentials")).Run(func(args mock.Arguments) { + stored := args.Get(4).(*webauthn.UserCredentials) + assert.Equal(t, userID, stored.UserID) + assert.Len(t, stored.Credentials, 1) + assert.Equal(t, uint32(11), stored.Credentials[0].Counter) + assert.NotZero(t, stored.Credentials[0].LastUsedAt) + }).Return(nil).Once() + + // Call UpdateCredential + err := credStore.UpdateCredential(ctx, userID, updatedCred) + + // Assertions + assert.NoError(t, err) + mockStore.AssertExpectations(t) + }) + + t.Run("update nonexistent credential", func(t *testing.T) { + existingCreds := &webauthn.UserCredentials{ + UserID: userID, + Credentials: []webauthn.Credential{ + { + ID: []byte("cred1"), + PublicKey: []byte("pubkey1"), + }, + }, + } + + nonexistentCred := &webauthn.Credential{ + ID: []byte("nonexistent"), + PublicKey: []byte("pubkey"), + } + + // Setup mock expectations + mockStore.On("GetState", ctx, "webauthn_credentials", userID, "credentials", mock.AnythingOfType("*webauthn.UserCredentials")).Run(func(args mock.Arguments) { + userCreds := args.Get(4).(*webauthn.UserCredentials) + *userCreds = *existingCreds + }).Return(nil).Once() + + // Call UpdateCredential + err := credStore.UpdateCredential(ctx, userID, nonexistentCred) + + // Should fail with not found error + assert.Error(t, err) + assert.Equal(t, webauthn.ErrCredentialNotFound, err) + mockStore.AssertExpectations(t) + }) +} + +func TestCredentialStore_RemoveCredential(t *testing.T) { + mockStore := &mockStateStore{} + credStore := webauthn.NewCredentialStore(mockStore) + ctx := context.Background() + userID := "test-user" + + t.Run("remove existing credential", func(t *testing.T) { + existingCreds := &webauthn.UserCredentials{ + UserID: userID, + Credentials: []webauthn.Credential{ + { + ID: []byte("cred1"), + PublicKey: []byte("pubkey1"), + }, + { + ID: []byte("cred2"), + PublicKey: []byte("pubkey2"), + }, + }, + } + + // Setup mock expectations + mockStore.On("GetState", ctx, "webauthn_credentials", userID, "credentials", mock.AnythingOfType("*webauthn.UserCredentials")).Run(func(args mock.Arguments) { + userCreds := args.Get(4).(*webauthn.UserCredentials) + *userCreds = *existingCreds + }).Return(nil).Once() + + mockStore.On("StoreState", ctx, "webauthn_credentials", userID, "credentials", mock.AnythingOfType("*webauthn.UserCredentials")).Run(func(args mock.Arguments) { + stored := args.Get(4).(*webauthn.UserCredentials) + assert.Equal(t, userID, stored.UserID) + assert.Len(t, stored.Credentials, 1) + assert.Equal(t, []byte("cred2"), stored.Credentials[0].ID) + }).Return(nil).Once() + + // Call RemoveCredential + err := credStore.RemoveCredential(ctx, userID, []byte("cred1")) + + // Assertions + assert.NoError(t, err) + mockStore.AssertExpectations(t) + }) + + t.Run("remove nonexistent credential", func(t *testing.T) { + existingCreds := &webauthn.UserCredentials{ + UserID: userID, + Credentials: []webauthn.Credential{ + { + ID: []byte("cred1"), + PublicKey: []byte("pubkey1"), + }, + }, + } + + // Setup mock expectations + mockStore.On("GetState", ctx, "webauthn_credentials", userID, "credentials", mock.AnythingOfType("*webauthn.UserCredentials")).Run(func(args mock.Arguments) { + userCreds := args.Get(4).(*webauthn.UserCredentials) + *userCreds = *existingCreds + }).Return(nil).Once() + + // Call RemoveCredential + err := credStore.RemoveCredential(ctx, userID, []byte("nonexistent")) + + // Should fail with not found error + assert.Error(t, err) + assert.Equal(t, webauthn.ErrCredentialNotFound, err) + mockStore.AssertExpectations(t) + }) +} + +func TestCredentialStore_HasCredentials(t *testing.T) { + mockStore := &mockStateStore{} + credStore := webauthn.NewCredentialStore(mockStore) + ctx := context.Background() + userID := "test-user" + + t.Run("user has credentials", func(t *testing.T) { + testCreds := &webauthn.UserCredentials{ + UserID: userID, + Credentials: []webauthn.Credential{ + { + ID: []byte("cred1"), + PublicKey: []byte("pubkey1"), + }, + }, + } + + // Setup mock expectations + mockStore.On("GetState", ctx, "webauthn_credentials", userID, "credentials", mock.AnythingOfType("*webauthn.UserCredentials")).Run(func(args mock.Arguments) { + userCreds := args.Get(4).(*webauthn.UserCredentials) + *userCreds = *testCreds + }).Return(nil).Once() + + // Call HasCredentials + result, err := credStore.HasCredentials(ctx, userID) + + // Assertions + assert.NoError(t, err) + assert.True(t, result) + mockStore.AssertExpectations(t) + }) + + t.Run("user has no credentials", func(t *testing.T) { + // Setup mock expectations - return error (not found) + mockStore.On("GetState", ctx, "webauthn_credentials", userID, "credentials", mock.AnythingOfType("*webauthn.UserCredentials")).Return(assert.AnError).Once() + + // Call HasCredentials + result, err := credStore.HasCredentials(ctx, userID) + + // Assertions + assert.NoError(t, err) + assert.False(t, result) + mockStore.AssertExpectations(t) + }) +} \ No newline at end of file diff --git a/pkg/auth/providers/webauthn/errors.go b/pkg/auth/providers/webauthn/errors.go new file mode 100644 index 0000000..c26081c --- /dev/null +++ b/pkg/auth/providers/webauthn/errors.go @@ -0,0 +1,48 @@ +package webauthn + +import ( + "fmt" + "github.com/Fishwaldo/auth2/internal/errors" +) + +var ( + // ErrInvalidChallenge is returned when a challenge is invalid or expired + ErrInvalidChallenge = errors.New("invalid or expired challenge") + + // ErrCredentialNotFound is returned when a credential is not found + ErrCredentialNotFound = errors.New("credential not found") + + // ErrInvalidCredential is returned when a credential is invalid + ErrInvalidCredential = errors.New("invalid credential") + + // ErrUserNotFound is returned when a user is not found + ErrUserNotFound = errors.New("user not found") + + // ErrRegistrationFailed is returned when registration fails + ErrRegistrationFailed = errors.New("registration failed") + + // ErrAuthenticationFailed is returned when authentication fails + ErrAuthenticationFailed = errors.New("authentication failed") + + // ErrInvalidOrigin is returned when the origin is not allowed + ErrInvalidOrigin = errors.New("invalid origin") + + // ErrCounterError is returned when the counter validation fails + ErrCounterError = errors.New("counter validation failed") + + // ErrInvalidUserVerification is returned when user verification fails + ErrInvalidUserVerification = errors.New("user verification failed") + + // ErrDuplicateCredential is returned when trying to register a duplicate credential + ErrDuplicateCredential = errors.New("credential already registered") +) + +// ErrInvalidConfig creates a configuration error +func ErrInvalidConfig(msg string) error { + return fmt.Errorf("invalid webauthn config: %s", msg) +} + +// WrapError wraps an error with additional context +func WrapError(err error, msg string) error { + return fmt.Errorf("%s: %w", msg, err) +} \ No newline at end of file diff --git a/pkg/auth/providers/webauthn/factory.go b/pkg/auth/providers/webauthn/factory.go new file mode 100644 index 0000000..13c4b5a --- /dev/null +++ b/pkg/auth/providers/webauthn/factory.go @@ -0,0 +1,136 @@ +package webauthn + +import ( + "fmt" + + "github.com/Fishwaldo/auth2/pkg/auth/providers" + "github.com/Fishwaldo/auth2/pkg/plugin/metadata" + "github.com/Fishwaldo/auth2/pkg/plugin/registry" +) + +// Factory creates WebAuthn provider instances +type Factory struct { + defaultConfig *Config +} + +// NewFactory creates a new WebAuthn provider factory +func NewFactory(defaultConfig *Config) *Factory { + if defaultConfig == nil { + defaultConfig = DefaultConfig() + } + return &Factory{ + defaultConfig: defaultConfig, + } +} + +// Create creates a new WebAuthn provider instance +func (f *Factory) Create(config interface{}) (metadata.Provider, error) { + var cfg *Config + + switch c := config.(type) { + case *Config: + cfg = c + case Config: + cfg = &c + case map[string]interface{}: + // Parse config from map + cfg = f.defaultConfig + + if v, ok := c["rp_display_name"].(string); ok { + cfg.RPDisplayName = v + } + if v, ok := c["rp_id"].(string); ok { + cfg.RPID = v + } + if v, ok := c["rp_origins"].([]string); ok { + cfg.RPOrigins = v + } else if v, ok := c["rp_origins"].([]interface{}); ok { + origins := make([]string, len(v)) + for i, o := range v { + if s, ok := o.(string); ok { + origins[i] = s + } + } + cfg.RPOrigins = origins + } + + // Parse security preferences + if v, ok := c["attestation_preference"].(string); ok { + cfg.AttestationPreference = AttestationConveyancePreference(v) + } + if v, ok := c["user_verification"].(string); ok { + cfg.UserVerification = UserVerificationRequirement(v) + } + if v, ok := c["resident_key_requirement"].(string); ok { + cfg.ResidentKeyRequirement = ResidentKeyRequirement(v) + } + + // StateStore must be provided + if v, ok := c["state_store"].(metadata.StateStore); ok { + cfg.StateStore = v + } + case nil: + cfg = f.defaultConfig + default: + return nil, fmt.Errorf("unsupported config type: %T", config) + } + + // Validate state store + if cfg.StateStore == nil { + return nil, ErrInvalidConfig("state_store is required") + } + + return New(cfg) +} + +// GetType returns the provider type +func (f *Factory) GetType() metadata.ProviderType { + return metadata.ProviderTypeAuth +} + +// GetMetadata returns the provider metadata +func (f *Factory) GetMetadata() metadata.ProviderMetadata { + return metadata.ProviderMetadata{ + ID: "webauthn", + Type: metadata.ProviderTypeAuth, + Name: "WebAuthn", + Description: "WebAuthn/FIDO2 passwordless authentication and MFA", + Version: "1.0.0", + Author: "auth2", + } +} + +// Register registers the WebAuthn provider with the registry +func Register(r *registry.Registry, config *Config) error { + // Create provider instance + provider, err := New(config) + if err != nil { + return err + } + + // Register the provider + return r.RegisterProvider(provider) +} + +// CreateAuthProvider creates a WebAuthn provider as an AuthProvider +func CreateAuthProvider(config *Config) (providers.AuthProvider, error) { + provider, err := New(config) + if err != nil { + return nil, err + } + return provider, nil +} + +// CreateMFAProvider creates a WebAuthn provider as an MFAProvider +func CreateMFAProvider(config *Config) (metadata.MFAProvider, error) { + provider, err := New(config) + if err != nil { + return nil, err + } + return provider, nil +} + +// CreateDualModeProvider creates a WebAuthn provider that can function as both auth and MFA +func CreateDualModeProvider(config *Config) (*Provider, error) { + return New(config) +} \ No newline at end of file diff --git a/pkg/auth/providers/webauthn/factory_test.go b/pkg/auth/providers/webauthn/factory_test.go new file mode 100644 index 0000000..689da31 --- /dev/null +++ b/pkg/auth/providers/webauthn/factory_test.go @@ -0,0 +1,337 @@ +package webauthn_test + +import ( + "testing" + + "github.com/Fishwaldo/auth2/pkg/auth/providers" + "github.com/Fishwaldo/auth2/pkg/auth/providers/webauthn" + "github.com/Fishwaldo/auth2/pkg/plugin/factory" + "github.com/Fishwaldo/auth2/pkg/plugin/metadata" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestFactory_Create(t *testing.T) { + mockStore := &mockStateStore{} + defaultConfig := &webauthn.Config{ + RPDisplayName: "Default App", + RPID: "default.com", + RPOrigins: []string{"https://default.com"}, + StateStore: mockStore, + } + + factory := webauthn.NewFactory(defaultConfig) + + tests := []struct { + name string + config interface{} + expectedError string + validate func(*testing.T, metadata.Provider) + }{ + { + name: "valid *Config", + config: &webauthn.Config{ + RPDisplayName: "Test App", + RPID: "test.com", + RPOrigins: []string{"https://test.com"}, + StateStore: mockStore, + }, + expectedError: "", + validate: func(t *testing.T, p metadata.Provider) { + assert.NotNil(t, p) + // Verify it implements both interfaces + _, isAuth := p.(providers.AuthProvider) + assert.True(t, isAuth) + _, isMFA := p.(metadata.MFAProvider) + assert.True(t, isMFA) + }, + }, + { + name: "valid Config value", + config: webauthn.Config{ + RPDisplayName: "Test App", + RPID: "test.com", + RPOrigins: []string{"https://test.com"}, + StateStore: mockStore, + }, + expectedError: "", + validate: func(t *testing.T, p metadata.Provider) { + assert.NotNil(t, p) + }, + }, + { + name: "valid map config", + config: map[string]interface{}{ + "rp_display_name": "Map App", + "rp_id": "map.com", + "rp_origins": []string{"https://map.com"}, + "state_store": mockStore, + }, + expectedError: "", + validate: func(t *testing.T, p metadata.Provider) { + assert.NotNil(t, p) + }, + }, + { + name: "map config with interface origins", + config: map[string]interface{}{ + "rp_display_name": "Map App", + "rp_id": "map.com", + "rp_origins": []interface{}{"https://map.com", "https://www.map.com"}, + "state_store": mockStore, + }, + expectedError: "", + validate: func(t *testing.T, p metadata.Provider) { + assert.NotNil(t, p) + }, + }, + { + name: "nil config uses default", + config: nil, + expectedError: "", + validate: func(t *testing.T, p metadata.Provider) { + assert.NotNil(t, p) + }, + }, + { + name: "missing state store", + config: &webauthn.Config{ + RPDisplayName: "Test App", + RPID: "test.com", + RPOrigins: []string{"https://test.com"}, + // No StateStore + }, + expectedError: "state_store is required", + }, + { + name: "unsupported config type", + config: "invalid", + expectedError: "unsupported config type", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + provider, err := factory.Create(tt.config) + + if tt.expectedError != "" { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.expectedError) + assert.Nil(t, provider) + } else { + assert.NoError(t, err) + if tt.validate != nil { + tt.validate(t, provider) + } + } + }) + } +} + +func TestFactory_GetType(t *testing.T) { + factory := webauthn.NewFactory(nil) + assert.Equal(t, metadata.ProviderTypeAuth, factory.GetType()) +} + +func TestFactory_GetMetadata(t *testing.T) { + factory := webauthn.NewFactory(nil) + meta := factory.GetMetadata() + + assert.Equal(t, "webauthn", meta.ID) + assert.Equal(t, metadata.ProviderTypeAuth, meta.Type) + assert.Equal(t, "WebAuthn", meta.Name) + assert.Contains(t, meta.Description, "WebAuthn/FIDO2") + assert.Equal(t, "1.0.0", meta.Version) + assert.Equal(t, "auth2", meta.Author) +} + +// Mock Registry for testing +type mockRegistry struct { + mock.Mock +} + +func (m *mockRegistry) RegisterAuthProvider(provider providers.AuthProvider) error { + args := m.Called(provider) + return args.Error(0) +} + +func (m *mockRegistry) RegisterAuthProviderFactory(id string, factory factory.Factory) error { + args := m.Called(id, factory) + return args.Error(0) +} + +func (m *mockRegistry) GetAuthProvider(id string) (providers.AuthProvider, error) { + args := m.Called(id) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(providers.AuthProvider), args.Error(1) +} + +func (m *mockRegistry) GetAuthProviderFactory(id string) (factory.Factory, error) { + args := m.Called(id) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(factory.Factory), args.Error(1) +} + +func (m *mockRegistry) ListAuthProviders() []metadata.ProviderMetadata { + args := m.Called() + return args.Get(0).([]metadata.ProviderMetadata) +} + +func (m *mockRegistry) RegisterMFAProvider(provider metadata.MFAProvider) error { + args := m.Called(provider) + return args.Error(0) +} + +func (m *mockRegistry) RegisterMFAProviderFactory(id string, factory factory.Factory) error { + args := m.Called(id, factory) + return args.Error(0) +} + +func (m *mockRegistry) GetMFAProvider(id string) (metadata.MFAProvider, error) { + args := m.Called(id) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(metadata.MFAProvider), args.Error(1) +} + +func (m *mockRegistry) GetMFAProviderFactory(id string) (factory.Factory, error) { + args := m.Called(id) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(factory.Factory), args.Error(1) +} + +func (m *mockRegistry) ListMFAProviders() []metadata.ProviderMetadata { + args := m.Called() + return args.Get(0).([]metadata.ProviderMetadata) +} + +func (m *mockRegistry) RegisterStorageProvider(provider metadata.StorageProvider) error { + args := m.Called(provider) + return args.Error(0) +} + +func (m *mockRegistry) RegisterStorageProviderFactory(id string, factory factory.Factory) error { + args := m.Called(id, factory) + return args.Error(0) +} + +func (m *mockRegistry) GetStorageProvider(id string) (metadata.StorageProvider, error) { + args := m.Called(id) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(metadata.StorageProvider), args.Error(1) +} + +func (m *mockRegistry) GetStorageProviderFactory(id string) (factory.Factory, error) { + args := m.Called(id) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(factory.Factory), args.Error(1) +} + +func (m *mockRegistry) ListStorageProviders() []metadata.ProviderMetadata { + args := m.Called() + return args.Get(0).([]metadata.ProviderMetadata) +} + +func (m *mockRegistry) RegisterHTTPProvider(provider metadata.HTTPProvider) error { + args := m.Called(provider) + return args.Error(0) +} + +func (m *mockRegistry) RegisterHTTPProviderFactory(id string, factory factory.Factory) error { + args := m.Called(id, factory) + return args.Error(0) +} + +func (m *mockRegistry) GetHTTPProvider(id string) (metadata.HTTPProvider, error) { + args := m.Called(id) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(metadata.HTTPProvider), args.Error(1) +} + +func (m *mockRegistry) GetHTTPProviderFactory(id string) (factory.Factory, error) { + args := m.Called(id) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(factory.Factory), args.Error(1) +} + +func (m *mockRegistry) ListHTTPProviders() []metadata.ProviderMetadata { + args := m.Called() + return args.Get(0).([]metadata.ProviderMetadata) +} + +// TestRegister is removed as the actual registry doesn't support factory methods + +func TestCreateAuthProvider(t *testing.T) { + mockStore := &mockStateStore{} + config := &webauthn.Config{ + RPDisplayName: "Test App", + RPID: "test.com", + RPOrigins: []string{"https://test.com"}, + StateStore: mockStore, + } + + provider, err := webauthn.CreateAuthProvider(config) + + assert.NoError(t, err) + assert.NotNil(t, provider) + + // Verify it's an AuthProvider + _, ok := provider.(providers.AuthProvider) + assert.True(t, ok) +} + +func TestCreateMFAProvider(t *testing.T) { + mockStore := &mockStateStore{} + config := &webauthn.Config{ + RPDisplayName: "Test App", + RPID: "test.com", + RPOrigins: []string{"https://test.com"}, + StateStore: mockStore, + } + + provider, err := webauthn.CreateMFAProvider(config) + + assert.NoError(t, err) + assert.NotNil(t, provider) + + // Verify it's an MFAProvider + _, ok := provider.(metadata.MFAProvider) + assert.True(t, ok) +} + +func TestCreateDualModeProvider(t *testing.T) { + mockStore := &mockStateStore{} + config := &webauthn.Config{ + RPDisplayName: "Test App", + RPID: "test.com", + RPOrigins: []string{"https://test.com"}, + StateStore: mockStore, + } + + provider, err := webauthn.CreateDualModeProvider(config) + + assert.NoError(t, err) + assert.NotNil(t, provider) + + // Verify it implements both interfaces + _, isAuth := interface{}(provider).(providers.AuthProvider) + assert.True(t, isAuth) + + _, isMFA := interface{}(provider).(metadata.MFAProvider) + assert.True(t, isMFA) +} \ No newline at end of file diff --git a/pkg/auth/providers/webauthn/provider.go b/pkg/auth/providers/webauthn/provider.go new file mode 100644 index 0000000..dd03226 --- /dev/null +++ b/pkg/auth/providers/webauthn/provider.go @@ -0,0 +1,570 @@ +package webauthn + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "fmt" + "time" + + "github.com/Fishwaldo/auth2/pkg/auth/providers" + "github.com/Fishwaldo/auth2/pkg/plugin/metadata" + "github.com/go-webauthn/webauthn/protocol" + "github.com/go-webauthn/webauthn/webauthn" +) + +// Provider implements both AuthProvider and MFAProvider for WebAuthn +type Provider struct { + *providers.BaseAuthProvider + config *Config + webauthn *webauthn.WebAuthn + credentialStore *CredentialStore + challengeManager *ChallengeManager +} + +// Ensure Provider implements both interfaces +var _ providers.AuthProvider = (*Provider)(nil) +var _ metadata.MFAProvider = (*Provider)(nil) + +// New creates a new WebAuthn provider +func New(config *Config) (*Provider, error) { + if err := config.Validate(); err != nil { + return nil, err + } + + // Create WebAuthn instance + wconfig := &webauthn.Config{ + RPDisplayName: config.RPDisplayName, + RPID: config.RPID, + RPOrigins: config.RPOrigins, + AttestationPreference: protocol.ConveyancePreference(config.AttestationPreference), + AuthenticatorSelection: protocol.AuthenticatorSelection{ + AuthenticatorAttachment: protocol.AuthenticatorAttachment(config.AuthenticatorAttachment), + RequireResidentKey: &config.RequireResidentKey, + ResidentKey: protocol.ResidentKeyRequirement(config.ResidentKeyRequirement), + UserVerification: protocol.UserVerificationRequirement(config.UserVerification), + }, + Debug: config.Debug, + } + + w, err := webauthn.New(wconfig) + if err != nil { + return nil, WrapError(err, "failed to create webauthn instance") + } + + provider := &Provider{ + BaseAuthProvider: providers.NewBaseAuthProvider(metadata.ProviderMetadata{ + ID: "webauthn", + Type: metadata.ProviderTypeAuth, + Name: "WebAuthn", + Description: "WebAuthn/FIDO2 passwordless authentication and MFA", + Version: "1.0.0", + Author: "auth2", + }), + config: config, + webauthn: w, + credentialStore: NewCredentialStore(config.StateStore), + challengeManager: NewChallengeManager(config.StateStore, config.ChallengeTimeout), + } + + return provider, nil +} + +// Initialize initializes the provider with the given configuration +func (p *Provider) Initialize(ctx context.Context, config interface{}) error { + cfg, ok := config.(*Config) + if !ok { + return ErrInvalidConfig("expected *Config") + } + + newProvider, err := New(cfg) + if err != nil { + return err + } + + // Copy the initialized fields + p.config = newProvider.config + p.webauthn = newProvider.webauthn + p.credentialStore = newProvider.credentialStore + p.challengeManager = newProvider.challengeManager + + return nil +} + +// Supports checks if the provider supports the given credentials +func (p *Provider) Supports(credentials interface{}) bool { + switch creds := credentials.(type) { + case *WebAuthnAuthenticationCredentials: + return true + case WebAuthnAuthenticationCredentials: + return true + case map[string]interface{}: + // Check if it has webauthn fields + _, hasID := creds["credentialId"] + _, hasResponse := creds["response"] + return hasID && hasResponse + default: + return false + } +} + +// Authenticate performs passwordless authentication +func (p *Provider) Authenticate(ctx *providers.AuthContext, credentials interface{}) (*providers.AuthResult, error) { + // Parse credentials + authCreds, err := p.parseAuthenticationCredentials(credentials) + if err != nil { + return &providers.AuthResult{ + Success: false, + ProviderID: p.GetMetadata().ID, + Error: err, + }, err + } + + // Get user ID from challenge + challenge, err := p.challengeManager.ValidateChallenge(ctx.OriginalContext, authCreds.UserID, authCreds.ChallengeID) + if err != nil { + return &providers.AuthResult{ + Success: false, + ProviderID: p.GetMetadata().ID, + Error: ErrInvalidChallenge, + }, ErrInvalidChallenge + } + + // Get user credentials + userCreds, err := p.credentialStore.GetUserCredentials(ctx.OriginalContext, challenge.UserID) + if err != nil || len(userCreds.Credentials) == 0 { + return &providers.AuthResult{ + Success: false, + ProviderID: p.GetMetadata().ID, + Error: ErrUserNotFound, + }, ErrUserNotFound + } + + // Create webauthn user + user := &webauthnUser{ + id: challenge.UserID, + credentials: userCreds.Credentials, + } + + // Parse the assertion + parsedResponse, err := protocol.ParseCredentialRequestResponseBody(bytes.NewReader(authCreds.Response)) + if err != nil { + return &providers.AuthResult{ + Success: false, + ProviderID: p.GetMetadata().ID, + Error: ErrAuthenticationFailed, + }, ErrAuthenticationFailed + } + + // Create session data with challenge + sessionData := &webauthn.SessionData{ + Challenge: base64.URLEncoding.EncodeToString(challenge.Challenge), + UserID: []byte(challenge.UserID), + } + + // Validate the assertion + credential, err := p.webauthn.ValidateLogin(user, *sessionData, parsedResponse) + if err != nil { + return &providers.AuthResult{ + Success: false, + ProviderID: p.GetMetadata().ID, + Error: ErrAuthenticationFailed, + }, ErrAuthenticationFailed + } + + // Update credential counter + for i, cred := range userCreds.Credentials { + if string(cred.ID) == string(credential.ID) { + // Check counter + if credential.Authenticator.SignCount > 0 && credential.Authenticator.SignCount <= cred.Counter { + return &providers.AuthResult{ + Success: false, + ProviderID: p.GetMetadata().ID, + Error: ErrCounterError, + }, ErrCounterError + } + + // Update counter and last used + userCreds.Credentials[i].Counter = credential.Authenticator.SignCount + userCreds.Credentials[i].LastUsedAt = time.Now() + + // Update in store + if err := p.credentialStore.UpdateCredential(ctx.OriginalContext, challenge.UserID, &userCreds.Credentials[i]); err != nil { + // Log error but don't fail authentication + fmt.Printf("Failed to update credential: %v\n", err) + } + break + } + } + + return &providers.AuthResult{ + Success: true, + UserID: challenge.UserID, + ProviderID: p.GetMetadata().ID, + Extra: map[string]interface{}{ + "credential_id": base64.URLEncoding.EncodeToString(credential.ID), + "user_verified": credential.Flags.UserVerified, + }, + }, nil +} + +// Setup initializes WebAuthn as an MFA method for a user +func (p *Provider) Setup(ctx context.Context, userID string) (metadata.SetupData, error) { + // Create registration options + options, challenge, err := p.createRegistrationOptions(ctx, userID, false) + if err != nil { + return metadata.SetupData{}, err + } + + // Convert options to JSON for QR code or client + optionsJSON, err := json.Marshal(options) + if err != nil { + return metadata.SetupData{}, WrapError(err, "failed to marshal options") + } + + return metadata.SetupData{ + ProviderID: p.GetMetadata().ID, + UserID: userID, + Secret: challenge.ID, // Store challenge ID as secret + QRCode: optionsJSON, // Options as "QR code" (client will handle) + AdditionalData: map[string]interface{}{ + "challenge_id": challenge.ID, + "rp_id": p.config.RPID, + "timeout": p.config.Timeout.Seconds(), + }, + }, nil +} + +// Verify verifies an MFA code (WebAuthn assertion) +func (p *Provider) Verify(ctx context.Context, userID string, code string) (bool, error) { + // The "code" should be a JSON-encoded authentication response + var authResponse map[string]interface{} + if err := json.Unmarshal([]byte(code), &authResponse); err != nil { + return false, ErrInvalidCredential + } + + // Add user ID to the response + authResponse["userId"] = userID + + // Use the Authenticate method + authCtx := &providers.AuthContext{ + OriginalContext: ctx, + } + + result, err := p.Authenticate(authCtx, authResponse) + if err != nil { + return false, err + } + + return result.Success && result.UserID == userID, nil +} + +// AuthenticateMetadata implements metadata.AuthProvider interface +func (p *Provider) AuthenticateMetadata(ctx context.Context, credentials interface{}) (string, error) { + authCtx := &providers.AuthContext{ + OriginalContext: ctx, + } + + result, err := p.Authenticate(authCtx, credentials) + if err != nil { + return "", err + } + + if !result.Success { + return "", result.Error + } + + return result.UserID, nil +} + +// GenerateBackupCodes is not applicable for WebAuthn +func (p *Provider) GenerateBackupCodes(ctx context.Context, userID string, count int) ([]string, error) { + return nil, fmt.Errorf("backup codes are not supported for WebAuthn") +} + +// BeginRegistration starts the WebAuthn registration process +func (p *Provider) BeginRegistration(ctx context.Context, userID string, username string, displayName string) (*RegistrationOptions, error) { + options, _, err := p.createRegistrationOptions(ctx, userID, true) + if err != nil { + return nil, err + } + + // Set user information + if username != "" { + options.User.Name = username + } + if displayName != "" { + options.User.DisplayName = displayName + } + + return options, nil +} + +// CompleteRegistration completes the WebAuthn registration process +func (p *Provider) CompleteRegistration(ctx context.Context, userID string, challengeID string, response *RegistrationResponse) error { + // Validate challenge + challenge, err := p.challengeManager.ValidateChallenge(ctx, userID, challengeID) + if err != nil { + return ErrInvalidChallenge + } + + // Create webauthn user + user := &webauthnUser{ + id: userID, + name: userID, + displayName: userID, + } + + // Create session data + sessionData := &webauthn.SessionData{ + Challenge: base64.URLEncoding.EncodeToString(challenge.Challenge), + UserID: []byte(userID), + } + + // Parse credential creation response + parsedResponse, err := protocol.ParseCredentialCreationResponseBody(bytes.NewReader(response.ClientDataJSON)) + if err != nil { + return WrapError(err, "failed to parse response") + } + + // Verify the registration + credential, err := p.webauthn.CreateCredential(user, *sessionData, parsedResponse) + if err != nil { + return WrapError(err, "failed to create credential") + } + + // Store the credential + cred := &Credential{ + ID: credential.ID, + PublicKey: credential.PublicKey, + AttestationType: string(credential.AttestationType), + Transport: response.Transports, + Flags: CredentialFlags{ + UserPresent: credential.Flags.UserPresent, + UserVerified: credential.Flags.UserVerified, + BackupEligible: credential.Flags.BackupEligible, + BackupState: credential.Flags.BackupState, + }, + Authenticator: AuthenticatorData{ + AAGUID: credential.Authenticator.AAGUID, + SignCount: credential.Authenticator.SignCount, + CloneWarning: credential.Authenticator.CloneWarning, + }, + Counter: credential.Authenticator.SignCount, + Attachment: response.AuthenticatorAttachment, + } + + if err := p.credentialStore.AddCredential(ctx, userID, cred); err != nil { + return err + } + + return nil +} + +// BeginAuthentication starts the WebAuthn authentication process +func (p *Provider) BeginAuthentication(ctx context.Context, userID string) (*AuthenticationOptions, error) { + // Get user credentials + userCreds, err := p.credentialStore.GetUserCredentials(ctx, userID) + if err != nil || len(userCreds.Credentials) == 0 { + return nil, ErrUserNotFound + } + + // Create challenge + challenge, err := p.challengeManager.CreateChallenge(ctx, userID, "authentication") + if err != nil { + return nil, err + } + + // Build allowed credentials + allowedCreds := make([]PublicKeyCredentialDescriptor, len(userCreds.Credentials)) + for i, cred := range userCreds.Credentials { + allowedCreds[i] = PublicKeyCredentialDescriptor{ + Type: "public-key", + ID: cred.ID, + Transports: cred.Transport, + } + } + + options := &AuthenticationOptions{ + Challenge: challenge.Challenge, + Timeout: uint64(p.config.Timeout.Milliseconds()), + RelyingPartyID: p.config.RPID, + AllowCredentials: allowedCreds, + UserVerification: p.config.UserVerification, + } + + return options, nil +} + +// Helper functions + +func (p *Provider) createRegistrationOptions(ctx context.Context, userID string, includeUser bool) (*RegistrationOptions, *Challenge, error) { + // Create challenge + challenge, err := p.challengeManager.CreateChallenge(ctx, userID, "registration") + if err != nil { + return nil, nil, err + } + + // Get existing credentials to exclude + userCreds, _ := p.credentialStore.GetUserCredentials(ctx, userID) + excludeCreds := make([]PublicKeyCredentialDescriptor, len(userCreds.Credentials)) + for i, cred := range userCreds.Credentials { + excludeCreds[i] = PublicKeyCredentialDescriptor{ + Type: "public-key", + ID: cred.ID, + Transports: cred.Transport, + } + } + + // Build credential parameters + credParams := make([]PublicKeyCredentialParameters, len(p.config.SupportedAlgorithms)) + for i, alg := range p.config.SupportedAlgorithms { + credParams[i] = PublicKeyCredentialParameters{ + Type: "public-key", + Algorithm: alg, + } + } + + options := &RegistrationOptions{ + Challenge: challenge.Challenge, + RelyingParty: RelyingParty{ + ID: p.config.RPID, + Name: p.config.RPDisplayName, + }, + PubKeyCredParams: credParams, + Timeout: uint64(p.config.Timeout.Milliseconds()), + ExcludeCredentials: excludeCreds, + Attestation: p.config.AttestationPreference, + AuthenticatorSelection: &AuthenticatorSelection{ + AuthenticatorAttachment: p.config.AuthenticatorAttachment, + ResidentKey: p.config.ResidentKeyRequirement, + RequireResidentKey: p.config.RequireResidentKey, + UserVerification: p.config.UserVerification, + }, + } + + if includeUser { + // Generate user ID bytes + userIDBytes := make([]byte, 32) + copy(userIDBytes, []byte(userID)) + + options.User = User{ + ID: userIDBytes, + Name: userID, + DisplayName: userID, + } + } + + return options, challenge, nil +} + +func (p *Provider) parseAuthenticationCredentials(credentials interface{}) (*WebAuthnAuthenticationCredentials, error) { + switch creds := credentials.(type) { + case *WebAuthnAuthenticationCredentials: + return creds, nil + case WebAuthnAuthenticationCredentials: + return &creds, nil + case map[string]interface{}: + // Parse from map + result := &WebAuthnAuthenticationCredentials{} + + if v, ok := creds["credentialId"].(string); ok { + result.CredentialID = v + } + if v, ok := creds["response"].([]byte); ok { + result.Response = v + } else if v, ok := creds["response"].(string); ok { + result.Response = []byte(v) + } + if v, ok := creds["challengeId"].(string); ok { + result.ChallengeID = v + } + if v, ok := creds["userId"].(string); ok { + result.UserID = v + } + + if result.CredentialID == "" || len(result.Response) == 0 { + return nil, ErrInvalidCredential + } + + return result, nil + default: + return nil, ErrInvalidCredential + } +} + +// WebAuthnAuthenticationCredentials represents credentials for WebAuthn authentication +type WebAuthnAuthenticationCredentials struct { + CredentialID string `json:"credentialId"` + Response []byte `json:"response"` + ChallengeID string `json:"challengeId"` + UserID string `json:"userId"` +} + +// webauthnUser implements webauthn.User interface +type webauthnUser struct { + id string + name string + displayName string + credentials []Credential +} + +func (u *webauthnUser) WebAuthnID() []byte { + return []byte(u.id) +} + +func (u *webauthnUser) WebAuthnName() string { + if u.name != "" { + return u.name + } + return u.id +} + +func (u *webauthnUser) WebAuthnDisplayName() string { + if u.displayName != "" { + return u.displayName + } + return u.name +} + +func (u *webauthnUser) WebAuthnCredentials() []webauthn.Credential { + creds := make([]webauthn.Credential, len(u.credentials)) + for i, c := range u.credentials { + creds[i] = webauthn.Credential{ + ID: c.ID, + PublicKey: c.PublicKey, + AttestationType: c.AttestationType, + Transport: convertTransports(c.Transport), + Flags: webauthn.CredentialFlags{ + UserPresent: c.Flags.UserPresent, + UserVerified: c.Flags.UserVerified, + BackupEligible: c.Flags.BackupEligible, + BackupState: c.Flags.BackupState, + }, + Authenticator: webauthn.Authenticator{ + AAGUID: c.Authenticator.AAGUID, + SignCount: c.Authenticator.SignCount, + CloneWarning: c.Authenticator.CloneWarning, + }, + } + } + return creds +} + +func (u *webauthnUser) WebAuthnIcon() string { + return "" +} + +// convertTransports converts string transports to protocol.AuthenticatorTransport +func convertTransports(transports []string) []protocol.AuthenticatorTransport { + if len(transports) == 0 { + return nil + } + + result := make([]protocol.AuthenticatorTransport, len(transports)) + for i, t := range transports { + result[i] = protocol.AuthenticatorTransport(t) + } + return result +} \ No newline at end of file diff --git a/pkg/auth/providers/webauthn/provider_test.go b/pkg/auth/providers/webauthn/provider_test.go new file mode 100644 index 0000000..cddb85b --- /dev/null +++ b/pkg/auth/providers/webauthn/provider_test.go @@ -0,0 +1,428 @@ +package webauthn_test + +import ( + "context" + "testing" + "time" + + "github.com/Fishwaldo/auth2/pkg/auth/providers/webauthn" + "github.com/Fishwaldo/auth2/pkg/plugin/metadata" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +// Mock StateStore implementation +type mockStateStore struct { + mock.Mock +} + +func (m *mockStateStore) StoreState(ctx context.Context, namespace string, entityID string, key string, value interface{}) error { + args := m.Called(ctx, namespace, entityID, key, value) + return args.Error(0) +} + +func (m *mockStateStore) GetState(ctx context.Context, namespace string, entityID string, key string, valuePtr interface{}) error { + args := m.Called(ctx, namespace, entityID, key, valuePtr) + return args.Error(0) +} + +func (m *mockStateStore) DeleteState(ctx context.Context, namespace string, entityID string, key string) error { + args := m.Called(ctx, namespace, entityID, key) + return args.Error(0) +} + +func (m *mockStateStore) ListStateKeys(ctx context.Context, namespace string, entityID string) ([]string, error) { + args := m.Called(ctx, namespace, entityID) + return args.Get(0).([]string), args.Error(1) +} + +func TestProvider_New(t *testing.T) { + tests := []struct { + name string + config *webauthn.Config + expectedError string + }{ + { + name: "valid config", + config: &webauthn.Config{ + RPDisplayName: "Test App", + RPID: "localhost", + RPOrigins: []string{"http://localhost"}, + StateStore: &mockStateStore{}, + }, + expectedError: "", + }, + { + name: "missing rp_display_name", + config: &webauthn.Config{ + RPID: "localhost", + RPOrigins: []string{"http://localhost"}, + StateStore: &mockStateStore{}, + }, + expectedError: "rp_display_name is required", + }, + { + name: "missing rp_id", + config: &webauthn.Config{ + RPDisplayName: "Test App", + RPOrigins: []string{"http://localhost"}, + StateStore: &mockStateStore{}, + }, + expectedError: "rp_id is required", + }, + { + name: "missing rp_origins", + config: &webauthn.Config{ + RPDisplayName: "Test App", + RPID: "localhost", + StateStore: &mockStateStore{}, + }, + expectedError: "at least one rp_origin is required", + }, + { + name: "missing state_store", + config: &webauthn.Config{ + RPDisplayName: "Test App", + RPID: "localhost", + RPOrigins: []string{"http://localhost"}, + }, + expectedError: "state_store is required", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + provider, err := webauthn.New(tt.config) + + if tt.expectedError != "" { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.expectedError) + assert.Nil(t, provider) + } else { + assert.NoError(t, err) + assert.NotNil(t, provider) + } + }) + } +} + +func TestProvider_Initialize(t *testing.T) { + mockStore := &mockStateStore{} + + tests := []struct { + name string + config interface{} + expectedError string + }{ + { + name: "valid config", + config: &webauthn.Config{ + RPDisplayName: "Test App", + RPID: "localhost", + RPOrigins: []string{"http://localhost"}, + StateStore: mockStore, + }, + expectedError: "", + }, + { + name: "invalid config type", + config: "invalid", + expectedError: "expected *Config", + }, + { + name: "nil config", + config: nil, + expectedError: "expected *Config", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + provider, err := webauthn.New(&webauthn.Config{ + RPDisplayName: "Initial", + RPID: "initial", + RPOrigins: []string{"http://initial"}, + StateStore: mockStore, + }) + require.NoError(t, err) + + err = provider.Initialize(context.Background(), tt.config) + + if tt.expectedError != "" { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.expectedError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestProvider_Supports(t *testing.T) { + mockStore := &mockStateStore{} + provider, err := webauthn.New(&webauthn.Config{ + RPDisplayName: "Test App", + RPID: "localhost", + RPOrigins: []string{"http://localhost"}, + StateStore: mockStore, + }) + require.NoError(t, err) + + tests := []struct { + name string + credentials interface{} + expected bool + }{ + { + name: "WebAuthnAuthenticationCredentials pointer", + credentials: &webauthn.WebAuthnAuthenticationCredentials{ + CredentialID: "test", + Response: []byte("response"), + }, + expected: true, + }, + { + name: "WebAuthnAuthenticationCredentials value", + credentials: webauthn.WebAuthnAuthenticationCredentials{ + CredentialID: "test", + Response: []byte("response"), + }, + expected: true, + }, + { + name: "map with webauthn fields", + credentials: map[string]interface{}{ + "credentialId": "test", + "response": []byte("response"), + }, + expected: true, + }, + { + name: "map without webauthn fields", + credentials: map[string]interface{}{ + "username": "test", + "password": "password", + }, + expected: false, + }, + { + name: "unsupported type", + credentials: "invalid", + expected: false, + }, + { + name: "nil credentials", + credentials: nil, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := provider.Supports(tt.credentials) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestProvider_Setup(t *testing.T) { + mockStore := &mockStateStore{} + provider, err := webauthn.New(&webauthn.Config{ + RPDisplayName: "Test App", + RPID: "localhost", + RPOrigins: []string{"http://localhost"}, + StateStore: mockStore, + Timeout: 60 * time.Second, + ChallengeTimeout: 5 * time.Minute, + }) + require.NoError(t, err) + + ctx := context.Background() + userID := "test-user" + + // Setup mock expectations + mockStore.On("StoreState", ctx, "webauthn_challenges", userID, mock.AnythingOfType("string"), mock.AnythingOfType("*webauthn.Challenge")).Return(nil) + mockStore.On("GetState", ctx, "webauthn_credentials", userID, "credentials", mock.AnythingOfType("*webauthn.UserCredentials")).Return(nil) + + // Call Setup + setupData, err := provider.Setup(ctx, userID) + + // Assertions + assert.NoError(t, err) + assert.Equal(t, "webauthn", setupData.ProviderID) + assert.Equal(t, userID, setupData.UserID) + assert.NotEmpty(t, setupData.Secret) // Challenge ID + assert.NotEmpty(t, setupData.QRCode) // Options JSON + + // Verify additional data + assert.Contains(t, setupData.AdditionalData, "challenge_id") + assert.Contains(t, setupData.AdditionalData, "rp_id") + assert.Contains(t, setupData.AdditionalData, "timeout") + assert.Equal(t, "localhost", setupData.AdditionalData["rp_id"]) + assert.Equal(t, float64(60), setupData.AdditionalData["timeout"]) + + mockStore.AssertExpectations(t) +} + +func TestProvider_GenerateBackupCodes(t *testing.T) { + mockStore := &mockStateStore{} + provider, err := webauthn.New(&webauthn.Config{ + RPDisplayName: "Test App", + RPID: "localhost", + RPOrigins: []string{"http://localhost"}, + StateStore: mockStore, + }) + require.NoError(t, err) + + ctx := context.Background() + userID := "test-user" + + // GenerateBackupCodes should not be supported for WebAuthn + codes, err := provider.GenerateBackupCodes(ctx, userID, 10) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "backup codes are not supported for WebAuthn") + assert.Nil(t, codes) +} + +func TestProvider_BeginRegistration(t *testing.T) { + mockStore := &mockStateStore{} + provider, err := webauthn.New(&webauthn.Config{ + RPDisplayName: "Test App", + RPID: "localhost", + RPOrigins: []string{"http://localhost"}, + StateStore: mockStore, + Timeout: 60 * time.Second, + ChallengeTimeout: 5 * time.Minute, + SupportedAlgorithms: []int64{-7, -257}, + }) + require.NoError(t, err) + + ctx := context.Background() + userID := "test-user" + username := "testuser" + displayName := "Test User" + + // Setup mock expectations + mockStore.On("StoreState", ctx, "webauthn_challenges", userID, mock.AnythingOfType("string"), mock.AnythingOfType("*webauthn.Challenge")).Return(nil) + mockStore.On("GetState", ctx, "webauthn_credentials", userID, "credentials", mock.AnythingOfType("*webauthn.UserCredentials")).Return(nil) + + // Call BeginRegistration + options, err := provider.BeginRegistration(ctx, userID, username, displayName) + + // Assertions + assert.NoError(t, err) + assert.NotNil(t, options) + assert.NotEmpty(t, options.Challenge) + assert.Equal(t, "localhost", options.RelyingParty.ID) + assert.Equal(t, "Test App", options.RelyingParty.Name) + assert.Equal(t, username, options.User.Name) + assert.Equal(t, displayName, options.User.DisplayName) + assert.Len(t, options.PubKeyCredParams, 2) + assert.Equal(t, int64(-7), options.PubKeyCredParams[0].Algorithm) + assert.Equal(t, int64(-257), options.PubKeyCredParams[1].Algorithm) + + mockStore.AssertExpectations(t) +} + +func TestProvider_BeginAuthentication(t *testing.T) { + mockStore := &mockStateStore{} + provider, err := webauthn.New(&webauthn.Config{ + RPDisplayName: "Test App", + RPID: "localhost", + RPOrigins: []string{"http://localhost"}, + StateStore: mockStore, + Timeout: 60 * time.Second, + ChallengeTimeout: 5 * time.Minute, + UserVerification: webauthn.UserVerificationPreferred, + }) + require.NoError(t, err) + + ctx := context.Background() + userID := "test-user" + + t.Run("user with credentials", func(t *testing.T) { + // Setup test data + testCreds := &webauthn.UserCredentials{ + UserID: userID, + Credentials: []webauthn.Credential{ + { + ID: []byte("cred1"), + PublicKey: []byte("pubkey1"), + Transport: []string{"usb", "nfc"}, + }, + { + ID: []byte("cred2"), + PublicKey: []byte("pubkey2"), + Transport: []string{"internal"}, + }, + }, + } + + // Setup mock expectations + mockStore.On("GetState", ctx, "webauthn_credentials", userID, "credentials", mock.AnythingOfType("*webauthn.UserCredentials")).Run(func(args mock.Arguments) { + userCreds := args.Get(4).(*webauthn.UserCredentials) + *userCreds = *testCreds + }).Return(nil).Once() + + mockStore.On("StoreState", ctx, "webauthn_challenges", userID, mock.AnythingOfType("string"), mock.AnythingOfType("*webauthn.Challenge")).Return(nil).Once() + + // Call BeginAuthentication + options, err := provider.BeginAuthentication(ctx, userID) + + // Assertions + assert.NoError(t, err) + assert.NotNil(t, options) + assert.NotEmpty(t, options.Challenge) + assert.Equal(t, "localhost", options.RelyingPartyID) + assert.Equal(t, webauthn.UserVerificationPreferred, options.UserVerification) + assert.Len(t, options.AllowCredentials, 2) + + // Verify allowed credentials + assert.Equal(t, "public-key", options.AllowCredentials[0].Type) + assert.Equal(t, []byte("cred1"), options.AllowCredentials[0].ID) + assert.Equal(t, []string{"usb", "nfc"}, options.AllowCredentials[0].Transports) + + assert.Equal(t, "public-key", options.AllowCredentials[1].Type) + assert.Equal(t, []byte("cred2"), options.AllowCredentials[1].ID) + assert.Equal(t, []string{"internal"}, options.AllowCredentials[1].Transports) + + mockStore.AssertExpectations(t) + }) + + t.Run("user without credentials", func(t *testing.T) { + // Setup mock expectations - return empty credentials + mockStore.On("GetState", ctx, "webauthn_credentials", userID, "credentials", mock.AnythingOfType("*webauthn.UserCredentials")).Return(nil).Once() + + // Call BeginAuthentication + options, err := provider.BeginAuthentication(ctx, userID) + + // Should fail with user not found + assert.Error(t, err) + assert.Equal(t, webauthn.ErrUserNotFound, err) + assert.Nil(t, options) + + mockStore.AssertExpectations(t) + }) +} + +func TestProvider_GetMetadata(t *testing.T) { + mockStore := &mockStateStore{} + provider, err := webauthn.New(&webauthn.Config{ + RPDisplayName: "Test App", + RPID: "localhost", + RPOrigins: []string{"http://localhost"}, + StateStore: mockStore, + }) + require.NoError(t, err) + + meta := provider.GetMetadata() + + assert.Equal(t, "webauthn", meta.ID) + assert.Equal(t, metadata.ProviderTypeAuth, meta.Type) + assert.Equal(t, "WebAuthn", meta.Name) + assert.Contains(t, meta.Description, "WebAuthn/FIDO2") + assert.Equal(t, "1.0.0", meta.Version) + assert.Equal(t, "auth2", meta.Author) +} \ No newline at end of file diff --git a/pkg/auth/providers/webauthn/types.go b/pkg/auth/providers/webauthn/types.go new file mode 100644 index 0000000..5189066 --- /dev/null +++ b/pkg/auth/providers/webauthn/types.go @@ -0,0 +1,181 @@ +package webauthn + +import ( + "time" +) + +// AttestationConveyancePreference represents the attestation preference +type AttestationConveyancePreference string + +const ( + // AttestationNone indicates no attestation is required + AttestationNone AttestationConveyancePreference = "none" + // AttestationIndirect indicates indirect attestation is preferred + AttestationIndirect AttestationConveyancePreference = "indirect" + // AttestationDirect indicates direct attestation is preferred + AttestationDirect AttestationConveyancePreference = "direct" + // AttestationEnterprise indicates enterprise attestation is preferred + AttestationEnterprise AttestationConveyancePreference = "enterprise" +) + +// UserVerificationRequirement represents the user verification requirement +type UserVerificationRequirement string + +const ( + // UserVerificationRequired requires user verification + UserVerificationRequired UserVerificationRequirement = "required" + // UserVerificationPreferred prefers user verification but doesn't require it + UserVerificationPreferred UserVerificationRequirement = "preferred" + // UserVerificationDiscouraged discourages user verification + UserVerificationDiscouraged UserVerificationRequirement = "discouraged" +) + +// ResidentKeyRequirement represents the resident key requirement +type ResidentKeyRequirement string + +const ( + // ResidentKeyDiscouraged discourages resident keys + ResidentKeyDiscouraged ResidentKeyRequirement = "discouraged" + // ResidentKeyPreferred prefers resident keys + ResidentKeyPreferred ResidentKeyRequirement = "preferred" + // ResidentKeyRequired requires resident keys + ResidentKeyRequired ResidentKeyRequirement = "required" +) + +// AuthenticatorAttachment represents the authenticator attachment +type AuthenticatorAttachment string + +const ( + // AttachmentPlatform indicates a platform authenticator + AttachmentPlatform AuthenticatorAttachment = "platform" + // AttachmentCrossPlatform indicates a cross-platform authenticator + AttachmentCrossPlatform AuthenticatorAttachment = "cross-platform" +) + +// Challenge represents a WebAuthn challenge +type Challenge struct { + ID string `json:"id"` + UserID string `json:"user_id"` + Challenge []byte `json:"challenge"` + Type string `json:"type"` // "registration" or "authentication" + CreatedAt time.Time `json:"created_at"` + ExpiresAt time.Time `json:"expires_at"` +} + +// Credential represents a stored WebAuthn credential +type Credential struct { + ID []byte `json:"id"` + PublicKey []byte `json:"public_key"` + AttestationType string `json:"attestation_type"` + Transport []string `json:"transport"` + Flags CredentialFlags `json:"flags"` + Authenticator AuthenticatorData `json:"authenticator"` + CreatedAt time.Time `json:"created_at"` + LastUsedAt time.Time `json:"last_used_at"` + Counter uint32 `json:"counter"` + BackupEligible bool `json:"backup_eligible"` + BackupState bool `json:"backup_state"` + Attachment AuthenticatorAttachment `json:"attachment,omitempty"` +} + +// CredentialFlags represents credential flags +type CredentialFlags struct { + UserPresent bool `json:"user_present"` + UserVerified bool `json:"user_verified"` + BackupEligible bool `json:"backup_eligible"` + BackupState bool `json:"backup_state"` +} + +// AuthenticatorData represents authenticator data +type AuthenticatorData struct { + AAGUID []byte `json:"aaguid"` + SignCount uint32 `json:"sign_count"` + CloneWarning bool `json:"clone_warning"` +} + +// UserCredentials represents all credentials for a user +type UserCredentials struct { + UserID string `json:"user_id"` + Credentials []Credential `json:"credentials"` +} + +// RegistrationOptions represents options for credential creation +type RegistrationOptions struct { + Challenge []byte `json:"challenge"` + RelyingParty RelyingParty `json:"rp"` + User User `json:"user"` + PubKeyCredParams []PublicKeyCredentialParameters `json:"pubKeyCredParams"` + Timeout uint64 `json:"timeout,omitempty"` + ExcludeCredentials []PublicKeyCredentialDescriptor `json:"excludeCredentials,omitempty"` + AuthenticatorSelection *AuthenticatorSelection `json:"authenticatorSelection,omitempty"` + Attestation AttestationConveyancePreference `json:"attestation,omitempty"` + Extensions map[string]interface{} `json:"extensions,omitempty"` +} + +// AuthenticationOptions represents options for authentication +type AuthenticationOptions struct { + Challenge []byte `json:"challenge"` + Timeout uint64 `json:"timeout,omitempty"` + RelyingPartyID string `json:"rpId,omitempty"` + AllowCredentials []PublicKeyCredentialDescriptor `json:"allowCredentials,omitempty"` + UserVerification UserVerificationRequirement `json:"userVerification,omitempty"` + Extensions map[string]interface{} `json:"extensions,omitempty"` +} + +// RelyingParty represents the relying party +type RelyingParty struct { + ID string `json:"id"` + Name string `json:"name"` +} + +// User represents a WebAuthn user +type User struct { + ID []byte `json:"id"` + Name string `json:"name"` + DisplayName string `json:"displayName"` +} + +// PublicKeyCredentialParameters represents credential parameters +type PublicKeyCredentialParameters struct { + Type string `json:"type"` + Algorithm int64 `json:"alg"` +} + +// PublicKeyCredentialDescriptor describes a credential +type PublicKeyCredentialDescriptor struct { + Type string `json:"type"` + ID []byte `json:"id"` + Transports []string `json:"transports,omitempty"` +} + +// AuthenticatorSelection represents authenticator selection criteria +type AuthenticatorSelection struct { + AuthenticatorAttachment AuthenticatorAttachment `json:"authenticatorAttachment,omitempty"` + ResidentKey ResidentKeyRequirement `json:"residentKey,omitempty"` + RequireResidentKey bool `json:"requireResidentKey,omitempty"` + UserVerification UserVerificationRequirement `json:"userVerification,omitempty"` +} + +// RegistrationResponse represents the response from credential creation +type RegistrationResponse struct { + ID string `json:"id"` + RawID []byte `json:"rawId"` + Type string `json:"type"` + AttestationObject []byte `json:"attestationObject"` + ClientDataJSON []byte `json:"clientDataJSON"` + Transports []string `json:"transports,omitempty"` + PublicKeyAlgorithm int64 `json:"publicKeyAlgorithm,omitempty"` + PublicKey []byte `json:"publicKey,omitempty"` + AuthenticatorAttachment AuthenticatorAttachment `json:"authenticatorAttachment,omitempty"` +} + +// AuthenticationResponse represents the response from authentication +type AuthenticationResponse struct { + ID string `json:"id"` + RawID []byte `json:"rawId"` + Type string `json:"type"` + AuthenticatorData []byte `json:"authenticatorData"` + ClientDataJSON []byte `json:"clientDataJSON"` + Signature []byte `json:"signature"` + UserHandle []byte `json:"userHandle,omitempty"` +} \ No newline at end of file