fix: migrate to go-jwt-middleware v3 API
schemas / check-release (pull_request) Successful in 1m57s
schemas / vulnerabilities (pull_request) Successful in 2m48s
schemas / check (pull_request) Successful in 8m17s
pre-commit / pre-commit (pull_request) Successful in 11m38s
schemas / build (pull_request) Successful in 5m31s
schemas / deploy-prod (pull_request) Has been skipped
schemas / check-release (pull_request) Successful in 1m57s
schemas / vulnerabilities (pull_request) Successful in 2m48s
schemas / check (pull_request) Successful in 8m17s
pre-commit / pre-commit (pull_request) Successful in 11m38s
schemas / build (pull_request) Successful in 5m31s
schemas / deploy-prod (pull_request) Has been skipped
- Use validator and jwks packages for JWT validation - Replace manual JWKS caching with jwks.NewCachingProvider - Add CustomClaims struct for https://unbound.se/roles claim - Rename TokenFromContext to ClaimsFromContext - Update middleware/auth.go to use new claims structure - Update tests to use core.SetClaims and validator.ValidatedClaims Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
+50
-141
@@ -2,39 +2,34 @@ package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"log"
|
||||
"net/url"
|
||||
|
||||
mw "github.com/auth0/go-jwt-middleware/v3"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/pkg/errors"
|
||||
jwtmiddleware "github.com/auth0/go-jwt-middleware/v3"
|
||||
"github.com/auth0/go-jwt-middleware/v3/jwks"
|
||||
"github.com/auth0/go-jwt-middleware/v3/validator"
|
||||
)
|
||||
|
||||
// CustomClaims contains custom claims from the JWT token.
|
||||
type CustomClaims struct {
|
||||
Roles []string `json:"https://unbound.se/roles"`
|
||||
}
|
||||
|
||||
// Validate implements the validator.CustomClaims interface.
|
||||
func (c CustomClaims) Validate(_ context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type Auth0 struct {
|
||||
domain string
|
||||
audience string
|
||||
client *http.Client
|
||||
cache JwksCache
|
||||
}
|
||||
|
||||
func NewAuth0(audience, domain string, strictSsl bool) *Auth0 {
|
||||
customTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
customTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: !strictSsl}
|
||||
client := &http.Client{Transport: customTransport}
|
||||
|
||||
func NewAuth0(audience, domain string, _ bool) *Auth0 {
|
||||
return &Auth0{
|
||||
domain: domain,
|
||||
audience: audience,
|
||||
client: client,
|
||||
cache: JwksCache{
|
||||
RWMutex: &sync.RWMutex{},
|
||||
cache: make(map[string]cacheItem),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,133 +37,47 @@ type Response struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type Jwks struct {
|
||||
Keys []JSONWebKeys `json:"keys"`
|
||||
}
|
||||
|
||||
type JSONWebKeys struct {
|
||||
Kty string `json:"kty"`
|
||||
Kid string `json:"kid"`
|
||||
Use string `json:"use"`
|
||||
N string `json:"n"`
|
||||
E string `json:"e"`
|
||||
X5c []string `json:"x5c"`
|
||||
}
|
||||
|
||||
func (a *Auth0) ValidationKeyGetter() func(token *jwt.Token) (interface{}, error) {
|
||||
return func(token *jwt.Token) (interface{}, error) {
|
||||
// Verify 'aud' claim
|
||||
|
||||
cert, err := a.getPemCert(token)
|
||||
if err != nil {
|
||||
panic(err.Error())
|
||||
}
|
||||
|
||||
result, _ := jwt.ParseRSAPublicKeyFromPEM([]byte(cert))
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Auth0) Middleware() *mw.JWTMiddleware {
|
||||
func (a *Auth0) Middleware() *jwtmiddleware.JWTMiddleware {
|
||||
issuer := fmt.Sprintf("https://%s/", a.domain)
|
||||
jwtMiddleware := mw.New(func(ctx context.Context, token string) (interface{}, error) {
|
||||
jwtToken, err := jwt.Parse(token, a.ValidationKeyGetter(), jwt.WithAudience(a.audience), jwt.WithIssuer(issuer))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, ok := jwtToken.Method.(*jwt.SigningMethodRSA); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", jwtToken.Header["alg"])
|
||||
}
|
||||
return jwtToken, nil
|
||||
},
|
||||
mw.WithTokenExtractor(func(r *http.Request) (string, error) {
|
||||
token := r.Header.Get("Authorization")
|
||||
if strings.HasPrefix(token, "Bearer ") {
|
||||
return token[7:], nil
|
||||
}
|
||||
return "", nil
|
||||
|
||||
issuerURL, err := url.Parse(issuer)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to parse issuer URL: %v", err)
|
||||
}
|
||||
|
||||
provider, err := jwks.NewCachingProvider(jwks.WithIssuerURL(issuerURL))
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create JWKS provider: %v", err)
|
||||
}
|
||||
|
||||
jwtValidator, err := validator.New(
|
||||
validator.WithKeyFunc(provider.KeyFunc),
|
||||
validator.WithAlgorithm(validator.RS256),
|
||||
validator.WithIssuer(issuer),
|
||||
validator.WithAudience(a.audience),
|
||||
validator.WithCustomClaims(func() validator.CustomClaims {
|
||||
return &CustomClaims{}
|
||||
}),
|
||||
mw.WithCredentialsOptional(true),
|
||||
)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create JWT validator: %v", err)
|
||||
}
|
||||
|
||||
jwtMiddleware, err := jwtmiddleware.New(
|
||||
jwtmiddleware.WithValidator(jwtValidator),
|
||||
jwtmiddleware.WithCredentialsOptional(true),
|
||||
)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create JWT middleware: %v", err)
|
||||
}
|
||||
|
||||
return jwtMiddleware
|
||||
}
|
||||
|
||||
func TokenFromContext(ctx context.Context) (*jwt.Token, error) {
|
||||
if value := ctx.Value(mw.ContextKey{}); value != nil {
|
||||
if u, ok := value.(*jwt.Token); ok {
|
||||
return u, nil
|
||||
}
|
||||
return nil, fmt.Errorf("token is in wrong format")
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Auth0) cacheGetWellknown(url string) (*Jwks, error) {
|
||||
if value := a.cache.get(url); value != nil {
|
||||
return value, nil
|
||||
}
|
||||
jwks := &Jwks{}
|
||||
resp, err := a.client.Get(url)
|
||||
func ClaimsFromContext(ctx context.Context) *validator.ValidatedClaims {
|
||||
claims, err := jwtmiddleware.GetClaims[*validator.ValidatedClaims](ctx)
|
||||
if err != nil {
|
||||
return jwks, err
|
||||
}
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
err = json.NewDecoder(resp.Body).Decode(jwks)
|
||||
if err == nil && jwks != nil {
|
||||
a.cache.put(url, jwks)
|
||||
}
|
||||
return jwks, err
|
||||
}
|
||||
|
||||
func (a *Auth0) getPemCert(token *jwt.Token) (string, error) {
|
||||
jwks, err := a.cacheGetWellknown(fmt.Sprintf("https://%s/.well-known/jwks.json", a.domain))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
var cert string
|
||||
for k := range jwks.Keys {
|
||||
if token.Header["kid"] == jwks.Keys[k].Kid {
|
||||
cert = "-----BEGIN CERTIFICATE-----\n" + jwks.Keys[k].X5c[0] + "\n-----END CERTIFICATE-----"
|
||||
}
|
||||
}
|
||||
|
||||
if cert == "" {
|
||||
err := errors.New("Unable to find appropriate key.")
|
||||
return cert, err
|
||||
}
|
||||
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
type JwksCache struct {
|
||||
*sync.RWMutex
|
||||
cache map[string]cacheItem
|
||||
}
|
||||
type cacheItem struct {
|
||||
data *Jwks
|
||||
expiration time.Time
|
||||
}
|
||||
|
||||
func (c *JwksCache) get(url string) *Jwks {
|
||||
c.RLock()
|
||||
defer c.RUnlock()
|
||||
if value, ok := c.cache[url]; ok {
|
||||
if time.Now().After(value.expiration) {
|
||||
return nil
|
||||
}
|
||||
return value.data
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *JwksCache) put(url string, jwks *Jwks) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
c.cache[url] = cacheItem{
|
||||
data: jwks,
|
||||
expiration: time.Now().Add(time.Minute * 60),
|
||||
return nil
|
||||
}
|
||||
return claims
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user