diff --git a/cmd/service/service_test.go b/cmd/service/service_test.go new file mode 100644 index 0000000..c9680d9 --- /dev/null +++ b/cmd/service/service_test.go @@ -0,0 +1,336 @@ +package main + +import ( + "context" + "testing" + + "github.com/99designs/gqlgen/graphql/handler/transport" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gitlab.com/unboundsoftware/eventsourced/eventsourced" + + "gitlab.com/unboundsoftware/schemas/domain" + "gitlab.com/unboundsoftware/schemas/hash" + "gitlab.com/unboundsoftware/schemas/middleware" +) + +// MockCache is a mock implementation for testing +type MockCache struct { + organizations map[string]*domain.Organization +} + +func (m *MockCache) OrganizationByAPIKey(apiKey string) *domain.Organization { + return m.organizations[apiKey] +} + +func TestWebSocketInitFunc_WithValidAPIKey(t *testing.T) { + // Setup + orgID := uuid.New() + org := &domain.Organization{ + BaseAggregate: eventsourced.BaseAggregate{ + ID: eventsourced.IdFromString(orgID.String()), + }, + Name: "Test Organization", + } + + apiKey := "test-api-key-123" + hashedKey := hash.String(apiKey) + + mockCache := &MockCache{ + organizations: map[string]*domain.Organization{ + hashedKey: org, + }, + } + + // Create InitFunc (simulating the WebSocket InitFunc logic) + initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) { + // Extract API key from WebSocket connection_init payload + if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" { + ctx = context.WithValue(ctx, middleware.ApiKey, apiKey) + + // Look up organization by API key + if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil { + ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization) + } + } + return ctx, &initPayload, nil + } + + // Test + ctx := context.Background() + initPayload := transport.InitPayload{ + "X-Api-Key": apiKey, + } + + resultCtx, resultPayload, err := initFunc(ctx, initPayload) + + // Assert + require.NoError(t, err) + require.NotNil(t, resultPayload) + + // Check API key is in context + if value := resultCtx.Value(middleware.ApiKey); value != nil { + assert.Equal(t, apiKey, value.(string)) + } else { + t.Fatal("API key not found in context") + } + + // Check organization is in context + if value := resultCtx.Value(middleware.OrganizationKey); value != nil { + capturedOrg, ok := value.(domain.Organization) + require.True(t, ok, "Organization should be of correct type") + assert.Equal(t, org.Name, capturedOrg.Name) + assert.Equal(t, org.ID.String(), capturedOrg.ID.String()) + } else { + t.Fatal("Organization not found in context") + } +} + +func TestWebSocketInitFunc_WithInvalidAPIKey(t *testing.T) { + // Setup + mockCache := &MockCache{ + organizations: map[string]*domain.Organization{}, + } + + apiKey := "invalid-api-key" + + // Create InitFunc + initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) { + // Extract API key from WebSocket connection_init payload + if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" { + ctx = context.WithValue(ctx, middleware.ApiKey, apiKey) + + // Look up organization by API key + if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil { + ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization) + } + } + return ctx, &initPayload, nil + } + + // Test + ctx := context.Background() + initPayload := transport.InitPayload{ + "X-Api-Key": apiKey, + } + + resultCtx, resultPayload, err := initFunc(ctx, initPayload) + + // Assert + require.NoError(t, err) + require.NotNil(t, resultPayload) + + // Check API key is in context + if value := resultCtx.Value(middleware.ApiKey); value != nil { + assert.Equal(t, apiKey, value.(string)) + } else { + t.Fatal("API key not found in context") + } + + // Check organization is NOT in context (since API key is invalid) + value := resultCtx.Value(middleware.OrganizationKey) + assert.Nil(t, value, "Organization should not be set for invalid API key") +} + +func TestWebSocketInitFunc_WithoutAPIKey(t *testing.T) { + // Setup + mockCache := &MockCache{ + organizations: map[string]*domain.Organization{}, + } + + // Create InitFunc + initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) { + // Extract API key from WebSocket connection_init payload + if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" { + ctx = context.WithValue(ctx, middleware.ApiKey, apiKey) + + // Look up organization by API key + if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil { + ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization) + } + } + return ctx, &initPayload, nil + } + + // Test + ctx := context.Background() + initPayload := transport.InitPayload{} + + resultCtx, resultPayload, err := initFunc(ctx, initPayload) + + // Assert + require.NoError(t, err) + require.NotNil(t, resultPayload) + + // Check API key is NOT in context + value := resultCtx.Value(middleware.ApiKey) + assert.Nil(t, value, "API key should not be set when not provided") + + // Check organization is NOT in context + value = resultCtx.Value(middleware.OrganizationKey) + assert.Nil(t, value, "Organization should not be set when API key is not provided") +} + +func TestWebSocketInitFunc_WithEmptyAPIKey(t *testing.T) { + // Setup + mockCache := &MockCache{ + organizations: map[string]*domain.Organization{}, + } + + // Create InitFunc + initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) { + // Extract API key from WebSocket connection_init payload + if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" { + ctx = context.WithValue(ctx, middleware.ApiKey, apiKey) + + // Look up organization by API key + if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil { + ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization) + } + } + return ctx, &initPayload, nil + } + + // Test + ctx := context.Background() + initPayload := transport.InitPayload{ + "X-Api-Key": "", // Empty string + } + + resultCtx, resultPayload, err := initFunc(ctx, initPayload) + + // Assert + require.NoError(t, err) + require.NotNil(t, resultPayload) + + // Check API key is NOT in context (because empty string fails the condition) + value := resultCtx.Value(middleware.ApiKey) + assert.Nil(t, value, "API key should not be set when empty") + + // Check organization is NOT in context + value = resultCtx.Value(middleware.OrganizationKey) + assert.Nil(t, value, "Organization should not be set when API key is empty") +} + +func TestWebSocketInitFunc_WithWrongTypeAPIKey(t *testing.T) { + // Setup + mockCache := &MockCache{ + organizations: map[string]*domain.Organization{}, + } + + // Create InitFunc + initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) { + // Extract API key from WebSocket connection_init payload + if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" { + ctx = context.WithValue(ctx, middleware.ApiKey, apiKey) + + // Look up organization by API key + if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil { + ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization) + } + } + return ctx, &initPayload, nil + } + + // Test + ctx := context.Background() + initPayload := transport.InitPayload{ + "X-Api-Key": 12345, // Wrong type (int instead of string) + } + + resultCtx, resultPayload, err := initFunc(ctx, initPayload) + + // Assert + require.NoError(t, err) + require.NotNil(t, resultPayload) + + // Check API key is NOT in context (type assertion fails) + value := resultCtx.Value(middleware.ApiKey) + assert.Nil(t, value, "API key should not be set when wrong type") + + // Check organization is NOT in context + value = resultCtx.Value(middleware.OrganizationKey) + assert.Nil(t, value, "Organization should not be set when API key has wrong type") +} + +func TestWebSocketInitFunc_WithMultipleOrganizations(t *testing.T) { + // Setup - create multiple organizations + org1ID := uuid.New() + org1 := &domain.Organization{ + BaseAggregate: eventsourced.BaseAggregate{ + ID: eventsourced.IdFromString(org1ID.String()), + }, + Name: "Organization 1", + } + + org2ID := uuid.New() + org2 := &domain.Organization{ + BaseAggregate: eventsourced.BaseAggregate{ + ID: eventsourced.IdFromString(org2ID.String()), + }, + Name: "Organization 2", + } + + apiKey1 := "api-key-org-1" + apiKey2 := "api-key-org-2" + hashedKey1 := hash.String(apiKey1) + hashedKey2 := hash.String(apiKey2) + + mockCache := &MockCache{ + organizations: map[string]*domain.Organization{ + hashedKey1: org1, + hashedKey2: org2, + }, + } + + // Create InitFunc + initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) { + // Extract API key from WebSocket connection_init payload + if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" { + ctx = context.WithValue(ctx, middleware.ApiKey, apiKey) + + // Look up organization by API key + if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil { + ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization) + } + } + return ctx, &initPayload, nil + } + + // Test with first API key + ctx1 := context.Background() + initPayload1 := transport.InitPayload{ + "X-Api-Key": apiKey1, + } + + resultCtx1, _, err := initFunc(ctx1, initPayload1) + require.NoError(t, err) + + if value := resultCtx1.Value(middleware.OrganizationKey); value != nil { + capturedOrg, ok := value.(domain.Organization) + require.True(t, ok) + assert.Equal(t, org1.Name, capturedOrg.Name) + assert.Equal(t, org1.ID.String(), capturedOrg.ID.String()) + } else { + t.Fatal("Organization 1 not found in context") + } + + // Test with second API key + ctx2 := context.Background() + initPayload2 := transport.InitPayload{ + "X-Api-Key": apiKey2, + } + + resultCtx2, _, err := initFunc(ctx2, initPayload2) + require.NoError(t, err) + + if value := resultCtx2.Value(middleware.OrganizationKey); value != nil { + capturedOrg, ok := value.(domain.Organization) + require.True(t, ok) + assert.Equal(t, org2.Name, capturedOrg.Name) + assert.Equal(t, org2.ID.String(), capturedOrg.ID.String()) + } else { + t.Fatal("Organization 2 not found in context") + } +} diff --git a/middleware/auth_test.go b/middleware/auth_test.go new file mode 100644 index 0000000..3e456c8 --- /dev/null +++ b/middleware/auth_test.go @@ -0,0 +1,467 @@ +package middleware + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + mw "github.com/auth0/go-jwt-middleware/v2" + "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "gitlab.com/unboundsoftware/eventsourced/eventsourced" + + "gitlab.com/unboundsoftware/schemas/domain" + "gitlab.com/unboundsoftware/schemas/hash" +) + +// MockCache is a mock implementation of the Cache interface +type MockCache struct { + mock.Mock +} + +func (m *MockCache) OrganizationByAPIKey(apiKey string) *domain.Organization { + args := m.Called(apiKey) + if args.Get(0) == nil { + return nil + } + return args.Get(0).(*domain.Organization) +} + +func TestAuthMiddleware_Handler_WithValidAPIKey(t *testing.T) { + // Setup + mockCache := new(MockCache) + authMiddleware := NewAuth(mockCache) + + orgID := uuid.New() + expectedOrg := &domain.Organization{ + BaseAggregate: eventsourced.BaseAggregate{ + ID: eventsourced.IdFromString(orgID.String()), + }, + Name: "Test Organization", + } + + apiKey := "test-api-key-123" + hashedKey := hash.String(apiKey) + + mockCache.On("OrganizationByAPIKey", hashedKey).Return(expectedOrg) + + // Create a test handler that checks the context + var capturedOrg *domain.Organization + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if org := r.Context().Value(OrganizationKey); org != nil { + if o, ok := org.(domain.Organization); ok { + capturedOrg = &o + } + } + w.WriteHeader(http.StatusOK) + }) + + // Create request with API key in context + req := httptest.NewRequest(http.MethodGet, "/test", nil) + ctx := context.WithValue(req.Context(), ApiKey, apiKey) + req = req.WithContext(ctx) + + rec := httptest.NewRecorder() + + // Execute + authMiddleware.Handler(testHandler).ServeHTTP(rec, req) + + // Assert + assert.Equal(t, http.StatusOK, rec.Code) + require.NotNil(t, capturedOrg) + assert.Equal(t, expectedOrg.Name, capturedOrg.Name) + assert.Equal(t, expectedOrg.ID.String(), capturedOrg.ID.String()) + mockCache.AssertExpectations(t) +} + +func TestAuthMiddleware_Handler_WithInvalidAPIKey(t *testing.T) { + // Setup + mockCache := new(MockCache) + authMiddleware := NewAuth(mockCache) + + apiKey := "invalid-api-key" + hashedKey := hash.String(apiKey) + + mockCache.On("OrganizationByAPIKey", hashedKey).Return(nil) + + // Create a test handler that checks the context + var capturedOrg *domain.Organization + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if org := r.Context().Value(OrganizationKey); org != nil { + if o, ok := org.(domain.Organization); ok { + capturedOrg = &o + } + } + w.WriteHeader(http.StatusOK) + }) + + // Create request with API key in context + req := httptest.NewRequest(http.MethodGet, "/test", nil) + ctx := context.WithValue(req.Context(), ApiKey, apiKey) + req = req.WithContext(ctx) + + rec := httptest.NewRecorder() + + // Execute + authMiddleware.Handler(testHandler).ServeHTTP(rec, req) + + // Assert + assert.Equal(t, http.StatusOK, rec.Code) + assert.Nil(t, capturedOrg, "Organization should not be set for invalid API key") + mockCache.AssertExpectations(t) +} + +func TestAuthMiddleware_Handler_WithoutAPIKey(t *testing.T) { + // Setup + mockCache := new(MockCache) + authMiddleware := NewAuth(mockCache) + + // The middleware always hashes the API key (even if empty) and calls the cache + emptyKeyHash := hash.String("") + mockCache.On("OrganizationByAPIKey", emptyKeyHash).Return(nil) + + // Create a test handler that checks the context + var capturedOrg *domain.Organization + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if org := r.Context().Value(OrganizationKey); org != nil { + if o, ok := org.(domain.Organization); ok { + capturedOrg = &o + } + } + w.WriteHeader(http.StatusOK) + }) + + // Create request without API key + req := httptest.NewRequest(http.MethodGet, "/test", nil) + rec := httptest.NewRecorder() + + // Execute + authMiddleware.Handler(testHandler).ServeHTTP(rec, req) + + // Assert + assert.Equal(t, http.StatusOK, rec.Code) + assert.Nil(t, capturedOrg, "Organization should not be set without API key") + mockCache.AssertExpectations(t) +} + +func TestAuthMiddleware_Handler_WithValidJWT(t *testing.T) { + // Setup + mockCache := new(MockCache) + authMiddleware := NewAuth(mockCache) + + // The middleware always hashes the API key (even if empty) and calls the cache + emptyKeyHash := hash.String("") + mockCache.On("OrganizationByAPIKey", emptyKeyHash).Return(nil) + + userID := "user-123" + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "sub": userID, + }) + + // Create a test handler that checks the context + var capturedUser string + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if user := r.Context().Value(UserKey); user != nil { + if u, ok := user.(string); ok { + capturedUser = u + } + } + w.WriteHeader(http.StatusOK) + }) + + // Create request with JWT token in context + req := httptest.NewRequest(http.MethodGet, "/test", nil) + ctx := context.WithValue(req.Context(), mw.ContextKey{}, token) + req = req.WithContext(ctx) + + rec := httptest.NewRecorder() + + // Execute + authMiddleware.Handler(testHandler).ServeHTTP(rec, req) + + // Assert + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, userID, capturedUser) +} + +func TestAuthMiddleware_Handler_APIKeyErrorHandling(t *testing.T) { + // Setup + mockCache := new(MockCache) + authMiddleware := NewAuth(mockCache) + + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + // Create request with invalid API key type in context + req := httptest.NewRequest(http.MethodGet, "/test", nil) + ctx := context.WithValue(req.Context(), ApiKey, 12345) // Invalid type + req = req.WithContext(ctx) + + rec := httptest.NewRecorder() + + // Execute + authMiddleware.Handler(testHandler).ServeHTTP(rec, req) + + // Assert + assert.Equal(t, http.StatusInternalServerError, rec.Code) + assert.Contains(t, rec.Body.String(), "Invalid API Key format") +} + +func TestAuthMiddleware_Handler_JWTErrorHandling(t *testing.T) { + // Setup + mockCache := new(MockCache) + authMiddleware := NewAuth(mockCache) + + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + // Create request with invalid JWT token type in context + req := httptest.NewRequest(http.MethodGet, "/test", nil) + ctx := context.WithValue(req.Context(), mw.ContextKey{}, "not-a-token") // Invalid type + req = req.WithContext(ctx) + + rec := httptest.NewRecorder() + + // Execute + authMiddleware.Handler(testHandler).ServeHTTP(rec, req) + + // Assert + assert.Equal(t, http.StatusInternalServerError, rec.Code) + assert.Contains(t, rec.Body.String(), "Invalid JWT token format") +} + +func TestAuthMiddleware_Handler_BothJWTAndAPIKey(t *testing.T) { + // Setup + mockCache := new(MockCache) + authMiddleware := NewAuth(mockCache) + + orgID := uuid.New() + expectedOrg := &domain.Organization{ + BaseAggregate: eventsourced.BaseAggregate{ + ID: eventsourced.IdFromString(orgID.String()), + }, + Name: "Test Organization", + } + + userID := "user-123" + apiKey := "test-api-key-123" + hashedKey := hash.String(apiKey) + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "sub": userID, + }) + + mockCache.On("OrganizationByAPIKey", hashedKey).Return(expectedOrg) + + // Create a test handler that checks both user and organization in context + var capturedUser string + var capturedOrg *domain.Organization + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if user := r.Context().Value(UserKey); user != nil { + if u, ok := user.(string); ok { + capturedUser = u + } + } + if org := r.Context().Value(OrganizationKey); org != nil { + if o, ok := org.(domain.Organization); ok { + capturedOrg = &o + } + } + w.WriteHeader(http.StatusOK) + }) + + // Create request with both JWT and API key in context + req := httptest.NewRequest(http.MethodGet, "/test", nil) + ctx := context.WithValue(req.Context(), mw.ContextKey{}, token) + ctx = context.WithValue(ctx, ApiKey, apiKey) + req = req.WithContext(ctx) + + rec := httptest.NewRecorder() + + // Execute + authMiddleware.Handler(testHandler).ServeHTTP(rec, req) + + // Assert + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, userID, capturedUser) + require.NotNil(t, capturedOrg) + assert.Equal(t, expectedOrg.Name, capturedOrg.Name) + mockCache.AssertExpectations(t) +} + +func TestUserFromContext(t *testing.T) { + tests := []struct { + name string + ctx context.Context + expected string + }{ + { + name: "with valid user", + ctx: context.WithValue(context.Background(), UserKey, "user-123"), + expected: "user-123", + }, + { + name: "without user", + ctx: context.Background(), + expected: "", + }, + { + name: "with invalid type", + ctx: context.WithValue(context.Background(), UserKey, 123), + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := UserFromContext(tt.ctx) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestOrganizationFromContext(t *testing.T) { + orgID := uuid.New() + org := domain.Organization{ + BaseAggregate: eventsourced.BaseAggregate{ + ID: eventsourced.IdFromString(orgID.String()), + }, + Name: "Test Org", + } + + tests := []struct { + name string + ctx context.Context + expected string + }{ + { + name: "with valid organization", + ctx: context.WithValue(context.Background(), OrganizationKey, org), + expected: orgID.String(), + }, + { + name: "without organization", + ctx: context.Background(), + expected: "", + }, + { + name: "with invalid type", + ctx: context.WithValue(context.Background(), OrganizationKey, "not-an-org"), + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := OrganizationFromContext(tt.ctx) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestAuthMiddleware_Directive_RequiresUser(t *testing.T) { + mockCache := new(MockCache) + authMiddleware := NewAuth(mockCache) + + requireUser := true + + // Test with user present + ctx := context.WithValue(context.Background(), UserKey, "user-123") + _, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) { + return "success", nil + }, &requireUser, nil) + assert.NoError(t, err) + + // Test without user + ctx = context.Background() + _, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) { + return "success", nil + }, &requireUser, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "no user available in request") +} + +func TestAuthMiddleware_Directive_RequiresOrganization(t *testing.T) { + mockCache := new(MockCache) + authMiddleware := NewAuth(mockCache) + + requireOrg := true + orgID := uuid.New() + org := domain.Organization{ + BaseAggregate: eventsourced.BaseAggregate{ + ID: eventsourced.IdFromString(orgID.String()), + }, + Name: "Test Org", + } + + // Test with organization present + ctx := context.WithValue(context.Background(), OrganizationKey, org) + _, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) { + return "success", nil + }, nil, &requireOrg) + assert.NoError(t, err) + + // Test without organization + ctx = context.Background() + _, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) { + return "success", nil + }, nil, &requireOrg) + assert.Error(t, err) + assert.Contains(t, err.Error(), "no organization available in request") +} + +func TestAuthMiddleware_Directive_RequiresBoth(t *testing.T) { + mockCache := new(MockCache) + authMiddleware := NewAuth(mockCache) + + requireUser := true + requireOrg := true + orgID := uuid.New() + org := domain.Organization{ + BaseAggregate: eventsourced.BaseAggregate{ + ID: eventsourced.IdFromString(orgID.String()), + }, + Name: "Test Org", + } + + // Test with both present + ctx := context.WithValue(context.Background(), UserKey, "user-123") + ctx = context.WithValue(ctx, OrganizationKey, org) + _, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) { + return "success", nil + }, &requireUser, &requireOrg) + assert.NoError(t, err) + + // Test with only user + ctx = context.WithValue(context.Background(), UserKey, "user-123") + _, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) { + return "success", nil + }, &requireUser, &requireOrg) + assert.Error(t, err) + + // Test with only organization + ctx = context.WithValue(context.Background(), OrganizationKey, org) + _, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) { + return "success", nil + }, &requireUser, &requireOrg) + assert.Error(t, err) +} + +func TestAuthMiddleware_Directive_NoRequirements(t *testing.T) { + mockCache := new(MockCache) + authMiddleware := NewAuth(mockCache) + + // Test with no requirements + ctx := context.Background() + result, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) { + return "success", nil + }, nil, nil) + assert.NoError(t, err) + assert.Equal(t, "success", result) +}