From bb0c08be06478e486d25a2849db71563cc044fdf Mon Sep 17 00:00:00 2001 From: Joakim Olsson Date: Thu, 20 Nov 2025 08:09:00 +0100 Subject: [PATCH] fix: enhance API key handling and logging in middleware Refactor API key processing to improve clarity and reduce code duplication. Introduce detailed logging for schema updates and initializations, capturing relevant context information. Use background context for async operations to avoid blocking. Implement organization lookup logic in the WebSocket init function for consistent API key handling across connections. --- cmd/service/service.go | 19 ++++++++++++++ graph/schema.resolvers.go | 52 +++++++++++++++++++++++++++++++++++---- middleware/auth.go | 4 ++- 3 files changed, 69 insertions(+), 6 deletions(-) diff --git a/cmd/service/service.go b/cmd/service/service.go index dbefcf2..147d612 100644 --- a/cmd/service/service.go +++ b/cmd/service/service.go @@ -30,6 +30,7 @@ import ( "gitlab.com/unboundsoftware/schemas/domain" "gitlab.com/unboundsoftware/schemas/graph" "gitlab.com/unboundsoftware/schemas/graph/generated" + "gitlab.com/unboundsoftware/schemas/hash" "gitlab.com/unboundsoftware/schemas/logging" "gitlab.com/unboundsoftware/schemas/middleware" "gitlab.com/unboundsoftware/schemas/monitoring" @@ -210,6 +211,24 @@ func start(closeEvents chan error, logger *slog.Logger, connectToAmqpFunc func(u srv.AddTransport(transport.Websocket{ KeepAlivePingInterval: 10 * time.Second, + InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) { + // Extract API key from WebSocket connection_init payload + if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" { + logger.Info("WebSocket connection with API key", "has_key", true) + ctx = context.WithValue(ctx, middleware.ApiKey, apiKey) + + // Look up organization by API key (same logic as auth middleware) + if organization := serviceCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil { + logger.Info("WebSocket: Organization found for API key", "org_id", organization.ID.String()) + ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization) + } else { + logger.Warn("WebSocket: No organization found for API key") + } + } else { + logger.Info("WebSocket connection without API key") + } + return ctx, &initPayload, nil + }, }) srv.AddTransport(transport.Options{}) srv.AddTransport(transport.GET{}) diff --git a/graph/schema.resolvers.go b/graph/schema.resolvers.go index 6c68229..00f6202 100644 --- a/graph/schema.resolvers.go +++ b/graph/schema.resolvers.go @@ -123,6 +123,13 @@ func (r *mutationResolver) UpdateSubGraph(ctx context.Context, input model.Input // Publish schema update to subscribers go func() { services, lastUpdate := r.Cache.Services(orgId, input.Ref, "") + r.Logger.Info("Publishing schema update after subgraph change", + "ref", input.Ref, + "orgId", orgId, + "lastUpdate", lastUpdate, + "servicesCount", len(services), + ) + subGraphs := make([]*model.SubGraph, len(services)) for i, id := range services { sg, err := r.fetchSubGraph(context.Background(), id) @@ -149,12 +156,21 @@ func (r *mutationResolver) UpdateSubGraph(ctx context.Context, input model.Input } // Publish to all subscribers of this ref - r.PubSub.Publish(input.Ref, &model.SchemaUpdate{ + update := &model.SchemaUpdate{ Ref: input.Ref, ID: lastUpdate, SubGraphs: subGraphs, CosmoRouterConfig: &cosmoConfig, - }) + } + + r.Logger.Info("Publishing schema update to subscribers", + "ref", update.Ref, + "id", update.ID, + "subGraphsCount", len(update.SubGraphs), + "cosmoConfigLength", len(cosmoConfig), + ) + + r.PubSub.Publish(input.Ref, update) }() return r.toGqlSubGraph(subGraph), nil @@ -225,8 +241,15 @@ func (r *queryResolver) Supergraph(ctx context.Context, ref string, isAfter *str // SchemaUpdates is the resolver for the schemaUpdates field. func (r *subscriptionResolver) SchemaUpdates(ctx context.Context, ref string) (<-chan *model.SchemaUpdate, error) { orgId := middleware.OrganizationFromContext(ctx) + + r.Logger.Info("SchemaUpdates subscription started", + "ref", ref, + "orgId", orgId, + ) + _, err := r.apiKeyCanAccessRef(ctx, ref, false) if err != nil { + r.Logger.Error("API key cannot access ref", "error", err, "ref", ref) return nil, err } @@ -235,12 +258,22 @@ func (r *subscriptionResolver) SchemaUpdates(ctx context.Context, ref string) (< // Send initial state immediately go func() { + // Use background context for async operation + bgCtx := context.Background() + services, lastUpdate := r.Cache.Services(orgId, ref, "") + r.Logger.Info("Preparing initial schema update", + "ref", ref, + "orgId", orgId, + "lastUpdate", lastUpdate, + "servicesCount", len(services), + ) + subGraphs := make([]*model.SubGraph, len(services)) for i, id := range services { - sg, err := r.fetchSubGraph(ctx, id) + sg, err := r.fetchSubGraph(bgCtx, id) if err != nil { - r.Logger.Error("fetch subgraph for initial update", "error", err) + r.Logger.Error("fetch subgraph for initial update", "error", err, "id", id) continue } subGraphs[i] = &model.SubGraph{ @@ -262,12 +295,21 @@ func (r *subscriptionResolver) SchemaUpdates(ctx context.Context, ref string) (< } // Send initial update - ch <- &model.SchemaUpdate{ + update := &model.SchemaUpdate{ Ref: ref, ID: lastUpdate, SubGraphs: subGraphs, CosmoRouterConfig: &cosmoConfig, } + + r.Logger.Info("Sending initial schema update", + "ref", update.Ref, + "id", update.ID, + "subGraphsCount", len(update.SubGraphs), + "cosmoConfigLength", len(cosmoConfig), + ) + + ch <- update }() // Clean up subscription when context is done diff --git a/middleware/auth.go b/middleware/auth.go index 6704946..bbe3a9e 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -49,7 +49,9 @@ func (m *AuthMiddleware) Handler(next http.Handler) http.Handler { _, _ = w.Write([]byte("Invalid API Key format")) return } - if organization := m.cache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil { + hashedKey := hash.String(apiKey) + organization := m.cache.OrganizationByAPIKey(hashedKey) + if organization != nil { ctx = context.WithValue(ctx, OrganizationKey, *organization) }