package main import ( "context" "testing" "github.com/99designs/gqlgen/graphql/handler/transport" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gitlab.com/unboundsoftware/eventsourced/eventsourced" "gitlab.com/unboundsoftware/schemas/domain" "gitlab.com/unboundsoftware/schemas/hash" "gitlab.com/unboundsoftware/schemas/middleware" ) // MockCache is a mock implementation for testing type MockCache struct { organizations map[string]*domain.Organization } func (m *MockCache) OrganizationByAPIKey(apiKey string) *domain.Organization { return m.organizations[apiKey] } func TestWebSocketInitFunc_WithValidAPIKey(t *testing.T) { // Setup orgID := uuid.New() org := &domain.Organization{ BaseAggregate: eventsourced.BaseAggregate{ ID: eventsourced.IdFromString(orgID.String()), }, Name: "Test Organization", } apiKey := "test-api-key-123" hashedKey := hash.String(apiKey) mockCache := &MockCache{ organizations: map[string]*domain.Organization{ hashedKey: org, }, } // Create InitFunc (simulating the WebSocket InitFunc logic) initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) { // Extract API key from WebSocket connection_init payload if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" { ctx = context.WithValue(ctx, middleware.ApiKey, apiKey) // Look up organization by API key if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil { ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization) } } return ctx, &initPayload, nil } // Test ctx := context.Background() initPayload := transport.InitPayload{ "X-Api-Key": apiKey, } resultCtx, resultPayload, err := initFunc(ctx, initPayload) // Assert require.NoError(t, err) require.NotNil(t, resultPayload) // Check API key is in context if value := resultCtx.Value(middleware.ApiKey); value != nil { assert.Equal(t, apiKey, value.(string)) } else { t.Fatal("API key not found in context") } // Check organization is in context if value := resultCtx.Value(middleware.OrganizationKey); value != nil { capturedOrg, ok := value.(domain.Organization) require.True(t, ok, "Organization should be of correct type") assert.Equal(t, org.Name, capturedOrg.Name) assert.Equal(t, org.ID.String(), capturedOrg.ID.String()) } else { t.Fatal("Organization not found in context") } } func TestWebSocketInitFunc_WithInvalidAPIKey(t *testing.T) { // Setup mockCache := &MockCache{ organizations: map[string]*domain.Organization{}, } apiKey := "invalid-api-key" // Create InitFunc initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) { // Extract API key from WebSocket connection_init payload if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" { ctx = context.WithValue(ctx, middleware.ApiKey, apiKey) // Look up organization by API key if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil { ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization) } } return ctx, &initPayload, nil } // Test ctx := context.Background() initPayload := transport.InitPayload{ "X-Api-Key": apiKey, } resultCtx, resultPayload, err := initFunc(ctx, initPayload) // Assert require.NoError(t, err) require.NotNil(t, resultPayload) // Check API key is in context if value := resultCtx.Value(middleware.ApiKey); value != nil { assert.Equal(t, apiKey, value.(string)) } else { t.Fatal("API key not found in context") } // Check organization is NOT in context (since API key is invalid) value := resultCtx.Value(middleware.OrganizationKey) assert.Nil(t, value, "Organization should not be set for invalid API key") } func TestWebSocketInitFunc_WithoutAPIKey(t *testing.T) { // Setup mockCache := &MockCache{ organizations: map[string]*domain.Organization{}, } // Create InitFunc initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) { // Extract API key from WebSocket connection_init payload if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" { ctx = context.WithValue(ctx, middleware.ApiKey, apiKey) // Look up organization by API key if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil { ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization) } } return ctx, &initPayload, nil } // Test ctx := context.Background() initPayload := transport.InitPayload{} resultCtx, resultPayload, err := initFunc(ctx, initPayload) // Assert require.NoError(t, err) require.NotNil(t, resultPayload) // Check API key is NOT in context value := resultCtx.Value(middleware.ApiKey) assert.Nil(t, value, "API key should not be set when not provided") // Check organization is NOT in context value = resultCtx.Value(middleware.OrganizationKey) assert.Nil(t, value, "Organization should not be set when API key is not provided") } func TestWebSocketInitFunc_WithEmptyAPIKey(t *testing.T) { // Setup mockCache := &MockCache{ organizations: map[string]*domain.Organization{}, } // Create InitFunc initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) { // Extract API key from WebSocket connection_init payload if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" { ctx = context.WithValue(ctx, middleware.ApiKey, apiKey) // Look up organization by API key if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil { ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization) } } return ctx, &initPayload, nil } // Test ctx := context.Background() initPayload := transport.InitPayload{ "X-Api-Key": "", // Empty string } resultCtx, resultPayload, err := initFunc(ctx, initPayload) // Assert require.NoError(t, err) require.NotNil(t, resultPayload) // Check API key is NOT in context (because empty string fails the condition) value := resultCtx.Value(middleware.ApiKey) assert.Nil(t, value, "API key should not be set when empty") // Check organization is NOT in context value = resultCtx.Value(middleware.OrganizationKey) assert.Nil(t, value, "Organization should not be set when API key is empty") } func TestWebSocketInitFunc_WithWrongTypeAPIKey(t *testing.T) { // Setup mockCache := &MockCache{ organizations: map[string]*domain.Organization{}, } // Create InitFunc initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) { // Extract API key from WebSocket connection_init payload if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" { ctx = context.WithValue(ctx, middleware.ApiKey, apiKey) // Look up organization by API key if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil { ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization) } } return ctx, &initPayload, nil } // Test ctx := context.Background() initPayload := transport.InitPayload{ "X-Api-Key": 12345, // Wrong type (int instead of string) } resultCtx, resultPayload, err := initFunc(ctx, initPayload) // Assert require.NoError(t, err) require.NotNil(t, resultPayload) // Check API key is NOT in context (type assertion fails) value := resultCtx.Value(middleware.ApiKey) assert.Nil(t, value, "API key should not be set when wrong type") // Check organization is NOT in context value = resultCtx.Value(middleware.OrganizationKey) assert.Nil(t, value, "Organization should not be set when API key has wrong type") } func TestWebSocketInitFunc_WithMultipleOrganizations(t *testing.T) { // Setup - create multiple organizations org1ID := uuid.New() org1 := &domain.Organization{ BaseAggregate: eventsourced.BaseAggregate{ ID: eventsourced.IdFromString(org1ID.String()), }, Name: "Organization 1", } org2ID := uuid.New() org2 := &domain.Organization{ BaseAggregate: eventsourced.BaseAggregate{ ID: eventsourced.IdFromString(org2ID.String()), }, Name: "Organization 2", } apiKey1 := "api-key-org-1" apiKey2 := "api-key-org-2" hashedKey1 := hash.String(apiKey1) hashedKey2 := hash.String(apiKey2) mockCache := &MockCache{ organizations: map[string]*domain.Organization{ hashedKey1: org1, hashedKey2: org2, }, } // Create InitFunc initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) { // Extract API key from WebSocket connection_init payload if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" { ctx = context.WithValue(ctx, middleware.ApiKey, apiKey) // Look up organization by API key if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil { ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization) } } return ctx, &initPayload, nil } // Test with first API key ctx1 := context.Background() initPayload1 := transport.InitPayload{ "X-Api-Key": apiKey1, } resultCtx1, _, err := initFunc(ctx1, initPayload1) require.NoError(t, err) if value := resultCtx1.Value(middleware.OrganizationKey); value != nil { capturedOrg, ok := value.(domain.Organization) require.True(t, ok) assert.Equal(t, org1.Name, capturedOrg.Name) assert.Equal(t, org1.ID.String(), capturedOrg.ID.String()) } else { t.Fatal("Organization 1 not found in context") } // Test with second API key ctx2 := context.Background() initPayload2 := transport.InitPayload{ "X-Api-Key": apiKey2, } resultCtx2, _, err := initFunc(ctx2, initPayload2) require.NoError(t, err) if value := resultCtx2.Value(middleware.OrganizationKey); value != nil { capturedOrg, ok := value.(domain.Organization) require.True(t, ok) assert.Equal(t, org2.Name, capturedOrg.Name) assert.Equal(t, org2.ID.String(), capturedOrg.ID.String()) } else { t.Fatal("Organization 2 not found in context") } }