fix: enhance API key handling and logging in middleware #627
@@ -0,0 +1,336 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/99designs/gqlgen/graphql/handler/transport"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"gitlab.com/unboundsoftware/eventsourced/eventsourced"
|
||||||
|
|
||||||
|
"gitlab.com/unboundsoftware/schemas/domain"
|
||||||
|
"gitlab.com/unboundsoftware/schemas/hash"
|
||||||
|
"gitlab.com/unboundsoftware/schemas/middleware"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockCache is a mock implementation for testing
|
||||||
|
type MockCache struct {
|
||||||
|
organizations map[string]*domain.Organization
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockCache) OrganizationByAPIKey(apiKey string) *domain.Organization {
|
||||||
|
return m.organizations[apiKey]
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWebSocketInitFunc_WithValidAPIKey(t *testing.T) {
|
||||||
|
// Setup
|
||||||
|
orgID := uuid.New()
|
||||||
|
org := &domain.Organization{
|
||||||
|
BaseAggregate: eventsourced.BaseAggregate{
|
||||||
|
ID: eventsourced.IdFromString(orgID.String()),
|
||||||
|
},
|
||||||
|
Name: "Test Organization",
|
||||||
|
}
|
||||||
|
|
||||||
|
apiKey := "test-api-key-123"
|
||||||
|
hashedKey := hash.String(apiKey)
|
||||||
|
|
||||||
|
mockCache := &MockCache{
|
||||||
|
organizations: map[string]*domain.Organization{
|
||||||
|
hashedKey: org,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create InitFunc (simulating the WebSocket InitFunc logic)
|
||||||
|
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
||||||
|
// Extract API key from WebSocket connection_init payload
|
||||||
|
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
||||||
|
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
||||||
|
|
||||||
|
// Look up organization by API key
|
||||||
|
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
|
||||||
|
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ctx, &initPayload, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test
|
||||||
|
ctx := context.Background()
|
||||||
|
initPayload := transport.InitPayload{
|
||||||
|
"X-Api-Key": apiKey,
|
||||||
|
}
|
||||||
|
|
||||||
|
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, resultPayload)
|
||||||
|
|
||||||
|
// Check API key is in context
|
||||||
|
if value := resultCtx.Value(middleware.ApiKey); value != nil {
|
||||||
|
assert.Equal(t, apiKey, value.(string))
|
||||||
|
} else {
|
||||||
|
t.Fatal("API key not found in context")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check organization is in context
|
||||||
|
if value := resultCtx.Value(middleware.OrganizationKey); value != nil {
|
||||||
|
capturedOrg, ok := value.(domain.Organization)
|
||||||
|
require.True(t, ok, "Organization should be of correct type")
|
||||||
|
assert.Equal(t, org.Name, capturedOrg.Name)
|
||||||
|
assert.Equal(t, org.ID.String(), capturedOrg.ID.String())
|
||||||
|
} else {
|
||||||
|
t.Fatal("Organization not found in context")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWebSocketInitFunc_WithInvalidAPIKey(t *testing.T) {
|
||||||
|
// Setup
|
||||||
|
mockCache := &MockCache{
|
||||||
|
organizations: map[string]*domain.Organization{},
|
||||||
|
}
|
||||||
|
|
||||||
|
apiKey := "invalid-api-key"
|
||||||
|
|
||||||
|
// Create InitFunc
|
||||||
|
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
||||||
|
// Extract API key from WebSocket connection_init payload
|
||||||
|
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
||||||
|
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
||||||
|
|
||||||
|
// Look up organization by API key
|
||||||
|
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
|
||||||
|
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ctx, &initPayload, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test
|
||||||
|
ctx := context.Background()
|
||||||
|
initPayload := transport.InitPayload{
|
||||||
|
"X-Api-Key": apiKey,
|
||||||
|
}
|
||||||
|
|
||||||
|
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, resultPayload)
|
||||||
|
|
||||||
|
// Check API key is in context
|
||||||
|
if value := resultCtx.Value(middleware.ApiKey); value != nil {
|
||||||
|
assert.Equal(t, apiKey, value.(string))
|
||||||
|
} else {
|
||||||
|
t.Fatal("API key not found in context")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check organization is NOT in context (since API key is invalid)
|
||||||
|
value := resultCtx.Value(middleware.OrganizationKey)
|
||||||
|
assert.Nil(t, value, "Organization should not be set for invalid API key")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWebSocketInitFunc_WithoutAPIKey(t *testing.T) {
|
||||||
|
// Setup
|
||||||
|
mockCache := &MockCache{
|
||||||
|
organizations: map[string]*domain.Organization{},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create InitFunc
|
||||||
|
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
||||||
|
// Extract API key from WebSocket connection_init payload
|
||||||
|
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
||||||
|
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
||||||
|
|
||||||
|
// Look up organization by API key
|
||||||
|
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
|
||||||
|
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ctx, &initPayload, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test
|
||||||
|
ctx := context.Background()
|
||||||
|
initPayload := transport.InitPayload{}
|
||||||
|
|
||||||
|
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, resultPayload)
|
||||||
|
|
||||||
|
// Check API key is NOT in context
|
||||||
|
value := resultCtx.Value(middleware.ApiKey)
|
||||||
|
assert.Nil(t, value, "API key should not be set when not provided")
|
||||||
|
|
||||||
|
// Check organization is NOT in context
|
||||||
|
value = resultCtx.Value(middleware.OrganizationKey)
|
||||||
|
assert.Nil(t, value, "Organization should not be set when API key is not provided")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWebSocketInitFunc_WithEmptyAPIKey(t *testing.T) {
|
||||||
|
// Setup
|
||||||
|
mockCache := &MockCache{
|
||||||
|
organizations: map[string]*domain.Organization{},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create InitFunc
|
||||||
|
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
||||||
|
// Extract API key from WebSocket connection_init payload
|
||||||
|
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
||||||
|
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
||||||
|
|
||||||
|
// Look up organization by API key
|
||||||
|
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
|
||||||
|
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ctx, &initPayload, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test
|
||||||
|
ctx := context.Background()
|
||||||
|
initPayload := transport.InitPayload{
|
||||||
|
"X-Api-Key": "", // Empty string
|
||||||
|
}
|
||||||
|
|
||||||
|
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, resultPayload)
|
||||||
|
|
||||||
|
// Check API key is NOT in context (because empty string fails the condition)
|
||||||
|
value := resultCtx.Value(middleware.ApiKey)
|
||||||
|
assert.Nil(t, value, "API key should not be set when empty")
|
||||||
|
|
||||||
|
// Check organization is NOT in context
|
||||||
|
value = resultCtx.Value(middleware.OrganizationKey)
|
||||||
|
assert.Nil(t, value, "Organization should not be set when API key is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWebSocketInitFunc_WithWrongTypeAPIKey(t *testing.T) {
|
||||||
|
// Setup
|
||||||
|
mockCache := &MockCache{
|
||||||
|
organizations: map[string]*domain.Organization{},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create InitFunc
|
||||||
|
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
||||||
|
// Extract API key from WebSocket connection_init payload
|
||||||
|
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
||||||
|
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
||||||
|
|
||||||
|
// Look up organization by API key
|
||||||
|
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
|
||||||
|
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ctx, &initPayload, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test
|
||||||
|
ctx := context.Background()
|
||||||
|
initPayload := transport.InitPayload{
|
||||||
|
"X-Api-Key": 12345, // Wrong type (int instead of string)
|
||||||
|
}
|
||||||
|
|
||||||
|
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, resultPayload)
|
||||||
|
|
||||||
|
// Check API key is NOT in context (type assertion fails)
|
||||||
|
value := resultCtx.Value(middleware.ApiKey)
|
||||||
|
assert.Nil(t, value, "API key should not be set when wrong type")
|
||||||
|
|
||||||
|
// Check organization is NOT in context
|
||||||
|
value = resultCtx.Value(middleware.OrganizationKey)
|
||||||
|
assert.Nil(t, value, "Organization should not be set when API key has wrong type")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWebSocketInitFunc_WithMultipleOrganizations(t *testing.T) {
|
||||||
|
// Setup - create multiple organizations
|
||||||
|
org1ID := uuid.New()
|
||||||
|
org1 := &domain.Organization{
|
||||||
|
BaseAggregate: eventsourced.BaseAggregate{
|
||||||
|
ID: eventsourced.IdFromString(org1ID.String()),
|
||||||
|
},
|
||||||
|
Name: "Organization 1",
|
||||||
|
}
|
||||||
|
|
||||||
|
org2ID := uuid.New()
|
||||||
|
org2 := &domain.Organization{
|
||||||
|
BaseAggregate: eventsourced.BaseAggregate{
|
||||||
|
ID: eventsourced.IdFromString(org2ID.String()),
|
||||||
|
},
|
||||||
|
Name: "Organization 2",
|
||||||
|
}
|
||||||
|
|
||||||
|
apiKey1 := "api-key-org-1"
|
||||||
|
apiKey2 := "api-key-org-2"
|
||||||
|
hashedKey1 := hash.String(apiKey1)
|
||||||
|
hashedKey2 := hash.String(apiKey2)
|
||||||
|
|
||||||
|
mockCache := &MockCache{
|
||||||
|
organizations: map[string]*domain.Organization{
|
||||||
|
hashedKey1: org1,
|
||||||
|
hashedKey2: org2,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create InitFunc
|
||||||
|
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
||||||
|
// Extract API key from WebSocket connection_init payload
|
||||||
|
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
||||||
|
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
||||||
|
|
||||||
|
// Look up organization by API key
|
||||||
|
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
|
||||||
|
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ctx, &initPayload, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with first API key
|
||||||
|
ctx1 := context.Background()
|
||||||
|
initPayload1 := transport.InitPayload{
|
||||||
|
"X-Api-Key": apiKey1,
|
||||||
|
}
|
||||||
|
|
||||||
|
resultCtx1, _, err := initFunc(ctx1, initPayload1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
if value := resultCtx1.Value(middleware.OrganizationKey); value != nil {
|
||||||
|
capturedOrg, ok := value.(domain.Organization)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, org1.Name, capturedOrg.Name)
|
||||||
|
assert.Equal(t, org1.ID.String(), capturedOrg.ID.String())
|
||||||
|
} else {
|
||||||
|
t.Fatal("Organization 1 not found in context")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with second API key
|
||||||
|
ctx2 := context.Background()
|
||||||
|
initPayload2 := transport.InitPayload{
|
||||||
|
"X-Api-Key": apiKey2,
|
||||||
|
}
|
||||||
|
|
||||||
|
resultCtx2, _, err := initFunc(ctx2, initPayload2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
if value := resultCtx2.Value(middleware.OrganizationKey); value != nil {
|
||||||
|
capturedOrg, ok := value.(domain.Organization)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, org2.Name, capturedOrg.Name)
|
||||||
|
assert.Equal(t, org2.ID.String(), capturedOrg.ID.String())
|
||||||
|
} else {
|
||||||
|
t.Fatal("Organization 2 not found in context")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,467 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
mw "github.com/auth0/go-jwt-middleware/v2"
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"gitlab.com/unboundsoftware/eventsourced/eventsourced"
|
||||||
|
|
||||||
|
"gitlab.com/unboundsoftware/schemas/domain"
|
||||||
|
"gitlab.com/unboundsoftware/schemas/hash"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockCache is a mock implementation of the Cache interface
|
||||||
|
type MockCache struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockCache) OrganizationByAPIKey(apiKey string) *domain.Organization {
|
||||||
|
args := m.Called(apiKey)
|
||||||
|
if args.Get(0) == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return args.Get(0).(*domain.Organization)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthMiddleware_Handler_WithValidAPIKey(t *testing.T) {
|
||||||
|
// Setup
|
||||||
|
mockCache := new(MockCache)
|
||||||
|
authMiddleware := NewAuth(mockCache)
|
||||||
|
|
||||||
|
orgID := uuid.New()
|
||||||
|
expectedOrg := &domain.Organization{
|
||||||
|
BaseAggregate: eventsourced.BaseAggregate{
|
||||||
|
ID: eventsourced.IdFromString(orgID.String()),
|
||||||
|
},
|
||||||
|
Name: "Test Organization",
|
||||||
|
}
|
||||||
|
|
||||||
|
apiKey := "test-api-key-123"
|
||||||
|
hashedKey := hash.String(apiKey)
|
||||||
|
|
||||||
|
mockCache.On("OrganizationByAPIKey", hashedKey).Return(expectedOrg)
|
||||||
|
|
||||||
|
// Create a test handler that checks the context
|
||||||
|
var capturedOrg *domain.Organization
|
||||||
|
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if org := r.Context().Value(OrganizationKey); org != nil {
|
||||||
|
if o, ok := org.(domain.Organization); ok {
|
||||||
|
capturedOrg = &o
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create request with API key in context
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
ctx := context.WithValue(req.Context(), ApiKey, apiKey)
|
||||||
|
req = req.WithContext(ctx)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Execute
|
||||||
|
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
require.NotNil(t, capturedOrg)
|
||||||
|
assert.Equal(t, expectedOrg.Name, capturedOrg.Name)
|
||||||
|
assert.Equal(t, expectedOrg.ID.String(), capturedOrg.ID.String())
|
||||||
|
mockCache.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthMiddleware_Handler_WithInvalidAPIKey(t *testing.T) {
|
||||||
|
// Setup
|
||||||
|
mockCache := new(MockCache)
|
||||||
|
authMiddleware := NewAuth(mockCache)
|
||||||
|
|
||||||
|
apiKey := "invalid-api-key"
|
||||||
|
hashedKey := hash.String(apiKey)
|
||||||
|
|
||||||
|
mockCache.On("OrganizationByAPIKey", hashedKey).Return(nil)
|
||||||
|
|
||||||
|
// Create a test handler that checks the context
|
||||||
|
var capturedOrg *domain.Organization
|
||||||
|
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if org := r.Context().Value(OrganizationKey); org != nil {
|
||||||
|
if o, ok := org.(domain.Organization); ok {
|
||||||
|
capturedOrg = &o
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create request with API key in context
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
ctx := context.WithValue(req.Context(), ApiKey, apiKey)
|
||||||
|
req = req.WithContext(ctx)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Execute
|
||||||
|
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
assert.Nil(t, capturedOrg, "Organization should not be set for invalid API key")
|
||||||
|
mockCache.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthMiddleware_Handler_WithoutAPIKey(t *testing.T) {
|
||||||
|
// Setup
|
||||||
|
mockCache := new(MockCache)
|
||||||
|
authMiddleware := NewAuth(mockCache)
|
||||||
|
|
||||||
|
// The middleware always hashes the API key (even if empty) and calls the cache
|
||||||
|
emptyKeyHash := hash.String("")
|
||||||
|
mockCache.On("OrganizationByAPIKey", emptyKeyHash).Return(nil)
|
||||||
|
|
||||||
|
// Create a test handler that checks the context
|
||||||
|
var capturedOrg *domain.Organization
|
||||||
|
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if org := r.Context().Value(OrganizationKey); org != nil {
|
||||||
|
if o, ok := org.(domain.Organization); ok {
|
||||||
|
capturedOrg = &o
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create request without API key
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Execute
|
||||||
|
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
assert.Nil(t, capturedOrg, "Organization should not be set without API key")
|
||||||
|
mockCache.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthMiddleware_Handler_WithValidJWT(t *testing.T) {
|
||||||
|
// Setup
|
||||||
|
mockCache := new(MockCache)
|
||||||
|
authMiddleware := NewAuth(mockCache)
|
||||||
|
|
||||||
|
// The middleware always hashes the API key (even if empty) and calls the cache
|
||||||
|
emptyKeyHash := hash.String("")
|
||||||
|
mockCache.On("OrganizationByAPIKey", emptyKeyHash).Return(nil)
|
||||||
|
|
||||||
|
userID := "user-123"
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||||
|
"sub": userID,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create a test handler that checks the context
|
||||||
|
var capturedUser string
|
||||||
|
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if user := r.Context().Value(UserKey); user != nil {
|
||||||
|
if u, ok := user.(string); ok {
|
||||||
|
capturedUser = u
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create request with JWT token in context
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
ctx := context.WithValue(req.Context(), mw.ContextKey{}, token)
|
||||||
|
req = req.WithContext(ctx)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Execute
|
||||||
|
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
assert.Equal(t, userID, capturedUser)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthMiddleware_Handler_APIKeyErrorHandling(t *testing.T) {
|
||||||
|
// Setup
|
||||||
|
mockCache := new(MockCache)
|
||||||
|
authMiddleware := NewAuth(mockCache)
|
||||||
|
|
||||||
|
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create request with invalid API key type in context
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
ctx := context.WithValue(req.Context(), ApiKey, 12345) // Invalid type
|
||||||
|
req = req.WithContext(ctx)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Execute
|
||||||
|
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||||
|
assert.Contains(t, rec.Body.String(), "Invalid API Key format")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthMiddleware_Handler_JWTErrorHandling(t *testing.T) {
|
||||||
|
// Setup
|
||||||
|
mockCache := new(MockCache)
|
||||||
|
authMiddleware := NewAuth(mockCache)
|
||||||
|
|
||||||
|
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create request with invalid JWT token type in context
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
ctx := context.WithValue(req.Context(), mw.ContextKey{}, "not-a-token") // Invalid type
|
||||||
|
req = req.WithContext(ctx)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Execute
|
||||||
|
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||||
|
assert.Contains(t, rec.Body.String(), "Invalid JWT token format")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthMiddleware_Handler_BothJWTAndAPIKey(t *testing.T) {
|
||||||
|
// Setup
|
||||||
|
mockCache := new(MockCache)
|
||||||
|
authMiddleware := NewAuth(mockCache)
|
||||||
|
|
||||||
|
orgID := uuid.New()
|
||||||
|
expectedOrg := &domain.Organization{
|
||||||
|
BaseAggregate: eventsourced.BaseAggregate{
|
||||||
|
ID: eventsourced.IdFromString(orgID.String()),
|
||||||
|
},
|
||||||
|
Name: "Test Organization",
|
||||||
|
}
|
||||||
|
|
||||||
|
userID := "user-123"
|
||||||
|
apiKey := "test-api-key-123"
|
||||||
|
hashedKey := hash.String(apiKey)
|
||||||
|
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||||
|
"sub": userID,
|
||||||
|
})
|
||||||
|
|
||||||
|
mockCache.On("OrganizationByAPIKey", hashedKey).Return(expectedOrg)
|
||||||
|
|
||||||
|
// Create a test handler that checks both user and organization in context
|
||||||
|
var capturedUser string
|
||||||
|
var capturedOrg *domain.Organization
|
||||||
|
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if user := r.Context().Value(UserKey); user != nil {
|
||||||
|
if u, ok := user.(string); ok {
|
||||||
|
capturedUser = u
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if org := r.Context().Value(OrganizationKey); org != nil {
|
||||||
|
if o, ok := org.(domain.Organization); ok {
|
||||||
|
capturedOrg = &o
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create request with both JWT and API key in context
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
ctx := context.WithValue(req.Context(), mw.ContextKey{}, token)
|
||||||
|
ctx = context.WithValue(ctx, ApiKey, apiKey)
|
||||||
|
req = req.WithContext(ctx)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Execute
|
||||||
|
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
assert.Equal(t, userID, capturedUser)
|
||||||
|
require.NotNil(t, capturedOrg)
|
||||||
|
assert.Equal(t, expectedOrg.Name, capturedOrg.Name)
|
||||||
|
mockCache.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUserFromContext(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
ctx context.Context
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "with valid user",
|
||||||
|
ctx: context.WithValue(context.Background(), UserKey, "user-123"),
|
||||||
|
expected: "user-123",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "without user",
|
||||||
|
ctx: context.Background(),
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with invalid type",
|
||||||
|
ctx: context.WithValue(context.Background(), UserKey, 123),
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := UserFromContext(tt.ctx)
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOrganizationFromContext(t *testing.T) {
|
||||||
|
orgID := uuid.New()
|
||||||
|
org := domain.Organization{
|
||||||
|
BaseAggregate: eventsourced.BaseAggregate{
|
||||||
|
ID: eventsourced.IdFromString(orgID.String()),
|
||||||
|
},
|
||||||
|
Name: "Test Org",
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
ctx context.Context
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "with valid organization",
|
||||||
|
ctx: context.WithValue(context.Background(), OrganizationKey, org),
|
||||||
|
expected: orgID.String(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "without organization",
|
||||||
|
ctx: context.Background(),
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with invalid type",
|
||||||
|
ctx: context.WithValue(context.Background(), OrganizationKey, "not-an-org"),
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := OrganizationFromContext(tt.ctx)
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthMiddleware_Directive_RequiresUser(t *testing.T) {
|
||||||
|
mockCache := new(MockCache)
|
||||||
|
authMiddleware := NewAuth(mockCache)
|
||||||
|
|
||||||
|
requireUser := true
|
||||||
|
|
||||||
|
// Test with user present
|
||||||
|
ctx := context.WithValue(context.Background(), UserKey, "user-123")
|
||||||
|
_, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||||
|
return "success", nil
|
||||||
|
}, &requireUser, nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Test without user
|
||||||
|
ctx = context.Background()
|
||||||
|
_, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||||
|
return "success", nil
|
||||||
|
}, &requireUser, nil)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "no user available in request")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthMiddleware_Directive_RequiresOrganization(t *testing.T) {
|
||||||
|
mockCache := new(MockCache)
|
||||||
|
authMiddleware := NewAuth(mockCache)
|
||||||
|
|
||||||
|
requireOrg := true
|
||||||
|
orgID := uuid.New()
|
||||||
|
org := domain.Organization{
|
||||||
|
BaseAggregate: eventsourced.BaseAggregate{
|
||||||
|
ID: eventsourced.IdFromString(orgID.String()),
|
||||||
|
},
|
||||||
|
Name: "Test Org",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with organization present
|
||||||
|
ctx := context.WithValue(context.Background(), OrganizationKey, org)
|
||||||
|
_, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||||
|
return "success", nil
|
||||||
|
}, nil, &requireOrg)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Test without organization
|
||||||
|
ctx = context.Background()
|
||||||
|
_, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||||
|
return "success", nil
|
||||||
|
}, nil, &requireOrg)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "no organization available in request")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthMiddleware_Directive_RequiresBoth(t *testing.T) {
|
||||||
|
mockCache := new(MockCache)
|
||||||
|
authMiddleware := NewAuth(mockCache)
|
||||||
|
|
||||||
|
requireUser := true
|
||||||
|
requireOrg := true
|
||||||
|
orgID := uuid.New()
|
||||||
|
org := domain.Organization{
|
||||||
|
BaseAggregate: eventsourced.BaseAggregate{
|
||||||
|
ID: eventsourced.IdFromString(orgID.String()),
|
||||||
|
},
|
||||||
|
Name: "Test Org",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with both present
|
||||||
|
ctx := context.WithValue(context.Background(), UserKey, "user-123")
|
||||||
|
ctx = context.WithValue(ctx, OrganizationKey, org)
|
||||||
|
_, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||||
|
return "success", nil
|
||||||
|
}, &requireUser, &requireOrg)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Test with only user
|
||||||
|
ctx = context.WithValue(context.Background(), UserKey, "user-123")
|
||||||
|
_, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||||
|
return "success", nil
|
||||||
|
}, &requireUser, &requireOrg)
|
||||||
|
assert.Error(t, err)
|
||||||
|
|
||||||
|
// Test with only organization
|
||||||
|
ctx = context.WithValue(context.Background(), OrganizationKey, org)
|
||||||
|
_, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||||
|
return "success", nil
|
||||||
|
}, &requireUser, &requireOrg)
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthMiddleware_Directive_NoRequirements(t *testing.T) {
|
||||||
|
mockCache := new(MockCache)
|
||||||
|
authMiddleware := NewAuth(mockCache)
|
||||||
|
|
||||||
|
// Test with no requirements
|
||||||
|
ctx := context.Background()
|
||||||
|
result, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||||
|
return "success", nil
|
||||||
|
}, nil, nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "success", result)
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user