Merge branch 'feat/cache-hashed-api-key-storage' into 'main'
feat(cache): implement hashed API key storage and retrieval See merge request unboundsoftware/schemas!628
This commit was merged in pull request #632.
This commit is contained in:
Vendored
+26
-16
@@ -14,7 +14,7 @@ import (
|
||||
type Cache struct {
|
||||
organizations map[string]domain.Organization
|
||||
users map[string][]string
|
||||
apiKeys map[string]domain.APIKey
|
||||
apiKeys map[string]domain.APIKey // keyed by organizationId-name
|
||||
services map[string]map[string]map[string]struct{}
|
||||
subGraphs map[string]string
|
||||
lastUpdate map[string]string
|
||||
@@ -22,15 +22,17 @@ type Cache struct {
|
||||
}
|
||||
|
||||
func (c *Cache) OrganizationByAPIKey(apiKey string) *domain.Organization {
|
||||
key, exists := c.apiKeys[apiKey]
|
||||
if !exists {
|
||||
return nil
|
||||
// Find the API key by comparing hashes
|
||||
for _, key := range c.apiKeys {
|
||||
if hash.CompareAPIKey(key.Key, apiKey) {
|
||||
org, exists := c.organizations[key.OrganizationId]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
return &org
|
||||
}
|
||||
}
|
||||
org, exists := c.organizations[key.OrganizationId]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
return &org
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Cache) OrganizationsByUser(sub string) []domain.Organization {
|
||||
@@ -43,11 +45,13 @@ func (c *Cache) OrganizationsByUser(sub string) []domain.Organization {
|
||||
}
|
||||
|
||||
func (c *Cache) ApiKeyByKey(key string) *domain.APIKey {
|
||||
k, exists := c.apiKeys[hash.String(key)]
|
||||
if !exists {
|
||||
return nil
|
||||
// Find the API key by comparing hashes
|
||||
for _, apiKey := range c.apiKeys {
|
||||
if hash.CompareAPIKey(apiKey.Key, key) {
|
||||
return &apiKey
|
||||
}
|
||||
}
|
||||
return &k
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Cache) Services(orgId, ref, lastUpdate string) ([]string, string) {
|
||||
@@ -76,14 +80,15 @@ func (c *Cache) Update(msg any, _ goamqp.Headers) (any, error) {
|
||||
key := domain.APIKey{
|
||||
Name: m.Name,
|
||||
OrganizationId: m.OrganizationId,
|
||||
Key: m.Key,
|
||||
Key: m.Key, // This is now the hashed key
|
||||
Refs: m.Refs,
|
||||
Read: m.Read,
|
||||
Publish: m.Publish,
|
||||
CreatedBy: m.Initiator,
|
||||
CreatedAt: m.When(),
|
||||
}
|
||||
c.apiKeys[m.Key] = key
|
||||
// Use composite key: organizationId-name
|
||||
c.apiKeys[apiKeyId(m.OrganizationId, m.Name)] = key
|
||||
org := c.organizations[m.OrganizationId]
|
||||
org.APIKeys = append(org.APIKeys, key)
|
||||
c.organizations[m.OrganizationId] = org
|
||||
@@ -93,7 +98,8 @@ func (c *Cache) Update(msg any, _ goamqp.Headers) (any, error) {
|
||||
c.organizations[m.ID.String()] = *m
|
||||
c.addUser(m.CreatedBy, *m)
|
||||
for _, k := range m.APIKeys {
|
||||
c.apiKeys[k.Key] = k
|
||||
// Use composite key: organizationId-name
|
||||
c.apiKeys[apiKeyId(k.OrganizationId, k.Name)] = k
|
||||
}
|
||||
case *domain.SubGraph:
|
||||
c.updateSubGraph(m.OrganizationId, m.Ref, m.ID.String(), m.Service, m.ChangedAt)
|
||||
@@ -143,3 +149,7 @@ func refKey(orgId string, ref string) string {
|
||||
func subGraphKey(orgId string, ref string, service string) string {
|
||||
return fmt.Sprintf("%s<->%s<->%s", orgId, ref, service)
|
||||
}
|
||||
|
||||
func apiKeyId(orgId string, name string) string {
|
||||
return fmt.Sprintf("%s<->%s", orgId, name)
|
||||
}
|
||||
|
||||
@@ -30,7 +30,6 @@ import (
|
||||
"gitlab.com/unboundsoftware/schemas/domain"
|
||||
"gitlab.com/unboundsoftware/schemas/graph"
|
||||
"gitlab.com/unboundsoftware/schemas/graph/generated"
|
||||
"gitlab.com/unboundsoftware/schemas/hash"
|
||||
"gitlab.com/unboundsoftware/schemas/logging"
|
||||
"gitlab.com/unboundsoftware/schemas/middleware"
|
||||
"gitlab.com/unboundsoftware/schemas/monitoring"
|
||||
@@ -217,8 +216,8 @@ func start(closeEvents chan error, logger *slog.Logger, connectToAmqpFunc func(u
|
||||
logger.Info("WebSocket connection with API key", "has_key", true)
|
||||
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
||||
|
||||
// Look up organization by API key (same logic as auth middleware)
|
||||
if organization := serviceCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
|
||||
// Look up organization by API key (cache handles hash comparison)
|
||||
if organization := serviceCache.OrganizationByAPIKey(apiKey); organization != nil {
|
||||
logger.Info("WebSocket: Organization found for API key", "org_id", organization.ID.String())
|
||||
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
||||
} else {
|
||||
|
||||
+47
-21
@@ -17,11 +17,18 @@ import (
|
||||
|
||||
// MockCache is a mock implementation for testing
|
||||
type MockCache struct {
|
||||
organizations map[string]*domain.Organization
|
||||
organizations map[string]*domain.Organization // keyed by orgId-name composite
|
||||
apiKeys map[string]string // maps orgId-name to hashed key
|
||||
}
|
||||
|
||||
func (m *MockCache) OrganizationByAPIKey(apiKey string) *domain.Organization {
|
||||
return m.organizations[apiKey]
|
||||
func (m *MockCache) OrganizationByAPIKey(plainKey string) *domain.Organization {
|
||||
// Find organization by comparing plaintext key with stored hash
|
||||
for compositeKey, hashedKey := range m.apiKeys {
|
||||
if hash.CompareAPIKey(hashedKey, plainKey) {
|
||||
return m.organizations[compositeKey]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestWebSocketInitFunc_WithValidAPIKey(t *testing.T) {
|
||||
@@ -35,11 +42,17 @@ func TestWebSocketInitFunc_WithValidAPIKey(t *testing.T) {
|
||||
}
|
||||
|
||||
apiKey := "test-api-key-123"
|
||||
hashedKey := hash.String(apiKey)
|
||||
hashedKey, err := hash.APIKey(apiKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
compositeKey := orgID.String() + "-test-key"
|
||||
|
||||
mockCache := &MockCache{
|
||||
organizations: map[string]*domain.Organization{
|
||||
hashedKey: org,
|
||||
compositeKey: org,
|
||||
},
|
||||
apiKeys: map[string]string{
|
||||
compositeKey: hashedKey,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -49,8 +62,8 @@ func TestWebSocketInitFunc_WithValidAPIKey(t *testing.T) {
|
||||
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
||||
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
||||
|
||||
// Look up organization by API key
|
||||
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
|
||||
// Look up organization by API key (cache handles hash comparison)
|
||||
if organization := mockCache.OrganizationByAPIKey(apiKey); organization != nil {
|
||||
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
||||
}
|
||||
}
|
||||
@@ -91,6 +104,7 @@ func TestWebSocketInitFunc_WithInvalidAPIKey(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := &MockCache{
|
||||
organizations: map[string]*domain.Organization{},
|
||||
apiKeys: map[string]string{},
|
||||
}
|
||||
|
||||
apiKey := "invalid-api-key"
|
||||
@@ -101,8 +115,8 @@ func TestWebSocketInitFunc_WithInvalidAPIKey(t *testing.T) {
|
||||
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
||||
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
||||
|
||||
// Look up organization by API key
|
||||
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
|
||||
// Look up organization by API key (cache handles hash comparison)
|
||||
if organization := mockCache.OrganizationByAPIKey(apiKey); organization != nil {
|
||||
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
||||
}
|
||||
}
|
||||
@@ -137,6 +151,7 @@ func TestWebSocketInitFunc_WithoutAPIKey(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := &MockCache{
|
||||
organizations: map[string]*domain.Organization{},
|
||||
apiKeys: map[string]string{},
|
||||
}
|
||||
|
||||
// Create InitFunc
|
||||
@@ -145,8 +160,8 @@ func TestWebSocketInitFunc_WithoutAPIKey(t *testing.T) {
|
||||
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
||||
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
||||
|
||||
// Look up organization by API key
|
||||
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
|
||||
// Look up organization by API key (cache handles hash comparison)
|
||||
if organization := mockCache.OrganizationByAPIKey(apiKey); organization != nil {
|
||||
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
||||
}
|
||||
}
|
||||
@@ -176,6 +191,7 @@ func TestWebSocketInitFunc_WithEmptyAPIKey(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := &MockCache{
|
||||
organizations: map[string]*domain.Organization{},
|
||||
apiKeys: map[string]string{},
|
||||
}
|
||||
|
||||
// Create InitFunc
|
||||
@@ -184,8 +200,8 @@ func TestWebSocketInitFunc_WithEmptyAPIKey(t *testing.T) {
|
||||
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
||||
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
||||
|
||||
// Look up organization by API key
|
||||
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
|
||||
// Look up organization by API key (cache handles hash comparison)
|
||||
if organization := mockCache.OrganizationByAPIKey(apiKey); organization != nil {
|
||||
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
||||
}
|
||||
}
|
||||
@@ -217,6 +233,7 @@ func TestWebSocketInitFunc_WithWrongTypeAPIKey(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := &MockCache{
|
||||
organizations: map[string]*domain.Organization{},
|
||||
apiKeys: map[string]string{},
|
||||
}
|
||||
|
||||
// Create InitFunc
|
||||
@@ -225,8 +242,8 @@ func TestWebSocketInitFunc_WithWrongTypeAPIKey(t *testing.T) {
|
||||
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
||||
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
||||
|
||||
// Look up organization by API key
|
||||
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
|
||||
// Look up organization by API key (cache handles hash comparison)
|
||||
if organization := mockCache.OrganizationByAPIKey(apiKey); organization != nil {
|
||||
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
||||
}
|
||||
}
|
||||
@@ -274,13 +291,22 @@ func TestWebSocketInitFunc_WithMultipleOrganizations(t *testing.T) {
|
||||
|
||||
apiKey1 := "api-key-org-1"
|
||||
apiKey2 := "api-key-org-2"
|
||||
hashedKey1 := hash.String(apiKey1)
|
||||
hashedKey2 := hash.String(apiKey2)
|
||||
hashedKey1, err := hash.APIKey(apiKey1)
|
||||
require.NoError(t, err)
|
||||
hashedKey2, err := hash.APIKey(apiKey2)
|
||||
require.NoError(t, err)
|
||||
|
||||
compositeKey1 := org1ID.String() + "-key1"
|
||||
compositeKey2 := org2ID.String() + "-key2"
|
||||
|
||||
mockCache := &MockCache{
|
||||
organizations: map[string]*domain.Organization{
|
||||
hashedKey1: org1,
|
||||
hashedKey2: org2,
|
||||
compositeKey1: org1,
|
||||
compositeKey2: org2,
|
||||
},
|
||||
apiKeys: map[string]string{
|
||||
compositeKey1: hashedKey1,
|
||||
compositeKey2: hashedKey2,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -290,8 +316,8 @@ func TestWebSocketInitFunc_WithMultipleOrganizations(t *testing.T) {
|
||||
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
||||
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
||||
|
||||
// Look up organization by API key
|
||||
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
|
||||
// Look up organization by API key (cache handles hash comparison)
|
||||
if organization := mockCache.OrganizationByAPIKey(apiKey); organization != nil {
|
||||
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
||||
}
|
||||
}
|
||||
|
||||
+12
-1
@@ -56,9 +56,20 @@ func (a AddAPIKey) Validate(_ context.Context, aggregate eventsourced.Aggregate)
|
||||
}
|
||||
|
||||
func (a AddAPIKey) Event(context.Context) eventsourced.Event {
|
||||
// Hash the API key using bcrypt for secure storage
|
||||
// Note: We can't return an error here, but bcrypt errors are extremely rare
|
||||
// (only if system runs out of memory or bcrypt cost is invalid)
|
||||
// We use a fixed cost of 12 which is always valid
|
||||
hashedKey, err := hash.APIKey(a.Key)
|
||||
if err != nil {
|
||||
// This should never happen with bcrypt cost 12, but if it does,
|
||||
// we'll store an empty hash which will fail validation later
|
||||
hashedKey = ""
|
||||
}
|
||||
|
||||
return &APIKeyAdded{
|
||||
Name: a.Name,
|
||||
Key: hash.String(a.Key),
|
||||
Key: hashedKey,
|
||||
Refs: a.Refs,
|
||||
Read: a.Read,
|
||||
Publish: a.Publish,
|
||||
|
||||
+24
-11
@@ -2,10 +2,13 @@ package domain
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gitlab.com/unboundsoftware/eventsourced/eventsourced"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"gitlab.com/unboundsoftware/schemas/hash"
|
||||
)
|
||||
|
||||
func TestAddAPIKey_Event(t *testing.T) {
|
||||
@@ -24,7 +27,6 @@ func TestAddAPIKey_Event(t *testing.T) {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want eventsourced.Event
|
||||
}{
|
||||
{
|
||||
name: "event",
|
||||
@@ -37,14 +39,6 @@ func TestAddAPIKey_Event(t *testing.T) {
|
||||
Initiator: "jim@example.org",
|
||||
},
|
||||
args: args{},
|
||||
want: &APIKeyAdded{
|
||||
Name: "test",
|
||||
Key: "dXNfYWtfMTIzNDU2Nzg5MDEyMzQ1NuOwxEKY/BwUmvv0yJlvuSQnrkHkZJuTTKSVmRt4UrhV",
|
||||
Refs: []string{"Example@dev"},
|
||||
Read: true,
|
||||
Publish: true,
|
||||
Initiator: "jim@example.org",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
@@ -57,7 +51,26 @@ func TestAddAPIKey_Event(t *testing.T) {
|
||||
Publish: tt.fields.Publish,
|
||||
Initiator: tt.fields.Initiator,
|
||||
}
|
||||
assert.Equalf(t, tt.want, a.Event(tt.args.in0), "Event(%v)", tt.args.in0)
|
||||
event := a.Event(tt.args.in0)
|
||||
require.NotNil(t, event)
|
||||
|
||||
// Cast to APIKeyAdded to verify fields
|
||||
apiKeyEvent, ok := event.(*APIKeyAdded)
|
||||
require.True(t, ok, "Event should be *APIKeyAdded")
|
||||
|
||||
// Verify non-key fields match exactly
|
||||
assert.Equal(t, tt.fields.Name, apiKeyEvent.Name)
|
||||
assert.Equal(t, tt.fields.Refs, apiKeyEvent.Refs)
|
||||
assert.Equal(t, tt.fields.Read, apiKeyEvent.Read)
|
||||
assert.Equal(t, tt.fields.Publish, apiKeyEvent.Publish)
|
||||
assert.Equal(t, tt.fields.Initiator, apiKeyEvent.Initiator)
|
||||
|
||||
// Verify the key is hashed correctly (bcrypt format)
|
||||
assert.True(t, strings.HasPrefix(apiKeyEvent.Key, "$2"), "Key should be bcrypt hashed")
|
||||
assert.NotEqual(t, tt.fields.Key, apiKeyEvent.Key, "Key should be hashed, not plaintext")
|
||||
|
||||
// Verify the hash matches the original key
|
||||
assert.True(t, hash.CompareAPIKey(apiKeyEvent.Key, tt.fields.Key), "Hashed key should match original")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,6 +31,7 @@ require (
|
||||
go.opentelemetry.io/otel/sdk/log v0.14.0
|
||||
go.opentelemetry.io/otel/sdk/metric v1.38.0
|
||||
go.opentelemetry.io/otel/trace v1.38.0
|
||||
golang.org/x/crypto v0.43.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
|
||||
+1
-1
@@ -38,7 +38,7 @@ func ToGqlAPIKeys(keys []domain.APIKey) []*model.APIKey {
|
||||
result[i] = &model.APIKey{
|
||||
ID: apiKeyId(k.OrganizationId, k.Name),
|
||||
Name: k.Name,
|
||||
Key: &k.Key,
|
||||
Key: nil, // Never return the hashed key - only return plaintext on creation
|
||||
Organization: nil,
|
||||
Refs: k.Refs,
|
||||
Read: k.Read,
|
||||
|
||||
@@ -3,9 +3,72 @@ package hash
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// String creates a SHA256 hash of a string (legacy, for non-sensitive data)
|
||||
func String(s string) string {
|
||||
encoded := sha256.New().Sum([]byte(s))
|
||||
return base64.StdEncoding.EncodeToString(encoded)
|
||||
}
|
||||
|
||||
// APIKey hashes an API key using bcrypt for secure storage
|
||||
// Cost of 12 provides a good balance between security and performance
|
||||
func APIKey(key string) (string, error) {
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(key), 12)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(hash), nil
|
||||
}
|
||||
|
||||
// CompareAPIKey compares a plaintext API key with a hash
|
||||
// Supports both bcrypt (new) and SHA256 (legacy) hashes for backwards compatibility
|
||||
// Returns true if they match, false otherwise
|
||||
//
|
||||
// Migration Strategy:
|
||||
// Old API keys stored with SHA256 will continue to work. To upgrade them to bcrypt:
|
||||
// 1. Keys are automatically upgraded when users re-authenticate (if implemented)
|
||||
// 2. Or, run a one-time migration using MigrateAPIKeyHash when convenient
|
||||
func CompareAPIKey(hashedKey, plainKey string) bool {
|
||||
// Bcrypt hashes start with $2a$, $2b$, or $2y$
|
||||
// If the hash starts with $2, it's a bcrypt hash
|
||||
if len(hashedKey) > 2 && hashedKey[0] == '$' && hashedKey[1] == '2' {
|
||||
// New bcrypt hash
|
||||
err := bcrypt.CompareHashAndPassword([]byte(hashedKey), []byte(plainKey))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// Legacy SHA256 hash - compare using the old method
|
||||
legacyHash := String(plainKey)
|
||||
return hashedKey == legacyHash
|
||||
}
|
||||
|
||||
// IsLegacyHash returns true if the hash is a legacy SHA256 hash (not bcrypt)
|
||||
func IsLegacyHash(hashedKey string) bool {
|
||||
return len(hashedKey) <= 2 || hashedKey[0] != '$' || hashedKey[1] != '2'
|
||||
}
|
||||
|
||||
// MigrateAPIKeyHash can be used to upgrade a legacy SHA256 hash to bcrypt
|
||||
// This is useful for one-time migrations of existing keys
|
||||
// Returns the new bcrypt hash if the key is legacy, otherwise returns the original
|
||||
func MigrateAPIKeyHash(currentHash, plainKey string) (string, bool, error) {
|
||||
// If already bcrypt, no migration needed
|
||||
if !IsLegacyHash(currentHash) {
|
||||
return currentHash, false, nil
|
||||
}
|
||||
|
||||
// Verify the legacy hash is correct before migrating
|
||||
if !CompareAPIKey(currentHash, plainKey) {
|
||||
return "", false, nil // Invalid key, don't migrate
|
||||
}
|
||||
|
||||
// Generate new bcrypt hash
|
||||
newHash, err := APIKey(plainKey)
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
|
||||
return newHash, true, nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,169 @@
|
||||
package hash
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAPIKey(t *testing.T) {
|
||||
key := "test_api_key_12345" // gitleaks:allow
|
||||
|
||||
hash1, err := APIKey(key)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, hash1)
|
||||
assert.NotEqual(t, key, hash1, "Hash should not equal plaintext")
|
||||
|
||||
// Bcrypt hashes should start with $2
|
||||
assert.True(t, strings.HasPrefix(hash1, "$2"), "Should be a bcrypt hash")
|
||||
|
||||
// Same key should produce different hashes (due to salt)
|
||||
hash2, err := APIKey(key)
|
||||
require.NoError(t, err)
|
||||
assert.NotEqual(t, hash1, hash2, "Bcrypt should produce different hashes with different salts")
|
||||
}
|
||||
|
||||
func TestCompareAPIKey_Bcrypt(t *testing.T) {
|
||||
key := "test_api_key_12345" // gitleaks:allow
|
||||
|
||||
hash, err := APIKey(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Correct key should match
|
||||
assert.True(t, CompareAPIKey(hash, key))
|
||||
|
||||
// Wrong key should not match
|
||||
assert.False(t, CompareAPIKey(hash, "wrong_key"))
|
||||
}
|
||||
|
||||
func TestCompareAPIKey_Legacy(t *testing.T) {
|
||||
key := "test_api_key_12345" // gitleaks:allow
|
||||
|
||||
// Create a legacy SHA256 hash
|
||||
legacyHash := String(key)
|
||||
|
||||
// Should still work with legacy hashes
|
||||
assert.True(t, CompareAPIKey(legacyHash, key))
|
||||
|
||||
// Wrong key should not match
|
||||
assert.False(t, CompareAPIKey(legacyHash, "wrong_key"))
|
||||
}
|
||||
|
||||
func TestCompareAPIKey_BackwardCompatibility(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
hashFunc func(string) string
|
||||
expectOK bool
|
||||
}{
|
||||
{
|
||||
name: "bcrypt hash",
|
||||
hashFunc: func(k string) string {
|
||||
h, _ := APIKey(k)
|
||||
return h
|
||||
},
|
||||
expectOK: true,
|
||||
},
|
||||
{
|
||||
name: "legacy SHA256 hash",
|
||||
hashFunc: func(k string) string {
|
||||
return String(k)
|
||||
},
|
||||
expectOK: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
key := "test_key_123"
|
||||
hash := tt.hashFunc(key)
|
||||
|
||||
result := CompareAPIKey(hash, key)
|
||||
assert.Equal(t, tt.expectOK, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestString(t *testing.T) {
|
||||
// Test that String function still works (for non-sensitive data)
|
||||
input := "test_string"
|
||||
hash1 := String(input)
|
||||
hash2 := String(input)
|
||||
|
||||
// SHA256 should be deterministic
|
||||
assert.Equal(t, hash1, hash2)
|
||||
assert.NotEmpty(t, hash1)
|
||||
assert.NotEqual(t, input, hash1)
|
||||
}
|
||||
|
||||
func TestIsLegacyHash(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
hash string
|
||||
isLegacy bool
|
||||
}{
|
||||
{
|
||||
name: "bcrypt hash",
|
||||
hash: "$2a$12$abcdefghijklmnopqrstuv",
|
||||
isLegacy: false,
|
||||
},
|
||||
{
|
||||
name: "SHA256 hash",
|
||||
hash: "dXNfYWtfMTIzNDU2Nzg5MDEyMzQ1NuOwxEKY",
|
||||
isLegacy: true,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
hash: "",
|
||||
isLegacy: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.isLegacy, IsLegacyHash(tt.hash))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateAPIKeyHash(t *testing.T) {
|
||||
plainKey := "test_api_key_123"
|
||||
|
||||
t.Run("migrate legacy hash", func(t *testing.T) {
|
||||
// Create a legacy SHA256 hash
|
||||
legacyHash := String(plainKey)
|
||||
|
||||
// Migrate it
|
||||
newHash, migrated, err := MigrateAPIKeyHash(legacyHash, plainKey)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, migrated, "Should indicate migration occurred")
|
||||
assert.NotEqual(t, legacyHash, newHash, "New hash should differ from legacy")
|
||||
assert.True(t, strings.HasPrefix(newHash, "$2"), "New hash should be bcrypt")
|
||||
|
||||
// Verify new hash works
|
||||
assert.True(t, CompareAPIKey(newHash, plainKey))
|
||||
})
|
||||
|
||||
t.Run("no migration needed for bcrypt", func(t *testing.T) {
|
||||
// Create a bcrypt hash
|
||||
bcryptHash, err := APIKey(plainKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to migrate it
|
||||
newHash, migrated, err := MigrateAPIKeyHash(bcryptHash, plainKey)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, migrated, "Should not migrate bcrypt hash")
|
||||
assert.Equal(t, bcryptHash, newHash, "Hash should remain unchanged")
|
||||
})
|
||||
|
||||
t.Run("invalid key does not migrate", func(t *testing.T) {
|
||||
legacyHash := String("correct_key")
|
||||
|
||||
// Try to migrate with wrong plaintext
|
||||
newHash, migrated, err := MigrateAPIKeyHash(legacyHash, "wrong_key")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, migrated, "Should not migrate invalid key")
|
||||
assert.Empty(t, newHash, "Should return empty for invalid key")
|
||||
})
|
||||
}
|
||||
+2
-3
@@ -9,7 +9,6 @@ import (
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
|
||||
"gitlab.com/unboundsoftware/schemas/domain"
|
||||
"gitlab.com/unboundsoftware/schemas/hash"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -49,8 +48,8 @@ func (m *AuthMiddleware) Handler(next http.Handler) http.Handler {
|
||||
_, _ = w.Write([]byte("Invalid API Key format"))
|
||||
return
|
||||
}
|
||||
hashedKey := hash.String(apiKey)
|
||||
organization := m.cache.OrganizationByAPIKey(hashedKey)
|
||||
// Cache handles hash comparison internally
|
||||
organization := m.cache.OrganizationByAPIKey(apiKey)
|
||||
if organization != nil {
|
||||
ctx = context.WithValue(ctx, OrganizationKey, *organization)
|
||||
}
|
||||
|
||||
+10
-13
@@ -15,7 +15,6 @@ import (
|
||||
"gitlab.com/unboundsoftware/eventsourced/eventsourced"
|
||||
|
||||
"gitlab.com/unboundsoftware/schemas/domain"
|
||||
"gitlab.com/unboundsoftware/schemas/hash"
|
||||
)
|
||||
|
||||
// MockCache is a mock implementation of the Cache interface
|
||||
@@ -45,9 +44,9 @@ func TestAuthMiddleware_Handler_WithValidAPIKey(t *testing.T) {
|
||||
}
|
||||
|
||||
apiKey := "test-api-key-123"
|
||||
hashedKey := hash.String(apiKey)
|
||||
|
||||
mockCache.On("OrganizationByAPIKey", hashedKey).Return(expectedOrg)
|
||||
// Mock expects plaintext key (cache handles hashing internally)
|
||||
mockCache.On("OrganizationByAPIKey", apiKey).Return(expectedOrg)
|
||||
|
||||
// Create a test handler that checks the context
|
||||
var capturedOrg *domain.Organization
|
||||
@@ -84,9 +83,9 @@ func TestAuthMiddleware_Handler_WithInvalidAPIKey(t *testing.T) {
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
apiKey := "invalid-api-key"
|
||||
hashedKey := hash.String(apiKey)
|
||||
|
||||
mockCache.On("OrganizationByAPIKey", hashedKey).Return(nil)
|
||||
// Mock expects plaintext key (cache handles hashing internally)
|
||||
mockCache.On("OrganizationByAPIKey", apiKey).Return(nil)
|
||||
|
||||
// Create a test handler that checks the context
|
||||
var capturedOrg *domain.Organization
|
||||
@@ -120,9 +119,8 @@ func TestAuthMiddleware_Handler_WithoutAPIKey(t *testing.T) {
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
// The middleware always hashes the API key (even if empty) and calls the cache
|
||||
emptyKeyHash := hash.String("")
|
||||
mockCache.On("OrganizationByAPIKey", emptyKeyHash).Return(nil)
|
||||
// The middleware passes the plaintext API key (cache handles hashing)
|
||||
mockCache.On("OrganizationByAPIKey", "").Return(nil)
|
||||
|
||||
// Create a test handler that checks the context
|
||||
var capturedOrg *domain.Organization
|
||||
@@ -153,9 +151,8 @@ func TestAuthMiddleware_Handler_WithValidJWT(t *testing.T) {
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
// The middleware always hashes the API key (even if empty) and calls the cache
|
||||
emptyKeyHash := hash.String("")
|
||||
mockCache.On("OrganizationByAPIKey", emptyKeyHash).Return(nil)
|
||||
// The middleware passes the plaintext API key (cache handles hashing)
|
||||
mockCache.On("OrganizationByAPIKey", "").Return(nil)
|
||||
|
||||
userID := "user-123"
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
@@ -251,13 +248,13 @@ func TestAuthMiddleware_Handler_BothJWTAndAPIKey(t *testing.T) {
|
||||
|
||||
userID := "user-123"
|
||||
apiKey := "test-api-key-123"
|
||||
hashedKey := hash.String(apiKey)
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"sub": userID,
|
||||
})
|
||||
|
||||
mockCache.On("OrganizationByAPIKey", hashedKey).Return(expectedOrg)
|
||||
// Mock expects plaintext key (cache handles hashing internally)
|
||||
mockCache.On("OrganizationByAPIKey", apiKey).Return(expectedOrg)
|
||||
|
||||
// Create a test handler that checks both user and organization in context
|
||||
var capturedUser string
|
||||
|
||||
Reference in New Issue
Block a user