package auth import ( "crypto/hmac" "crypto/sha256" "encoding/hex" "net/http" "net/http/httptest" "testing" "github.com/stretchr/testify/assert" ) func sign(key, header string) string { mac := hmac.New(sha256.New, []byte(key)) mac.Write([]byte(header)) return hex.EncodeToString(mac.Sum(nil)) } func TestUserMiddleware(t *testing.T) { key := "secret" header := `{"email":"jim@example.org","roles":["admin"]}` capture := func(next *bool) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { *next = true if u := FromContext(r.Context()); u != nil { assert.Equal(t, "jim@example.org", u.Email) assert.True(t, u.HasRole("admin")) } }) } t.Run("valid signature passes and injects user", func(t *testing.T) { called := false req := httptest.NewRequest(http.MethodPost, "/query", nil) req.Header.Set("user", header) req.Header.Set("user-signature", sign(key, header)) rw := httptest.NewRecorder() UserMiddleware([]byte(key))(capture(&called)).ServeHTTP(rw, req) assert.True(t, called) assert.Equal(t, http.StatusOK, rw.Code) }) t.Run("invalid signature is rejected", func(t *testing.T) { called := false req := httptest.NewRequest(http.MethodPost, "/query", nil) req.Header.Set("user", header) req.Header.Set("user-signature", "deadbeef") rw := httptest.NewRecorder() UserMiddleware([]byte(key))(capture(&called)).ServeHTTP(rw, req) assert.False(t, called) assert.Equal(t, http.StatusUnauthorized, rw.Code) }) t.Run("missing signature when key set is rejected", func(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/query", nil) req.Header.Set("user", header) rw := httptest.NewRecorder() UserMiddleware([]byte(key))(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})).ServeHTTP(rw, req) assert.Equal(t, http.StatusUnauthorized, rw.Code) }) t.Run("empty key skips verification (dev only)", func(t *testing.T) { called := false req := httptest.NewRequest(http.MethodPost, "/query", nil) req.Header.Set("user", header) rw := httptest.NewRecorder() UserMiddleware(nil)(capture(&called)).ServeHTTP(rw, req) assert.True(t, called) }) } func TestFromContextNil(t *testing.T) { assert.Nil(t, FromContext(httptest.NewRequest(http.MethodGet, "/", nil).Context())) }