fix: migrate to go-jwt-middleware v3 API
schemas / check-release (pull_request) Successful in 1m57s
schemas / vulnerabilities (pull_request) Successful in 2m48s
schemas / check (pull_request) Successful in 8m17s
pre-commit / pre-commit (pull_request) Successful in 11m38s
schemas / build (pull_request) Successful in 5m31s
schemas / deploy-prod (pull_request) Has been skipped

- Use validator and jwks packages for JWT validation
- Replace manual JWKS caching with jwks.NewCachingProvider
- Add CustomClaims struct for https://unbound.se/roles claim
- Rename TokenFromContext to ClaimsFromContext
- Update middleware/auth.go to use new claims structure
- Update tests to use core.SetClaims and validator.ValidatedClaims

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-19 20:31:45 +01:00
parent e2c1803683
commit 817927cb7d
5 changed files with 133 additions and 242 deletions
+50 -141
View File
@@ -2,39 +2,34 @@ package middleware
import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"net/http"
"strings"
"sync"
"time"
"log"
"net/url"
mw "github.com/auth0/go-jwt-middleware/v3"
"github.com/golang-jwt/jwt/v5"
"github.com/pkg/errors"
jwtmiddleware "github.com/auth0/go-jwt-middleware/v3"
"github.com/auth0/go-jwt-middleware/v3/jwks"
"github.com/auth0/go-jwt-middleware/v3/validator"
)
// CustomClaims contains custom claims from the JWT token.
type CustomClaims struct {
Roles []string `json:"https://unbound.se/roles"`
}
// Validate implements the validator.CustomClaims interface.
func (c CustomClaims) Validate(_ context.Context) error {
return nil
}
type Auth0 struct {
domain string
audience string
client *http.Client
cache JwksCache
}
func NewAuth0(audience, domain string, strictSsl bool) *Auth0 {
customTransport := http.DefaultTransport.(*http.Transport).Clone()
customTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: !strictSsl}
client := &http.Client{Transport: customTransport}
func NewAuth0(audience, domain string, _ bool) *Auth0 {
return &Auth0{
domain: domain,
audience: audience,
client: client,
cache: JwksCache{
RWMutex: &sync.RWMutex{},
cache: make(map[string]cacheItem),
},
}
}
@@ -42,133 +37,47 @@ type Response struct {
Message string `json:"message"`
}
type Jwks struct {
Keys []JSONWebKeys `json:"keys"`
}
type JSONWebKeys struct {
Kty string `json:"kty"`
Kid string `json:"kid"`
Use string `json:"use"`
N string `json:"n"`
E string `json:"e"`
X5c []string `json:"x5c"`
}
func (a *Auth0) ValidationKeyGetter() func(token *jwt.Token) (interface{}, error) {
return func(token *jwt.Token) (interface{}, error) {
// Verify 'aud' claim
cert, err := a.getPemCert(token)
if err != nil {
panic(err.Error())
}
result, _ := jwt.ParseRSAPublicKeyFromPEM([]byte(cert))
return result, nil
}
}
func (a *Auth0) Middleware() *mw.JWTMiddleware {
func (a *Auth0) Middleware() *jwtmiddleware.JWTMiddleware {
issuer := fmt.Sprintf("https://%s/", a.domain)
jwtMiddleware := mw.New(func(ctx context.Context, token string) (interface{}, error) {
jwtToken, err := jwt.Parse(token, a.ValidationKeyGetter(), jwt.WithAudience(a.audience), jwt.WithIssuer(issuer))
if err != nil {
return nil, err
}
if _, ok := jwtToken.Method.(*jwt.SigningMethodRSA); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", jwtToken.Header["alg"])
}
return jwtToken, nil
},
mw.WithTokenExtractor(func(r *http.Request) (string, error) {
token := r.Header.Get("Authorization")
if strings.HasPrefix(token, "Bearer ") {
return token[7:], nil
}
return "", nil
issuerURL, err := url.Parse(issuer)
if err != nil {
log.Fatalf("failed to parse issuer URL: %v", err)
}
provider, err := jwks.NewCachingProvider(jwks.WithIssuerURL(issuerURL))
if err != nil {
log.Fatalf("failed to create JWKS provider: %v", err)
}
jwtValidator, err := validator.New(
validator.WithKeyFunc(provider.KeyFunc),
validator.WithAlgorithm(validator.RS256),
validator.WithIssuer(issuer),
validator.WithAudience(a.audience),
validator.WithCustomClaims(func() validator.CustomClaims {
return &CustomClaims{}
}),
mw.WithCredentialsOptional(true),
)
if err != nil {
log.Fatalf("failed to create JWT validator: %v", err)
}
jwtMiddleware, err := jwtmiddleware.New(
jwtmiddleware.WithValidator(jwtValidator),
jwtmiddleware.WithCredentialsOptional(true),
)
if err != nil {
log.Fatalf("failed to create JWT middleware: %v", err)
}
return jwtMiddleware
}
func TokenFromContext(ctx context.Context) (*jwt.Token, error) {
if value := ctx.Value(mw.ContextKey{}); value != nil {
if u, ok := value.(*jwt.Token); ok {
return u, nil
}
return nil, fmt.Errorf("token is in wrong format")
}
return nil, nil
}
func (a *Auth0) cacheGetWellknown(url string) (*Jwks, error) {
if value := a.cache.get(url); value != nil {
return value, nil
}
jwks := &Jwks{}
resp, err := a.client.Get(url)
func ClaimsFromContext(ctx context.Context) *validator.ValidatedClaims {
claims, err := jwtmiddleware.GetClaims[*validator.ValidatedClaims](ctx)
if err != nil {
return jwks, err
}
defer func() {
_ = resp.Body.Close()
}()
err = json.NewDecoder(resp.Body).Decode(jwks)
if err == nil && jwks != nil {
a.cache.put(url, jwks)
}
return jwks, err
}
func (a *Auth0) getPemCert(token *jwt.Token) (string, error) {
jwks, err := a.cacheGetWellknown(fmt.Sprintf("https://%s/.well-known/jwks.json", a.domain))
if err != nil {
return "", err
}
var cert string
for k := range jwks.Keys {
if token.Header["kid"] == jwks.Keys[k].Kid {
cert = "-----BEGIN CERTIFICATE-----\n" + jwks.Keys[k].X5c[0] + "\n-----END CERTIFICATE-----"
}
}
if cert == "" {
err := errors.New("Unable to find appropriate key.")
return cert, err
}
return cert, nil
}
type JwksCache struct {
*sync.RWMutex
cache map[string]cacheItem
}
type cacheItem struct {
data *Jwks
expiration time.Time
}
func (c *JwksCache) get(url string) *Jwks {
c.RLock()
defer c.RUnlock()
if value, ok := c.cache[url]; ok {
if time.Now().After(value.expiration) {
return nil
}
return value.data
}
return nil
}
func (c *JwksCache) put(url string, jwks *Jwks) {
c.Lock()
defer c.Unlock()
c.cache[url] = cacheItem{
data: jwks,
expiration: time.Now().Add(time.Minute * 60),
return nil
}
return claims
}