diff --git a/internal/errors/http.go b/internal/errors/http.go index 0ef4b7b..e2ac9f7 100644 --- a/internal/errors/http.go +++ b/internal/errors/http.go @@ -175,6 +175,8 @@ func errorToHTTPStatus(err error) int { return http.StatusRequestTimeout case CodeUnavailable: return http.StatusServiceUnavailable + case CodeValidation: + return http.StatusBadRequest default: return http.StatusInternalServerError } diff --git a/pkg/auth/providers/basic/factory_test.go b/pkg/auth/providers/basic/factory_test.go new file mode 100644 index 0000000..98a35fc --- /dev/null +++ b/pkg/auth/providers/basic/factory_test.go @@ -0,0 +1,245 @@ +package basic_test + +import ( + "context" + "testing" + + "github.com/Fishwaldo/auth2/pkg/auth/providers" + "github.com/Fishwaldo/auth2/pkg/auth/providers/basic" + "github.com/Fishwaldo/auth2/pkg/plugin/factory" + "github.com/Fishwaldo/auth2/pkg/plugin/metadata" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func TestNewFactory(t *testing.T) { + mockStore := &mockUserStore{} + mockPwdUtils := &mockPasswordUtils{} + + factory := basic.NewFactory(mockStore, mockPwdUtils) + + assert.NotNil(t, factory) +} + +func TestFactory_Create(t *testing.T) { + mockStore := &mockUserStore{} + mockPwdUtils := &mockPasswordUtils{} + factory := basic.NewFactory(mockStore, mockPwdUtils) + + tests := []struct { + name string + id string + config interface{} + expectError bool + errorContains string + validate func(*testing.T, metadata.Provider) + }{ + { + name: "create with Config struct", + id: "test-provider", + config: &basic.Config{ + AccountLockThreshold: 10, + AccountLockDuration: 60, + RequireVerifiedEmail: true, + }, + expectError: false, + validate: func(t *testing.T, p metadata.Provider) { + authProvider, ok := p.(*basic.Provider) + require.True(t, ok) + assert.NotNil(t, authProvider) + }, + }, + { + name: "create with nil config", + id: "test-provider", + config: nil, + expectError: false, + validate: func(t *testing.T, p metadata.Provider) { + authProvider, ok := p.(*basic.Provider) + require.True(t, ok) + assert.NotNil(t, authProvider) + }, + }, + { + name: "create with map config", + id: "test-provider", + config: map[string]interface{}{ + "account_lock_threshold": 15, + "account_lock_duration": 120, + "require_verified_email": false, + }, + expectError: false, + validate: func(t *testing.T, p metadata.Provider) { + authProvider, ok := p.(*basic.Provider) + require.True(t, ok) + assert.NotNil(t, authProvider) + }, + }, + { + name: "create with map config - partial values", + id: "test-provider", + config: map[string]interface{}{ + "account_lock_threshold": 20, + }, + expectError: false, + validate: func(t *testing.T, p metadata.Provider) { + authProvider, ok := p.(*basic.Provider) + require.True(t, ok) + assert.NotNil(t, authProvider) + }, + }, + { + name: "create with map config - wrong types", + id: "test-provider", + config: map[string]interface{}{ + "account_lock_threshold": "not an int", + "account_lock_duration": "not an int", + "require_verified_email": "not a bool", + }, + expectError: false, // Should not error, just use defaults + validate: func(t *testing.T, p metadata.Provider) { + authProvider, ok := p.(*basic.Provider) + require.True(t, ok) + assert.NotNil(t, authProvider) + }, + }, + { + name: "create with invalid config type", + id: "test-provider", + config: "invalid config", + expectError: true, + errorContains: "invalid configuration type", + }, + { + name: "create with invalid config struct", + id: "test-provider", + config: struct{ Field string }{Field: "value"}, + expectError: true, + errorContains: "invalid configuration type", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + provider, err := factory.Create(tt.id, tt.config) + + if tt.expectError { + assert.Error(t, err) + if tt.errorContains != "" { + assert.Contains(t, err.Error(), tt.errorContains) + } + assert.Nil(t, provider) + } else { + require.NoError(t, err) + require.NotNil(t, provider) + if tt.validate != nil { + tt.validate(t, provider) + } + } + }) + } +} + +func TestFactory_GetType(t *testing.T) { + mockStore := &mockUserStore{} + mockPwdUtils := &mockPasswordUtils{} + factory := basic.NewFactory(mockStore, mockPwdUtils) + + providerType := factory.GetType() + assert.Equal(t, metadata.ProviderTypeAuth, providerType) +} + +func TestFactory_GetMetadata(t *testing.T) { + mockStore := &mockUserStore{} + mockPwdUtils := &mockPasswordUtils{} + factory := basic.NewFactory(mockStore, mockPwdUtils) + + metadataList := factory.GetMetadata() + + assert.Len(t, metadataList, 1) + + md := metadataList[0] + assert.Equal(t, "basic", md.ID) + assert.Equal(t, metadata.ProviderTypeAuth, md.Type) + assert.Equal(t, basic.ProviderName, md.Name) + assert.Equal(t, basic.ProviderDescription, md.Description) + assert.Equal(t, basic.ProviderVersion, md.Version) +} + +func TestRegister(t *testing.T) { + mockRegistry := &mockRegistry{} + mockStore := &mockUserStore{} + mockPwdUtils := &mockPasswordUtils{} + + // Set up expectation + mockRegistry.On("RegisterAuthProviderFactory", "basic", mock.AnythingOfType("*basic.Factory")).Return(nil) + + err := basic.Register(mockRegistry, mockStore, mockPwdUtils) + assert.NoError(t, err) + + mockRegistry.AssertExpectations(t) +} + +func TestRegister_Error(t *testing.T) { + mockRegistry := &mockRegistry{} + mockStore := &mockUserStore{} + mockPwdUtils := &mockPasswordUtils{} + + // Set up expectation for error + expectedErr := assert.AnError + mockRegistry.On("RegisterAuthProviderFactory", "basic", mock.AnythingOfType("*basic.Factory")).Return(expectedErr) + + err := basic.Register(mockRegistry, mockStore, mockPwdUtils) + assert.Error(t, err) + assert.Equal(t, expectedErr, err) + + mockRegistry.AssertExpectations(t) +} + +// Mock Registry for testing +type mockRegistry struct { + mock.Mock +} + +func (m *mockRegistry) RegisterAuthProvider(provider providers.AuthProvider) error { + args := m.Called(provider) + return args.Error(0) +} + +func (m *mockRegistry) GetAuthProvider(id string) (providers.AuthProvider, error) { + args := m.Called(id) + if args.Get(0) != nil { + return args.Get(0).(providers.AuthProvider), args.Error(1) + } + return nil, args.Error(1) +} + +func (m *mockRegistry) RegisterAuthProviderFactory(id string, factory factory.Factory) error { + args := m.Called(id, factory) + return args.Error(0) +} + +func (m *mockRegistry) GetAuthProviderFactory(id string) (factory.Factory, error) { + args := m.Called(id) + if args.Get(0) != nil { + return args.Get(0).(factory.Factory), args.Error(1) + } + return nil, args.Error(1) +} + +func (m *mockRegistry) ListAuthProviders() []providers.AuthProvider { + args := m.Called() + if args.Get(0) != nil { + return args.Get(0).([]providers.AuthProvider) + } + return nil +} + +func (m *mockRegistry) CreateAuthProvider(ctx context.Context, factoryID, providerID string, config interface{}) (providers.AuthProvider, error) { + args := m.Called(ctx, factoryID, providerID, config) + if args.Get(0) != nil { + return args.Get(0).(providers.AuthProvider), args.Error(1) + } + return nil, args.Error(1) +} \ No newline at end of file diff --git a/pkg/auth/providers/basic/utils_test.go b/pkg/auth/providers/basic/utils_test.go new file mode 100644 index 0000000..056225a --- /dev/null +++ b/pkg/auth/providers/basic/utils_test.go @@ -0,0 +1,323 @@ +package basic_test + +import ( + "context" + "testing" + "time" + + "github.com/Fishwaldo/auth2/pkg/auth/providers/basic" + "github.com/Fishwaldo/auth2/pkg/user" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func TestValidateAccount(t *testing.T) { + ctx := context.Background() + + tests := []struct { + name string + user *user.User + config *basic.Config + expectError bool + errorContains string + }{ + { + name: "valid account", + user: &user.User{ + Enabled: true, + Locked: false, + EmailVerified: true, + }, + config: &basic.Config{ + RequireVerifiedEmail: true, + }, + expectError: false, + }, + { + name: "disabled account", + user: &user.User{ + Enabled: false, + }, + config: &basic.Config{}, + expectError: true, + errorContains: "account is disabled", + }, + { + name: "locked account - no expiry", + user: &user.User{ + Enabled: true, + Locked: true, + }, + config: &basic.Config{ + AccountLockDuration: 0, + }, + expectError: true, + errorContains: "account is locked", + }, + { + name: "locked account - not expired", + user: &user.User{ + Enabled: true, + Locked: true, + LockoutTime: time.Now().Add(-5 * time.Minute), + }, + config: &basic.Config{ + AccountLockDuration: 30, // 30 minutes + }, + expectError: true, + errorContains: "account is locked", + }, + { + name: "locked account - expired", + user: &user.User{ + Enabled: true, + Locked: true, + LockoutTime: time.Now().Add(-60 * time.Minute), + }, + config: &basic.Config{ + AccountLockDuration: 30, // 30 minutes + }, + expectError: false, + }, + { + name: "unverified email when required", + user: &user.User{ + Enabled: true, + Locked: false, + EmailVerified: false, + }, + config: &basic.Config{ + RequireVerifiedEmail: true, + }, + expectError: true, + errorContains: "email verification required", + }, + { + name: "unverified email when not required", + user: &user.User{ + Enabled: true, + Locked: false, + EmailVerified: false, + }, + config: &basic.Config{ + RequireVerifiedEmail: false, + }, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := basic.ValidateAccount(ctx, tt.user, tt.config) + + if tt.expectError { + assert.Error(t, err) + if tt.errorContains != "" { + assert.Contains(t, err.Error(), tt.errorContains) + } + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestCheckPasswordRequirements(t *testing.T) { + tests := []struct { + name string + user *user.User + expectRequired bool + expectReason string + }{ + { + name: "password change required", + user: &user.User{ + RequirePasswordChange: true, + }, + expectRequired: true, + expectReason: "Password change required", + }, + { + name: "password change not required", + user: &user.User{ + RequirePasswordChange: false, + }, + expectRequired: false, + expectReason: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + required, reason := basic.CheckPasswordRequirements(tt.user) + + assert.Equal(t, tt.expectRequired, required) + assert.Equal(t, tt.expectReason, reason) + }) + } +} + +func TestProcessSuccessfulLogin(t *testing.T) { + ctx := context.Background() + mockStore := &mockUserStore{} + + originalTime := time.Now() + usr := &user.User{ + ID: "user123", + FailedLoginAttempts: 5, + LastLogin: originalTime.Add(-24 * time.Hour), + } + + // Set up expectation + mockStore.On("Update", ctx, mock.MatchedBy(func(u *user.User) bool { + return u.ID == usr.ID && + u.FailedLoginAttempts == 0 && + u.LastLogin.After(originalTime) + })).Return(nil) + + err := basic.ProcessSuccessfulLogin(ctx, mockStore, usr) + + assert.NoError(t, err) + assert.Equal(t, 0, usr.FailedLoginAttempts) + assert.True(t, usr.LastLogin.After(originalTime)) + + mockStore.AssertExpectations(t) +} + +func TestProcessSuccessfulLogin_UpdateError(t *testing.T) { + ctx := context.Background() + mockStore := &mockUserStore{} + + usr := &user.User{ + ID: "user123", + FailedLoginAttempts: 5, + } + + // Set up expectation for error + expectedErr := assert.AnError + mockStore.On("Update", ctx, mock.AnythingOfType("*user.User")).Return(expectedErr) + + err := basic.ProcessSuccessfulLogin(ctx, mockStore, usr) + + assert.Error(t, err) + assert.Equal(t, expectedErr, err) + + mockStore.AssertExpectations(t) +} + +func TestProcessFailedLogin(t *testing.T) { + ctx := context.Background() + + tests := []struct { + name string + initialUser *user.User + config *basic.Config + setupMock func(*mockUserStore) + expectLocked bool + expectError bool + }{ + { + name: "increment failed attempts - not locked", + initialUser: &user.User{ + ID: "user123", + FailedLoginAttempts: 2, + }, + config: &basic.Config{ + AccountLockThreshold: 5, + }, + setupMock: func(m *mockUserStore) { + m.On("Update", ctx, mock.MatchedBy(func(u *user.User) bool { + return u.ID == "user123" && + u.FailedLoginAttempts == 3 && + !u.Locked + })).Return(nil) + }, + expectLocked: false, + expectError: false, + }, + { + name: "increment failed attempts - lock account", + initialUser: &user.User{ + ID: "user123", + FailedLoginAttempts: 4, + }, + config: &basic.Config{ + AccountLockThreshold: 5, + }, + setupMock: func(m *mockUserStore) { + m.On("Update", ctx, mock.MatchedBy(func(u *user.User) bool { + return u.ID == "user123" && + u.FailedLoginAttempts == 5 && + u.Locked && + !u.LockoutTime.IsZero() && + u.LockoutReason == "Too many failed login attempts" + })).Return(nil) + }, + expectLocked: true, + expectError: false, + }, + { + name: "no lock threshold", + initialUser: &user.User{ + ID: "user123", + FailedLoginAttempts: 10, + }, + config: &basic.Config{ + AccountLockThreshold: 0, + }, + setupMock: func(m *mockUserStore) { + m.On("Update", ctx, mock.MatchedBy(func(u *user.User) bool { + return u.ID == "user123" && + u.FailedLoginAttempts == 11 && + !u.Locked + })).Return(nil) + }, + expectLocked: false, + expectError: false, + }, + { + name: "update error", + initialUser: &user.User{ + ID: "user123", + FailedLoginAttempts: 0, + }, + config: &basic.Config{ + AccountLockThreshold: 5, + }, + setupMock: func(m *mockUserStore) { + m.On("Update", ctx, mock.AnythingOfType("*user.User")).Return(assert.AnError) + }, + expectLocked: false, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockStore := &mockUserStore{} + tt.setupMock(mockStore) + + originalTime := time.Now() + err := basic.ProcessFailedLogin(ctx, mockStore, tt.initialUser, tt.config) + + if tt.expectError { + assert.Error(t, err) + } else { + require.NoError(t, err) + assert.True(t, tt.initialUser.LastFailedLogin.After(originalTime) || + tt.initialUser.LastFailedLogin.Equal(originalTime)) + + if tt.expectLocked { + assert.True(t, tt.initialUser.Locked) + assert.False(t, tt.initialUser.LockoutTime.IsZero()) + assert.Equal(t, "Too many failed login attempts", tt.initialUser.LockoutReason) + } else { + assert.False(t, tt.initialUser.Locked) + } + } + + mockStore.AssertExpectations(t) + }) + } +} \ No newline at end of file diff --git a/pkg/log/log_test.go b/pkg/log/log_test.go new file mode 100644 index 0000000..d880973 --- /dev/null +++ b/pkg/log/log_test.go @@ -0,0 +1,517 @@ +package log_test + +import ( + "bytes" + "context" + "encoding/json" + "strings" + "testing" + + "github.com/Fishwaldo/auth2/pkg/log" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDefaultConfig(t *testing.T) { + cfg := log.DefaultConfig() + + assert.Equal(t, log.LevelInfo, cfg.Level) + assert.Equal(t, "json", cfg.Format) + assert.NotNil(t, cfg.Writer) + assert.False(t, cfg.AddSource) + assert.Nil(t, cfg.ContextKeys) +} + +func TestNew(t *testing.T) { + tests := []struct { + name string + config *log.Config + }{ + { + name: "with nil config uses default", + config: nil, + }, + { + name: "with json format", + config: &log.Config{ + Level: log.LevelDebug, + Format: "json", + Writer: &bytes.Buffer{}, + AddSource: true, + }, + }, + { + name: "with text format", + config: &log.Config{ + Level: log.LevelWarn, + Format: "text", + Writer: &bytes.Buffer{}, + AddSource: false, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + logger := log.New(tt.config) + assert.NotNil(t, logger) + assert.NotNil(t, logger.Logger) + }) + } +} + +func TestLogger_WithField(t *testing.T) { + var buf bytes.Buffer + logger := log.New(&log.Config{ + Level: log.LevelInfo, + Format: "json", + Writer: &buf, + }) + + // Create logger with field + loggerWithField := logger.WithField("key", "value") + loggerWithField.Info("test message") + + // Parse log output + var logEntry map[string]interface{} + err := json.Unmarshal(buf.Bytes(), &logEntry) + require.NoError(t, err) + + assert.Equal(t, "test message", logEntry["msg"]) + assert.Equal(t, "value", logEntry["key"]) +} + +func TestLogger_WithFields(t *testing.T) { + var buf bytes.Buffer + logger := log.New(&log.Config{ + Level: log.LevelInfo, + Format: "json", + Writer: &buf, + }) + + // Create logger with multiple fields + fields := map[string]interface{}{ + "field1": "value1", + "field2": 42, + "field3": true, + } + loggerWithFields := logger.WithFields(fields) + loggerWithFields.Info("test message") + + // Parse log output + var logEntry map[string]interface{} + err := json.Unmarshal(buf.Bytes(), &logEntry) + require.NoError(t, err) + + assert.Equal(t, "test message", logEntry["msg"]) + assert.Equal(t, "value1", logEntry["field1"]) + assert.Equal(t, float64(42), logEntry["field2"]) // JSON unmarshals numbers as float64 + assert.Equal(t, true, logEntry["field3"]) +} + +func TestLogger_WithContext(t *testing.T) { + var buf bytes.Buffer + logger := log.New(&log.Config{ + Level: log.LevelInfo, + Format: "json", + Writer: &buf, + }) + + ctx := context.Background() + loggerWithCtx := logger.WithContext(ctx) + assert.NotNil(t, loggerWithCtx) + + // Test that it returns a logger + loggerWithCtx.Info("test message") + + // Verify log was written + assert.NotEmpty(t, buf.String()) +} + +func TestLogger_Levels(t *testing.T) { + tests := []struct { + name string + logLevel log.Level + msgLevel string + shouldLog bool + }{ + { + name: "debug level logs debug", + logLevel: log.LevelDebug, + msgLevel: "debug", + shouldLog: true, + }, + { + name: "info level skips debug", + logLevel: log.LevelInfo, + msgLevel: "debug", + shouldLog: false, + }, + { + name: "info level logs info", + logLevel: log.LevelInfo, + msgLevel: "info", + shouldLog: true, + }, + { + name: "warn level logs warn", + logLevel: log.LevelWarn, + msgLevel: "warn", + shouldLog: true, + }, + { + name: "error level logs error", + logLevel: log.LevelError, + msgLevel: "error", + shouldLog: true, + }, + { + name: "error level skips warn", + logLevel: log.LevelError, + msgLevel: "warn", + shouldLog: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var buf bytes.Buffer + logger := log.New(&log.Config{ + Level: tt.logLevel, + Format: "json", + Writer: &buf, + }) + + msg := "test message" + switch tt.msgLevel { + case "debug": + logger.Debug(msg) + case "info": + logger.Info(msg) + case "warn": + logger.Warn(msg) + case "error": + logger.Error(msg) + } + + if tt.shouldLog { + assert.NotEmpty(t, buf.String()) + // Verify the message was logged + var logEntry map[string]interface{} + err := json.Unmarshal(buf.Bytes(), &logEntry) + require.NoError(t, err) + assert.Equal(t, msg, logEntry["msg"]) + } else { + assert.Empty(t, buf.String()) + } + }) + } +} + +func TestLogger_TextFormat(t *testing.T) { + var buf bytes.Buffer + logger := log.New(&log.Config{ + Level: log.LevelInfo, + Format: "text", + Writer: &buf, + }) + + logger.Info("test message", "key", "value") + + output := buf.String() + assert.Contains(t, output, "test message") + assert.Contains(t, output, "key=value") +} + +func TestLogger_AddSource(t *testing.T) { + var buf bytes.Buffer + logger := log.New(&log.Config{ + Level: log.LevelInfo, + Format: "json", + Writer: &buf, + AddSource: true, + }) + + logger.Info("test message") + + // Parse log output + var logEntry map[string]interface{} + err := json.Unmarshal(buf.Bytes(), &logEntry) + require.NoError(t, err) + + // Should have source information + assert.Contains(t, logEntry, "source") + source := logEntry["source"].(map[string]interface{}) + assert.Contains(t, source, "file") + assert.Contains(t, source, "line") +} + +func TestSetDefault(t *testing.T) { + // Save original default logger + originalDefault := log.Default() + defer log.SetDefault(originalDefault) + + var buf bytes.Buffer + newLogger := log.New(&log.Config{ + Level: log.LevelInfo, + Format: "json", + Writer: &buf, + }) + + log.SetDefault(newLogger) + + // Verify the default was set + assert.Equal(t, newLogger, log.Default()) + + // Test convenience functions use the new default + log.Info("test from default") + + assert.NotEmpty(t, buf.String()) + var logEntry map[string]interface{} + err := json.Unmarshal(buf.Bytes(), &logEntry) + require.NoError(t, err) + assert.Equal(t, "test from default", logEntry["msg"]) +} + +func TestConvenienceMethods(t *testing.T) { + // Save original default logger + originalDefault := log.Default() + defer log.SetDefault(originalDefault) + + var buf bytes.Buffer + logger := log.New(&log.Config{ + Level: log.LevelDebug, + Format: "json", + Writer: &buf, + }) + log.SetDefault(logger) + + tests := []struct { + name string + logFunc func(string, ...interface{}) + level string + }{ + { + name: "debug", + logFunc: log.Debug, + level: "DEBUG", + }, + { + name: "info", + logFunc: log.Info, + level: "INFO", + }, + { + name: "warn", + logFunc: log.Warn, + level: "WARN", + }, + { + name: "error", + logFunc: log.Error, + level: "ERROR", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + buf.Reset() + + tt.logFunc("test message", "key", "value") + + var logEntry map[string]interface{} + err := json.Unmarshal(buf.Bytes(), &logEntry) + require.NoError(t, err) + + assert.Equal(t, "test message", logEntry["msg"]) + assert.Equal(t, tt.level, logEntry["level"]) + assert.Equal(t, "value", logEntry["key"]) + }) + } +} + +func TestGetLoggerFromContext(t *testing.T) { + tests := []struct { + name string + setupContext func() context.Context + expectDefault bool + }{ + { + name: "context with logger", + setupContext: func() context.Context { + logger := log.New(&log.Config{ + Level: log.LevelInfo, + Format: "json", + Writer: &bytes.Buffer{}, + }) + return log.ContextWithLogger(context.Background(), logger) + }, + expectDefault: false, + }, + { + name: "context without logger", + setupContext: func() context.Context { + return context.Background() + }, + expectDefault: true, + }, + { + name: "nil context", + setupContext: func() context.Context { + return nil + }, + expectDefault: true, + }, + { + name: "context with wrong type", + setupContext: func() context.Context { + type wrongKey struct{} + return context.WithValue(context.Background(), wrongKey{}, "not a logger") + }, + expectDefault: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := tt.setupContext() + logger := log.GetLoggerFromContext(ctx) + + assert.NotNil(t, logger) + if tt.expectDefault { + assert.Equal(t, log.Default(), logger) + } else { + // Should be the logger from context, not default + assert.NotNil(t, logger) + } + }) + } +} + +func TestContextWithLogger(t *testing.T) { + logger := log.New(&log.Config{ + Level: log.LevelInfo, + Format: "json", + Writer: &bytes.Buffer{}, + }) + + ctx := log.ContextWithLogger(context.Background(), logger) + + // Retrieve logger from context + retrievedLogger := log.GetLoggerFromContext(ctx) + assert.Equal(t, logger, retrievedLogger) +} + +func TestLogger_WithFieldsEdgeCases(t *testing.T) { + var buf bytes.Buffer + logger := log.New(&log.Config{ + Level: log.LevelInfo, + Format: "json", + Writer: &buf, + }) + + // Test with empty fields map + emptyFields := map[string]interface{}{} + loggerWithEmpty := logger.WithFields(emptyFields) + loggerWithEmpty.Info("test empty fields") + + var logEntry map[string]interface{} + err := json.Unmarshal(buf.Bytes(), &logEntry) + require.NoError(t, err) + assert.Equal(t, "test empty fields", logEntry["msg"]) + + // Test with nil values + buf.Reset() + nilFields := map[string]interface{}{ + "nilField": nil, + "strField": "value", + } + loggerWithNil := logger.WithFields(nilFields) + loggerWithNil.Info("test nil fields") + + err = json.Unmarshal(buf.Bytes(), &logEntry) + require.NoError(t, err) + assert.Equal(t, "test nil fields", logEntry["msg"]) + assert.Nil(t, logEntry["nilField"]) + assert.Equal(t, "value", logEntry["strField"]) +} + +func TestLogger_OutputFormats(t *testing.T) { + tests := []struct { + name string + format string + checkOutput func(t *testing.T, output string) + }{ + { + name: "json format with special characters", + format: "json", + checkOutput: func(t *testing.T, output string) { + var logEntry map[string]interface{} + err := json.Unmarshal([]byte(output), &logEntry) + require.NoError(t, err) + assert.Equal(t, "test \"quoted\" message", logEntry["msg"]) + assert.Equal(t, "value with\nnewline", logEntry["special"]) + }, + }, + { + name: "text format with special characters", + format: "text", + checkOutput: func(t *testing.T, output string) { + assert.Contains(t, output, "test \\\"quoted\\\" message") + assert.Contains(t, output, "special=\"value with\\nnewline\"") + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var buf bytes.Buffer + logger := log.New(&log.Config{ + Level: log.LevelInfo, + Format: tt.format, + Writer: &buf, + }) + + logger.Info("test \"quoted\" message", "special", "value with\nnewline") + + tt.checkOutput(t, strings.TrimSpace(buf.String())) + }) + } +} + +func TestLogger_ConcurrentUse(t *testing.T) { + var buf bytes.Buffer + logger := log.New(&log.Config{ + Level: log.LevelInfo, + Format: "json", + Writer: &buf, + }) + + done := make(chan bool, 10) + + // Launch multiple goroutines logging concurrently + for i := 0; i < 10; i++ { + go func(id int) { + logger.WithField("goroutine", id).Info("concurrent log") + done <- true + }(i) + } + + // Wait for all goroutines to complete + for i := 0; i < 10; i++ { + <-done + } + + // Verify all logs were written + lines := strings.Split(strings.TrimSpace(buf.String()), "\n") + assert.Len(t, lines, 10) + + // Verify each line is valid JSON + for _, line := range lines { + var logEntry map[string]interface{} + err := json.Unmarshal([]byte(line), &logEntry) + require.NoError(t, err) + assert.Equal(t, "concurrent log", logEntry["msg"]) + assert.Contains(t, logEntry, "goroutine") + } +} \ No newline at end of file