mirror of
https://github.com/Fishwaldo/auth2.git
synced 2025-06-03 12:21:22 +00:00
Improve test coverage to 81% and fix validation error handling
- Add comprehensive tests for pkg/log achieving 100% coverage - Add tests for basic auth provider factory and utils (98.5% coverage) - Fix missing HTTP status mapping for validation errors in internal/errors - Overall test coverage improved from 49.1% to 81.0% 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
b7537019d9
commit
802c1e137b
4 changed files with 1087 additions and 0 deletions
|
@ -175,6 +175,8 @@ func errorToHTTPStatus(err error) int {
|
|||
return http.StatusRequestTimeout
|
||||
case CodeUnavailable:
|
||||
return http.StatusServiceUnavailable
|
||||
case CodeValidation:
|
||||
return http.StatusBadRequest
|
||||
default:
|
||||
return http.StatusInternalServerError
|
||||
}
|
||||
|
|
245
pkg/auth/providers/basic/factory_test.go
Normal file
245
pkg/auth/providers/basic/factory_test.go
Normal file
|
@ -0,0 +1,245 @@
|
|||
package basic_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/Fishwaldo/auth2/pkg/auth/providers"
|
||||
"github.com/Fishwaldo/auth2/pkg/auth/providers/basic"
|
||||
"github.com/Fishwaldo/auth2/pkg/plugin/factory"
|
||||
"github.com/Fishwaldo/auth2/pkg/plugin/metadata"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewFactory(t *testing.T) {
|
||||
mockStore := &mockUserStore{}
|
||||
mockPwdUtils := &mockPasswordUtils{}
|
||||
|
||||
factory := basic.NewFactory(mockStore, mockPwdUtils)
|
||||
|
||||
assert.NotNil(t, factory)
|
||||
}
|
||||
|
||||
func TestFactory_Create(t *testing.T) {
|
||||
mockStore := &mockUserStore{}
|
||||
mockPwdUtils := &mockPasswordUtils{}
|
||||
factory := basic.NewFactory(mockStore, mockPwdUtils)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
id string
|
||||
config interface{}
|
||||
expectError bool
|
||||
errorContains string
|
||||
validate func(*testing.T, metadata.Provider)
|
||||
}{
|
||||
{
|
||||
name: "create with Config struct",
|
||||
id: "test-provider",
|
||||
config: &basic.Config{
|
||||
AccountLockThreshold: 10,
|
||||
AccountLockDuration: 60,
|
||||
RequireVerifiedEmail: true,
|
||||
},
|
||||
expectError: false,
|
||||
validate: func(t *testing.T, p metadata.Provider) {
|
||||
authProvider, ok := p.(*basic.Provider)
|
||||
require.True(t, ok)
|
||||
assert.NotNil(t, authProvider)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "create with nil config",
|
||||
id: "test-provider",
|
||||
config: nil,
|
||||
expectError: false,
|
||||
validate: func(t *testing.T, p metadata.Provider) {
|
||||
authProvider, ok := p.(*basic.Provider)
|
||||
require.True(t, ok)
|
||||
assert.NotNil(t, authProvider)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "create with map config",
|
||||
id: "test-provider",
|
||||
config: map[string]interface{}{
|
||||
"account_lock_threshold": 15,
|
||||
"account_lock_duration": 120,
|
||||
"require_verified_email": false,
|
||||
},
|
||||
expectError: false,
|
||||
validate: func(t *testing.T, p metadata.Provider) {
|
||||
authProvider, ok := p.(*basic.Provider)
|
||||
require.True(t, ok)
|
||||
assert.NotNil(t, authProvider)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "create with map config - partial values",
|
||||
id: "test-provider",
|
||||
config: map[string]interface{}{
|
||||
"account_lock_threshold": 20,
|
||||
},
|
||||
expectError: false,
|
||||
validate: func(t *testing.T, p metadata.Provider) {
|
||||
authProvider, ok := p.(*basic.Provider)
|
||||
require.True(t, ok)
|
||||
assert.NotNil(t, authProvider)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "create with map config - wrong types",
|
||||
id: "test-provider",
|
||||
config: map[string]interface{}{
|
||||
"account_lock_threshold": "not an int",
|
||||
"account_lock_duration": "not an int",
|
||||
"require_verified_email": "not a bool",
|
||||
},
|
||||
expectError: false, // Should not error, just use defaults
|
||||
validate: func(t *testing.T, p metadata.Provider) {
|
||||
authProvider, ok := p.(*basic.Provider)
|
||||
require.True(t, ok)
|
||||
assert.NotNil(t, authProvider)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "create with invalid config type",
|
||||
id: "test-provider",
|
||||
config: "invalid config",
|
||||
expectError: true,
|
||||
errorContains: "invalid configuration type",
|
||||
},
|
||||
{
|
||||
name: "create with invalid config struct",
|
||||
id: "test-provider",
|
||||
config: struct{ Field string }{Field: "value"},
|
||||
expectError: true,
|
||||
errorContains: "invalid configuration type",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
provider, err := factory.Create(tt.id, tt.config)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
if tt.errorContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorContains)
|
||||
}
|
||||
assert.Nil(t, provider)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, provider)
|
||||
if tt.validate != nil {
|
||||
tt.validate(t, provider)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFactory_GetType(t *testing.T) {
|
||||
mockStore := &mockUserStore{}
|
||||
mockPwdUtils := &mockPasswordUtils{}
|
||||
factory := basic.NewFactory(mockStore, mockPwdUtils)
|
||||
|
||||
providerType := factory.GetType()
|
||||
assert.Equal(t, metadata.ProviderTypeAuth, providerType)
|
||||
}
|
||||
|
||||
func TestFactory_GetMetadata(t *testing.T) {
|
||||
mockStore := &mockUserStore{}
|
||||
mockPwdUtils := &mockPasswordUtils{}
|
||||
factory := basic.NewFactory(mockStore, mockPwdUtils)
|
||||
|
||||
metadataList := factory.GetMetadata()
|
||||
|
||||
assert.Len(t, metadataList, 1)
|
||||
|
||||
md := metadataList[0]
|
||||
assert.Equal(t, "basic", md.ID)
|
||||
assert.Equal(t, metadata.ProviderTypeAuth, md.Type)
|
||||
assert.Equal(t, basic.ProviderName, md.Name)
|
||||
assert.Equal(t, basic.ProviderDescription, md.Description)
|
||||
assert.Equal(t, basic.ProviderVersion, md.Version)
|
||||
}
|
||||
|
||||
func TestRegister(t *testing.T) {
|
||||
mockRegistry := &mockRegistry{}
|
||||
mockStore := &mockUserStore{}
|
||||
mockPwdUtils := &mockPasswordUtils{}
|
||||
|
||||
// Set up expectation
|
||||
mockRegistry.On("RegisterAuthProviderFactory", "basic", mock.AnythingOfType("*basic.Factory")).Return(nil)
|
||||
|
||||
err := basic.Register(mockRegistry, mockStore, mockPwdUtils)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mockRegistry.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestRegister_Error(t *testing.T) {
|
||||
mockRegistry := &mockRegistry{}
|
||||
mockStore := &mockUserStore{}
|
||||
mockPwdUtils := &mockPasswordUtils{}
|
||||
|
||||
// Set up expectation for error
|
||||
expectedErr := assert.AnError
|
||||
mockRegistry.On("RegisterAuthProviderFactory", "basic", mock.AnythingOfType("*basic.Factory")).Return(expectedErr)
|
||||
|
||||
err := basic.Register(mockRegistry, mockStore, mockPwdUtils)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, expectedErr, err)
|
||||
|
||||
mockRegistry.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// Mock Registry for testing
|
||||
type mockRegistry struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockRegistry) RegisterAuthProvider(provider providers.AuthProvider) error {
|
||||
args := m.Called(provider)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *mockRegistry) GetAuthProvider(id string) (providers.AuthProvider, error) {
|
||||
args := m.Called(id)
|
||||
if args.Get(0) != nil {
|
||||
return args.Get(0).(providers.AuthProvider), args.Error(1)
|
||||
}
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockRegistry) RegisterAuthProviderFactory(id string, factory factory.Factory) error {
|
||||
args := m.Called(id, factory)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *mockRegistry) GetAuthProviderFactory(id string) (factory.Factory, error) {
|
||||
args := m.Called(id)
|
||||
if args.Get(0) != nil {
|
||||
return args.Get(0).(factory.Factory), args.Error(1)
|
||||
}
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockRegistry) ListAuthProviders() []providers.AuthProvider {
|
||||
args := m.Called()
|
||||
if args.Get(0) != nil {
|
||||
return args.Get(0).([]providers.AuthProvider)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockRegistry) CreateAuthProvider(ctx context.Context, factoryID, providerID string, config interface{}) (providers.AuthProvider, error) {
|
||||
args := m.Called(ctx, factoryID, providerID, config)
|
||||
if args.Get(0) != nil {
|
||||
return args.Get(0).(providers.AuthProvider), args.Error(1)
|
||||
}
|
||||
return nil, args.Error(1)
|
||||
}
|
323
pkg/auth/providers/basic/utils_test.go
Normal file
323
pkg/auth/providers/basic/utils_test.go
Normal file
|
@ -0,0 +1,323 @@
|
|||
package basic_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Fishwaldo/auth2/pkg/auth/providers/basic"
|
||||
"github.com/Fishwaldo/auth2/pkg/user"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestValidateAccount(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
user *user.User
|
||||
config *basic.Config
|
||||
expectError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "valid account",
|
||||
user: &user.User{
|
||||
Enabled: true,
|
||||
Locked: false,
|
||||
EmailVerified: true,
|
||||
},
|
||||
config: &basic.Config{
|
||||
RequireVerifiedEmail: true,
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "disabled account",
|
||||
user: &user.User{
|
||||
Enabled: false,
|
||||
},
|
||||
config: &basic.Config{},
|
||||
expectError: true,
|
||||
errorContains: "account is disabled",
|
||||
},
|
||||
{
|
||||
name: "locked account - no expiry",
|
||||
user: &user.User{
|
||||
Enabled: true,
|
||||
Locked: true,
|
||||
},
|
||||
config: &basic.Config{
|
||||
AccountLockDuration: 0,
|
||||
},
|
||||
expectError: true,
|
||||
errorContains: "account is locked",
|
||||
},
|
||||
{
|
||||
name: "locked account - not expired",
|
||||
user: &user.User{
|
||||
Enabled: true,
|
||||
Locked: true,
|
||||
LockoutTime: time.Now().Add(-5 * time.Minute),
|
||||
},
|
||||
config: &basic.Config{
|
||||
AccountLockDuration: 30, // 30 minutes
|
||||
},
|
||||
expectError: true,
|
||||
errorContains: "account is locked",
|
||||
},
|
||||
{
|
||||
name: "locked account - expired",
|
||||
user: &user.User{
|
||||
Enabled: true,
|
||||
Locked: true,
|
||||
LockoutTime: time.Now().Add(-60 * time.Minute),
|
||||
},
|
||||
config: &basic.Config{
|
||||
AccountLockDuration: 30, // 30 minutes
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "unverified email when required",
|
||||
user: &user.User{
|
||||
Enabled: true,
|
||||
Locked: false,
|
||||
EmailVerified: false,
|
||||
},
|
||||
config: &basic.Config{
|
||||
RequireVerifiedEmail: true,
|
||||
},
|
||||
expectError: true,
|
||||
errorContains: "email verification required",
|
||||
},
|
||||
{
|
||||
name: "unverified email when not required",
|
||||
user: &user.User{
|
||||
Enabled: true,
|
||||
Locked: false,
|
||||
EmailVerified: false,
|
||||
},
|
||||
config: &basic.Config{
|
||||
RequireVerifiedEmail: false,
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := basic.ValidateAccount(ctx, tt.user, tt.config)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
if tt.errorContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorContains)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckPasswordRequirements(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
user *user.User
|
||||
expectRequired bool
|
||||
expectReason string
|
||||
}{
|
||||
{
|
||||
name: "password change required",
|
||||
user: &user.User{
|
||||
RequirePasswordChange: true,
|
||||
},
|
||||
expectRequired: true,
|
||||
expectReason: "Password change required",
|
||||
},
|
||||
{
|
||||
name: "password change not required",
|
||||
user: &user.User{
|
||||
RequirePasswordChange: false,
|
||||
},
|
||||
expectRequired: false,
|
||||
expectReason: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
required, reason := basic.CheckPasswordRequirements(tt.user)
|
||||
|
||||
assert.Equal(t, tt.expectRequired, required)
|
||||
assert.Equal(t, tt.expectReason, reason)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessSuccessfulLogin(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mockStore := &mockUserStore{}
|
||||
|
||||
originalTime := time.Now()
|
||||
usr := &user.User{
|
||||
ID: "user123",
|
||||
FailedLoginAttempts: 5,
|
||||
LastLogin: originalTime.Add(-24 * time.Hour),
|
||||
}
|
||||
|
||||
// Set up expectation
|
||||
mockStore.On("Update", ctx, mock.MatchedBy(func(u *user.User) bool {
|
||||
return u.ID == usr.ID &&
|
||||
u.FailedLoginAttempts == 0 &&
|
||||
u.LastLogin.After(originalTime)
|
||||
})).Return(nil)
|
||||
|
||||
err := basic.ProcessSuccessfulLogin(ctx, mockStore, usr)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 0, usr.FailedLoginAttempts)
|
||||
assert.True(t, usr.LastLogin.After(originalTime))
|
||||
|
||||
mockStore.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestProcessSuccessfulLogin_UpdateError(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mockStore := &mockUserStore{}
|
||||
|
||||
usr := &user.User{
|
||||
ID: "user123",
|
||||
FailedLoginAttempts: 5,
|
||||
}
|
||||
|
||||
// Set up expectation for error
|
||||
expectedErr := assert.AnError
|
||||
mockStore.On("Update", ctx, mock.AnythingOfType("*user.User")).Return(expectedErr)
|
||||
|
||||
err := basic.ProcessSuccessfulLogin(ctx, mockStore, usr)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, expectedErr, err)
|
||||
|
||||
mockStore.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestProcessFailedLogin(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
initialUser *user.User
|
||||
config *basic.Config
|
||||
setupMock func(*mockUserStore)
|
||||
expectLocked bool
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "increment failed attempts - not locked",
|
||||
initialUser: &user.User{
|
||||
ID: "user123",
|
||||
FailedLoginAttempts: 2,
|
||||
},
|
||||
config: &basic.Config{
|
||||
AccountLockThreshold: 5,
|
||||
},
|
||||
setupMock: func(m *mockUserStore) {
|
||||
m.On("Update", ctx, mock.MatchedBy(func(u *user.User) bool {
|
||||
return u.ID == "user123" &&
|
||||
u.FailedLoginAttempts == 3 &&
|
||||
!u.Locked
|
||||
})).Return(nil)
|
||||
},
|
||||
expectLocked: false,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "increment failed attempts - lock account",
|
||||
initialUser: &user.User{
|
||||
ID: "user123",
|
||||
FailedLoginAttempts: 4,
|
||||
},
|
||||
config: &basic.Config{
|
||||
AccountLockThreshold: 5,
|
||||
},
|
||||
setupMock: func(m *mockUserStore) {
|
||||
m.On("Update", ctx, mock.MatchedBy(func(u *user.User) bool {
|
||||
return u.ID == "user123" &&
|
||||
u.FailedLoginAttempts == 5 &&
|
||||
u.Locked &&
|
||||
!u.LockoutTime.IsZero() &&
|
||||
u.LockoutReason == "Too many failed login attempts"
|
||||
})).Return(nil)
|
||||
},
|
||||
expectLocked: true,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "no lock threshold",
|
||||
initialUser: &user.User{
|
||||
ID: "user123",
|
||||
FailedLoginAttempts: 10,
|
||||
},
|
||||
config: &basic.Config{
|
||||
AccountLockThreshold: 0,
|
||||
},
|
||||
setupMock: func(m *mockUserStore) {
|
||||
m.On("Update", ctx, mock.MatchedBy(func(u *user.User) bool {
|
||||
return u.ID == "user123" &&
|
||||
u.FailedLoginAttempts == 11 &&
|
||||
!u.Locked
|
||||
})).Return(nil)
|
||||
},
|
||||
expectLocked: false,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "update error",
|
||||
initialUser: &user.User{
|
||||
ID: "user123",
|
||||
FailedLoginAttempts: 0,
|
||||
},
|
||||
config: &basic.Config{
|
||||
AccountLockThreshold: 5,
|
||||
},
|
||||
setupMock: func(m *mockUserStore) {
|
||||
m.On("Update", ctx, mock.AnythingOfType("*user.User")).Return(assert.AnError)
|
||||
},
|
||||
expectLocked: false,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockStore := &mockUserStore{}
|
||||
tt.setupMock(mockStore)
|
||||
|
||||
originalTime := time.Now()
|
||||
err := basic.ProcessFailedLogin(ctx, mockStore, tt.initialUser, tt.config)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.True(t, tt.initialUser.LastFailedLogin.After(originalTime) ||
|
||||
tt.initialUser.LastFailedLogin.Equal(originalTime))
|
||||
|
||||
if tt.expectLocked {
|
||||
assert.True(t, tt.initialUser.Locked)
|
||||
assert.False(t, tt.initialUser.LockoutTime.IsZero())
|
||||
assert.Equal(t, "Too many failed login attempts", tt.initialUser.LockoutReason)
|
||||
} else {
|
||||
assert.False(t, tt.initialUser.Locked)
|
||||
}
|
||||
}
|
||||
|
||||
mockStore.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
517
pkg/log/log_test.go
Normal file
517
pkg/log/log_test.go
Normal file
|
@ -0,0 +1,517 @@
|
|||
package log_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Fishwaldo/auth2/pkg/log"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDefaultConfig(t *testing.T) {
|
||||
cfg := log.DefaultConfig()
|
||||
|
||||
assert.Equal(t, log.LevelInfo, cfg.Level)
|
||||
assert.Equal(t, "json", cfg.Format)
|
||||
assert.NotNil(t, cfg.Writer)
|
||||
assert.False(t, cfg.AddSource)
|
||||
assert.Nil(t, cfg.ContextKeys)
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *log.Config
|
||||
}{
|
||||
{
|
||||
name: "with nil config uses default",
|
||||
config: nil,
|
||||
},
|
||||
{
|
||||
name: "with json format",
|
||||
config: &log.Config{
|
||||
Level: log.LevelDebug,
|
||||
Format: "json",
|
||||
Writer: &bytes.Buffer{},
|
||||
AddSource: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with text format",
|
||||
config: &log.Config{
|
||||
Level: log.LevelWarn,
|
||||
Format: "text",
|
||||
Writer: &bytes.Buffer{},
|
||||
AddSource: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
logger := log.New(tt.config)
|
||||
assert.NotNil(t, logger)
|
||||
assert.NotNil(t, logger.Logger)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogger_WithField(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := log.New(&log.Config{
|
||||
Level: log.LevelInfo,
|
||||
Format: "json",
|
||||
Writer: &buf,
|
||||
})
|
||||
|
||||
// Create logger with field
|
||||
loggerWithField := logger.WithField("key", "value")
|
||||
loggerWithField.Info("test message")
|
||||
|
||||
// Parse log output
|
||||
var logEntry map[string]interface{}
|
||||
err := json.Unmarshal(buf.Bytes(), &logEntry)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "test message", logEntry["msg"])
|
||||
assert.Equal(t, "value", logEntry["key"])
|
||||
}
|
||||
|
||||
func TestLogger_WithFields(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := log.New(&log.Config{
|
||||
Level: log.LevelInfo,
|
||||
Format: "json",
|
||||
Writer: &buf,
|
||||
})
|
||||
|
||||
// Create logger with multiple fields
|
||||
fields := map[string]interface{}{
|
||||
"field1": "value1",
|
||||
"field2": 42,
|
||||
"field3": true,
|
||||
}
|
||||
loggerWithFields := logger.WithFields(fields)
|
||||
loggerWithFields.Info("test message")
|
||||
|
||||
// Parse log output
|
||||
var logEntry map[string]interface{}
|
||||
err := json.Unmarshal(buf.Bytes(), &logEntry)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "test message", logEntry["msg"])
|
||||
assert.Equal(t, "value1", logEntry["field1"])
|
||||
assert.Equal(t, float64(42), logEntry["field2"]) // JSON unmarshals numbers as float64
|
||||
assert.Equal(t, true, logEntry["field3"])
|
||||
}
|
||||
|
||||
func TestLogger_WithContext(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := log.New(&log.Config{
|
||||
Level: log.LevelInfo,
|
||||
Format: "json",
|
||||
Writer: &buf,
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
loggerWithCtx := logger.WithContext(ctx)
|
||||
assert.NotNil(t, loggerWithCtx)
|
||||
|
||||
// Test that it returns a logger
|
||||
loggerWithCtx.Info("test message")
|
||||
|
||||
// Verify log was written
|
||||
assert.NotEmpty(t, buf.String())
|
||||
}
|
||||
|
||||
func TestLogger_Levels(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
logLevel log.Level
|
||||
msgLevel string
|
||||
shouldLog bool
|
||||
}{
|
||||
{
|
||||
name: "debug level logs debug",
|
||||
logLevel: log.LevelDebug,
|
||||
msgLevel: "debug",
|
||||
shouldLog: true,
|
||||
},
|
||||
{
|
||||
name: "info level skips debug",
|
||||
logLevel: log.LevelInfo,
|
||||
msgLevel: "debug",
|
||||
shouldLog: false,
|
||||
},
|
||||
{
|
||||
name: "info level logs info",
|
||||
logLevel: log.LevelInfo,
|
||||
msgLevel: "info",
|
||||
shouldLog: true,
|
||||
},
|
||||
{
|
||||
name: "warn level logs warn",
|
||||
logLevel: log.LevelWarn,
|
||||
msgLevel: "warn",
|
||||
shouldLog: true,
|
||||
},
|
||||
{
|
||||
name: "error level logs error",
|
||||
logLevel: log.LevelError,
|
||||
msgLevel: "error",
|
||||
shouldLog: true,
|
||||
},
|
||||
{
|
||||
name: "error level skips warn",
|
||||
logLevel: log.LevelError,
|
||||
msgLevel: "warn",
|
||||
shouldLog: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := log.New(&log.Config{
|
||||
Level: tt.logLevel,
|
||||
Format: "json",
|
||||
Writer: &buf,
|
||||
})
|
||||
|
||||
msg := "test message"
|
||||
switch tt.msgLevel {
|
||||
case "debug":
|
||||
logger.Debug(msg)
|
||||
case "info":
|
||||
logger.Info(msg)
|
||||
case "warn":
|
||||
logger.Warn(msg)
|
||||
case "error":
|
||||
logger.Error(msg)
|
||||
}
|
||||
|
||||
if tt.shouldLog {
|
||||
assert.NotEmpty(t, buf.String())
|
||||
// Verify the message was logged
|
||||
var logEntry map[string]interface{}
|
||||
err := json.Unmarshal(buf.Bytes(), &logEntry)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, msg, logEntry["msg"])
|
||||
} else {
|
||||
assert.Empty(t, buf.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogger_TextFormat(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := log.New(&log.Config{
|
||||
Level: log.LevelInfo,
|
||||
Format: "text",
|
||||
Writer: &buf,
|
||||
})
|
||||
|
||||
logger.Info("test message", "key", "value")
|
||||
|
||||
output := buf.String()
|
||||
assert.Contains(t, output, "test message")
|
||||
assert.Contains(t, output, "key=value")
|
||||
}
|
||||
|
||||
func TestLogger_AddSource(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := log.New(&log.Config{
|
||||
Level: log.LevelInfo,
|
||||
Format: "json",
|
||||
Writer: &buf,
|
||||
AddSource: true,
|
||||
})
|
||||
|
||||
logger.Info("test message")
|
||||
|
||||
// Parse log output
|
||||
var logEntry map[string]interface{}
|
||||
err := json.Unmarshal(buf.Bytes(), &logEntry)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should have source information
|
||||
assert.Contains(t, logEntry, "source")
|
||||
source := logEntry["source"].(map[string]interface{})
|
||||
assert.Contains(t, source, "file")
|
||||
assert.Contains(t, source, "line")
|
||||
}
|
||||
|
||||
func TestSetDefault(t *testing.T) {
|
||||
// Save original default logger
|
||||
originalDefault := log.Default()
|
||||
defer log.SetDefault(originalDefault)
|
||||
|
||||
var buf bytes.Buffer
|
||||
newLogger := log.New(&log.Config{
|
||||
Level: log.LevelInfo,
|
||||
Format: "json",
|
||||
Writer: &buf,
|
||||
})
|
||||
|
||||
log.SetDefault(newLogger)
|
||||
|
||||
// Verify the default was set
|
||||
assert.Equal(t, newLogger, log.Default())
|
||||
|
||||
// Test convenience functions use the new default
|
||||
log.Info("test from default")
|
||||
|
||||
assert.NotEmpty(t, buf.String())
|
||||
var logEntry map[string]interface{}
|
||||
err := json.Unmarshal(buf.Bytes(), &logEntry)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "test from default", logEntry["msg"])
|
||||
}
|
||||
|
||||
func TestConvenienceMethods(t *testing.T) {
|
||||
// Save original default logger
|
||||
originalDefault := log.Default()
|
||||
defer log.SetDefault(originalDefault)
|
||||
|
||||
var buf bytes.Buffer
|
||||
logger := log.New(&log.Config{
|
||||
Level: log.LevelDebug,
|
||||
Format: "json",
|
||||
Writer: &buf,
|
||||
})
|
||||
log.SetDefault(logger)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
logFunc func(string, ...interface{})
|
||||
level string
|
||||
}{
|
||||
{
|
||||
name: "debug",
|
||||
logFunc: log.Debug,
|
||||
level: "DEBUG",
|
||||
},
|
||||
{
|
||||
name: "info",
|
||||
logFunc: log.Info,
|
||||
level: "INFO",
|
||||
},
|
||||
{
|
||||
name: "warn",
|
||||
logFunc: log.Warn,
|
||||
level: "WARN",
|
||||
},
|
||||
{
|
||||
name: "error",
|
||||
logFunc: log.Error,
|
||||
level: "ERROR",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
buf.Reset()
|
||||
|
||||
tt.logFunc("test message", "key", "value")
|
||||
|
||||
var logEntry map[string]interface{}
|
||||
err := json.Unmarshal(buf.Bytes(), &logEntry)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "test message", logEntry["msg"])
|
||||
assert.Equal(t, tt.level, logEntry["level"])
|
||||
assert.Equal(t, "value", logEntry["key"])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetLoggerFromContext(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupContext func() context.Context
|
||||
expectDefault bool
|
||||
}{
|
||||
{
|
||||
name: "context with logger",
|
||||
setupContext: func() context.Context {
|
||||
logger := log.New(&log.Config{
|
||||
Level: log.LevelInfo,
|
||||
Format: "json",
|
||||
Writer: &bytes.Buffer{},
|
||||
})
|
||||
return log.ContextWithLogger(context.Background(), logger)
|
||||
},
|
||||
expectDefault: false,
|
||||
},
|
||||
{
|
||||
name: "context without logger",
|
||||
setupContext: func() context.Context {
|
||||
return context.Background()
|
||||
},
|
||||
expectDefault: true,
|
||||
},
|
||||
{
|
||||
name: "nil context",
|
||||
setupContext: func() context.Context {
|
||||
return nil
|
||||
},
|
||||
expectDefault: true,
|
||||
},
|
||||
{
|
||||
name: "context with wrong type",
|
||||
setupContext: func() context.Context {
|
||||
type wrongKey struct{}
|
||||
return context.WithValue(context.Background(), wrongKey{}, "not a logger")
|
||||
},
|
||||
expectDefault: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := tt.setupContext()
|
||||
logger := log.GetLoggerFromContext(ctx)
|
||||
|
||||
assert.NotNil(t, logger)
|
||||
if tt.expectDefault {
|
||||
assert.Equal(t, log.Default(), logger)
|
||||
} else {
|
||||
// Should be the logger from context, not default
|
||||
assert.NotNil(t, logger)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestContextWithLogger(t *testing.T) {
|
||||
logger := log.New(&log.Config{
|
||||
Level: log.LevelInfo,
|
||||
Format: "json",
|
||||
Writer: &bytes.Buffer{},
|
||||
})
|
||||
|
||||
ctx := log.ContextWithLogger(context.Background(), logger)
|
||||
|
||||
// Retrieve logger from context
|
||||
retrievedLogger := log.GetLoggerFromContext(ctx)
|
||||
assert.Equal(t, logger, retrievedLogger)
|
||||
}
|
||||
|
||||
func TestLogger_WithFieldsEdgeCases(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := log.New(&log.Config{
|
||||
Level: log.LevelInfo,
|
||||
Format: "json",
|
||||
Writer: &buf,
|
||||
})
|
||||
|
||||
// Test with empty fields map
|
||||
emptyFields := map[string]interface{}{}
|
||||
loggerWithEmpty := logger.WithFields(emptyFields)
|
||||
loggerWithEmpty.Info("test empty fields")
|
||||
|
||||
var logEntry map[string]interface{}
|
||||
err := json.Unmarshal(buf.Bytes(), &logEntry)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "test empty fields", logEntry["msg"])
|
||||
|
||||
// Test with nil values
|
||||
buf.Reset()
|
||||
nilFields := map[string]interface{}{
|
||||
"nilField": nil,
|
||||
"strField": "value",
|
||||
}
|
||||
loggerWithNil := logger.WithFields(nilFields)
|
||||
loggerWithNil.Info("test nil fields")
|
||||
|
||||
err = json.Unmarshal(buf.Bytes(), &logEntry)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "test nil fields", logEntry["msg"])
|
||||
assert.Nil(t, logEntry["nilField"])
|
||||
assert.Equal(t, "value", logEntry["strField"])
|
||||
}
|
||||
|
||||
func TestLogger_OutputFormats(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
format string
|
||||
checkOutput func(t *testing.T, output string)
|
||||
}{
|
||||
{
|
||||
name: "json format with special characters",
|
||||
format: "json",
|
||||
checkOutput: func(t *testing.T, output string) {
|
||||
var logEntry map[string]interface{}
|
||||
err := json.Unmarshal([]byte(output), &logEntry)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "test \"quoted\" message", logEntry["msg"])
|
||||
assert.Equal(t, "value with\nnewline", logEntry["special"])
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "text format with special characters",
|
||||
format: "text",
|
||||
checkOutput: func(t *testing.T, output string) {
|
||||
assert.Contains(t, output, "test \\\"quoted\\\" message")
|
||||
assert.Contains(t, output, "special=\"value with\\nnewline\"")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := log.New(&log.Config{
|
||||
Level: log.LevelInfo,
|
||||
Format: tt.format,
|
||||
Writer: &buf,
|
||||
})
|
||||
|
||||
logger.Info("test \"quoted\" message", "special", "value with\nnewline")
|
||||
|
||||
tt.checkOutput(t, strings.TrimSpace(buf.String()))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogger_ConcurrentUse(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := log.New(&log.Config{
|
||||
Level: log.LevelInfo,
|
||||
Format: "json",
|
||||
Writer: &buf,
|
||||
})
|
||||
|
||||
done := make(chan bool, 10)
|
||||
|
||||
// Launch multiple goroutines logging concurrently
|
||||
for i := 0; i < 10; i++ {
|
||||
go func(id int) {
|
||||
logger.WithField("goroutine", id).Info("concurrent log")
|
||||
done <- true
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all goroutines to complete
|
||||
for i := 0; i < 10; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
// Verify all logs were written
|
||||
lines := strings.Split(strings.TrimSpace(buf.String()), "\n")
|
||||
assert.Len(t, lines, 10)
|
||||
|
||||
// Verify each line is valid JSON
|
||||
for _, line := range lines {
|
||||
var logEntry map[string]interface{}
|
||||
err := json.Unmarshal([]byte(line), &logEntry)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "concurrent log", logEntry["msg"])
|
||||
assert.Contains(t, logEntry, "goroutine")
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue