package auth import ( "crypto/rand" "crypto/rsa" "encoding/json" "fmt" "time" "github.com/google/uuid" "github.com/lestrrat-go/jwx/v2/jwa" "github.com/lestrrat-go/jwx/v2/jwk" "github.com/lestrrat-go/jwx/v2/jws" "github.com/lestrrat-go/jwx/v2/jwt" ) const ( // TokenExpiry is the default token expiration time TokenExpiry = 2 * time.Hour ) // JWTService handles JWT signing and JWKS generation type JWTService struct { privateKey *rsa.PrivateKey jwkSet jwk.Set issuer string audience string adminClaim string emailClaim string } // NewJWTService creates a new JWT service with a generated RSA key pair func NewJWTService(issuer, audience, adminClaim, emailClaim string) (*JWTService, error) { // Generate RSA 2048-bit key pair privateKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { return nil, fmt.Errorf("generate RSA key: %w", err) } // Create JWK from private key key, err := jwk.FromRaw(privateKey) if err != nil { return nil, fmt.Errorf("create JWK from private key: %w", err) } // Set key metadata keyID := uuid.New().String() if err := key.Set(jwk.KeyIDKey, keyID); err != nil { return nil, fmt.Errorf("set key ID: %w", err) } if err := key.Set(jwk.AlgorithmKey, jwa.RS256); err != nil { return nil, fmt.Errorf("set algorithm: %w", err) } if err := key.Set(jwk.KeyUsageKey, "sig"); err != nil { return nil, fmt.Errorf("set key usage: %w", err) } // Create public key for JWKS publicKey, err := key.PublicKey() if err != nil { return nil, fmt.Errorf("get public key: %w", err) } // Create JWKS with public key jwkSet := jwk.NewSet() if err := jwkSet.AddKey(publicKey); err != nil { return nil, fmt.Errorf("add key to set: %w", err) } return &JWTService{ privateKey: privateKey, jwkSet: jwkSet, issuer: issuer, audience: audience, adminClaim: adminClaim, emailClaim: emailClaim, }, nil } // SignToken creates a signed JWT with the given claims func (s *JWTService) SignToken(claims map[string]interface{}) (string, error) { // Build JWT token builder := jwt.NewBuilder() now := time.Now() builder.Issuer(s.issuer) builder.IssuedAt(now) builder.Expiration(now.Add(TokenExpiry)) // Add all claims for key, value := range claims { builder.Claim(key, value) } token, err := builder.Build() if err != nil { return "", fmt.Errorf("build token: %w", err) } // Create JWK from private key for signing key, err := jwk.FromRaw(s.privateKey) if err != nil { return "", fmt.Errorf("create signing key: %w", err) } // Get key ID from JWKS pubKey, _ := s.jwkSet.Key(0) keyID := pubKey.KeyID() if err := key.Set(jwk.KeyIDKey, keyID); err != nil { return "", fmt.Errorf("set key ID: %w", err) } // Sign the token signed, err := jwt.Sign(token, jwt.WithKey(jwa.RS256, key)) if err != nil { return "", fmt.Errorf("sign token: %w", err) } return string(signed), nil } // SignAccessToken creates an access token for the given subject func (s *JWTService) SignAccessToken(subject, clientID, email string, customClaims []map[string]interface{}) (string, error) { claims := map[string]interface{}{ "sub": subject, "aud": []string{s.audience}, "azp": clientID, } // Add custom claims for _, cc := range customClaims { for k, v := range cc { claims[k] = v } } // Add email claim claims[s.emailClaim] = email return s.SignToken(claims) } // SignIDToken creates an ID token for the given subject func (s *JWTService) SignIDToken(subject, clientID, nonce, email, name, givenName, familyName, picture string, customClaims []map[string]interface{}) (string, error) { claims := map[string]interface{}{ "sub": subject, "aud": clientID, "azp": clientID, "name": name, "given_name": givenName, "family_name": familyName, "email": email, "picture": picture, } if nonce != "" { claims["nonce"] = nonce } // Add custom claims for _, cc := range customClaims { for k, v := range cc { claims[k] = v } } // Add email claim claims[s.emailClaim] = email return s.SignToken(claims) } // GetJWKS returns the JSON Web Key Set as JSON bytes func (s *JWTService) GetJWKS() ([]byte, error) { return json.Marshal(s.jwkSet) } // DecodeToken decodes a JWT without verifying the signature func (s *JWTService) DecodeToken(tokenString string) (map[string]interface{}, error) { // Parse without verification msg, err := jws.Parse([]byte(tokenString)) if err != nil { return nil, fmt.Errorf("parse token: %w", err) } var claims map[string]interface{} if err := json.Unmarshal(msg.Payload(), &claims); err != nil { return nil, fmt.Errorf("unmarshal claims: %w", err) } return claims, nil } // Issuer returns the issuer URL func (s *JWTService) Issuer() string { return s.issuer } // Audience returns the audience func (s *JWTService) Audience() string { return s.audience } // AdminClaim returns the admin custom claim key func (s *JWTService) AdminClaim() string { return s.adminClaim } // EmailClaim returns the email custom claim key func (s *JWTService) EmailClaim() string { return s.emailClaim }