From 27a14e83edcceaad77b538efb824905a2e664952 Mon Sep 17 00:00:00 2001 From: Daniel Francesconi Date: Fri, 15 Nov 2024 14:13:04 +0100 Subject: [PATCH 1/4] feat: add oidc middleware --- go.mod | 4 ++ go.sum | 8 +++ middlewarex/oidc.go | 104 +++++++++++++++++++++++++++++++++++++++ middlewarex/oidc_test.go | 81 ++++++++++++++++++++++++++++++ 4 files changed, 197 insertions(+) create mode 100644 middlewarex/oidc.go create mode 100644 middlewarex/oidc_test.go diff --git a/go.mod b/go.mod index f44cf23..cf740e9 100644 --- a/go.mod +++ b/go.mod @@ -3,14 +3,18 @@ module github.com/HGV/x go 1.22 require ( + github.com/coreos/go-oidc/v3 v3.11.0 github.com/jackc/pgx/v5 v5.7.2 github.com/stretchr/testify v1.10.0 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/go-jose/go-jose/v4 v4.0.2 // indirect github.com/kr/text v0.2.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rogpeppe/go-internal v1.13.1 // indirect + golang.org/x/crypto v0.31.0 // indirect + golang.org/x/oauth2 v0.21.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index bb3efd2..f2a3b1c 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,12 @@ +github.com/coreos/go-oidc/v3 v3.11.0 h1:Ia3MxdwpSw702YW0xgfmP1GVCMA9aEFWu12XUZ3/OtI= +github.com/coreos/go-oidc/v3 v3.11.0/go.mod h1:gE3LgjOgFoHi9a4ce4/tJczr0Ai2/BoDhf0r5lltWI0= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-jose/go-jose/v4 v4.0.2 h1:R3l3kkBds16bO7ZFAEEcofK0MkrAJt3jlJznWZG0nvk= +github.com/go-jose/go-jose/v4 v4.0.2/go.mod h1:WVf9LFMHh/QVrmqrOfqun0C45tMe3RoiKJMPvgWwLfY= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= @@ -21,6 +27,8 @@ github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOf github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= +golang.org/x/oauth2 v0.21.0 h1:tsimM75w1tF/uws5rbeHzIWxEqElMehnc+iW793zsZs= +golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= diff --git a/middlewarex/oidc.go b/middlewarex/oidc.go new file mode 100644 index 0000000..80547e7 --- /dev/null +++ b/middlewarex/oidc.go @@ -0,0 +1,104 @@ +package middlewarex + +import ( + "context" + "errors" + "net/http" + "strings" + + "github.com/coreos/go-oidc/v3/oidc" +) + +type ( + OIDCMiddlewareOption func(*oidcMiddleware) + + oidcMiddleware struct { + authFailedHandler func(error) http.HandlerFunc + config *oidc.Config + verifier *oidc.IDTokenVerifier + } + oidcContextKey int +) + +const ( + idTokenContextKey oidcContextKey = iota +) + +func OIDC(ctx context.Context, issuer string, opts ...OIDCMiddlewareOption) func(next http.Handler) http.Handler { + provider, err := oidc.NewProvider(ctx, issuer) + if err != nil { + panic(err) + } + + mw := oidcMiddleware{ + authFailedHandler: oidcAuthFailed, + } + + for _, opt := range opts { + opt(&mw) + } + + if mw.config == nil { + mw.verifier = provider.Verifier(&oidc.Config{}) + } else { + mw.verifier = provider.Verifier(mw.config) + } + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + authHeader := r.Header.Get("Authorization") + bearerToken, ok := validateAuthHeader(authHeader, "Bearer ") + if !ok { + mw.authFailedHandler(errors.New("bearer token is missing or invalid")).ServeHTTP(w, r) + return + } + + idToken, err := mw.verifier.Verify(r.Context(), bearerToken) + if err != nil { + mw.authFailedHandler(err).ServeHTTP(w, r) + return + } + + ctx := contextWithIDToken(r.Context(), idToken) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} + +func contextWithIDToken(ctx context.Context, idToken *oidc.IDToken) context.Context { + return context.WithValue(ctx, idTokenContextKey, idToken) +} + +func IDTokenFromContext(ctx context.Context) *oidc.IDToken { + if idToken, ok := ctx.Value(idTokenContextKey).(*oidc.IDToken); ok { + return idToken + } + return nil +} + +func validateAuthHeader(s, scheme string) (string, bool) { + if len(s) >= len(scheme) && strings.EqualFold(s[0:len(scheme)], scheme) { + return s[len(scheme):], true + } + return s, false +} + +func oidcAuthFailed(err error) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + } +} + +func WithAuthFailedHandler(h func(error) http.HandlerFunc) OIDCMiddlewareOption { + return func(opt *oidcMiddleware) { + if h != nil { + opt.authFailedHandler = h + } + } +} + +func WithOIDCConfig(c oidc.Config) OIDCMiddlewareOption { + return func(opt *oidcMiddleware) { + opt.config = &c + } +} diff --git a/middlewarex/oidc_test.go b/middlewarex/oidc_test.go new file mode 100644 index 0000000..03b7e10 --- /dev/null +++ b/middlewarex/oidc_test.go @@ -0,0 +1,81 @@ +package middlewarex + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/coreos/go-oidc/v3/oidc" + "github.com/stretchr/testify/assert" +) + +func TestValidateAuthHeader(t *testing.T) { + tests := []struct { + authHeader string + scheme string + expectedToken string + expectedOk bool + }{ + { + authHeader: "", scheme: "bearer", + expectedToken: "", expectedOk: false, + }, + { + authHeader: "bearer token", scheme: "bearer ", + expectedToken: "token", expectedOk: true, + }, + { + authHeader: "BEARER token", scheme: "bearer ", + expectedToken: "token", expectedOk: true, + }, + } + + for _, tt := range tests { + t.Run("", func(t *testing.T) { + token, ok := validateAuthHeader(tt.authHeader, tt.scheme) + assert.Equal(t, tt.expectedOk, ok) + assert.Equal(t, tt.expectedToken, token) + }) + } +} + +func TestHandler(t *testing.T) { + issuer := "https://api.accounts.hgv.it" + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.NotNil(t, IDTokenFromContext(r.Context())) + w.WriteHeader(http.StatusTeapot) + }) + + t.Run("unauthorized", func(t *testing.T) { + h := OIDC(context.Background(), issuer)(next) + r := httptest.NewRequest(http.MethodGet, "/", nil) + w := httptest.NewRecorder() + h.ServeHTTP(w, r) + assert.Equal(t, http.StatusUnauthorized, w.Code) + }) + + t.Run("overwrite default error handler", func(t *testing.T) { + h := OIDC(context.Background(), issuer, WithAuthFailedHandler(func(err error) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + } + }))(next) + r := httptest.NewRequest(http.MethodGet, "/", nil) + w := httptest.NewRecorder() + h.ServeHTTP(w, r) + assert.Equal(t, http.StatusForbidden, w.Code) + }) + + t.Run("valid expired token", func(t *testing.T) { + h := OIDC(context.Background(), issuer, WithOIDCConfig(oidc.Config{ + SkipClientIDCheck: true, + SkipExpiryCheck: true, + }))(next) + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.Header.Add("Authorization", "Bearer eyJhbGciOiJSUzI1NiIsImtpZCI6IjJlNTc0NjE3LTJlYzYtNGNhNy1hYTE2LThiYTYyMWRlMGI3YSIsInR5cCI6IkpXVCJ9.eyJhdWQiOltdLCJjbGllbnRfaWQiOiJjMDI2NTZiZC00NzZkLTQ1MGYtOWMwZC0zN2ZiMDhiYTI3MjEiLCJleHAiOjE3MzE2Njk3NDMsImV4dCI6e30sImlhdCI6MTczMTY2NjE0MywiaXNzIjoiaHR0cHM6Ly9hcGkuYWNjb3VudHMuaGd2Lml0IiwianRpIjoiZjk3YWE1ODAtZjZmNC00ZGQ3LTlkMDgtMjM1YTM5ZGU4ZWZlIiwibmJmIjoxNzMxNjY2MTQzLCJzY3AiOltdLCJzdWIiOiJjMDI2NTZiZC00NzZkLTQ1MGYtOWMwZC0zN2ZiMDhiYTI3MjEifQ.IeIc2EWCYjH8EaYClYpaTpYz-DDRbpu4vRuzirmBXZy28r7OazSrJdRSEa2a_G9Yq0UzmJXeBtPAouvsQdwmHX1PdBFzwwqLPT4kXcxMmlX6RvnTy-95wVfXnJJP-cGU5U4sMKKFGnsecAQotesEsYk19Dxylr5RMA-DsgwwpN8GQuf4KdLJk4IDJx8Z-FlfAG4XMODGM2S3sqGCwc6b5nQUXa_cUTIMqJCyUdb3Kd3OcQHKEK0o0esG1CBgqj3RrRE98BejeEjR5LOYiQpY1aAklmxa_3UOtEi9Bej1PRyybRxV7QbNE8_K0WVdj3CCedbtpK7DB0mNGCtas2bjiFxsr9MBHUtDcU3taXEoEkSqye7vIbLgd66SFm5gq78-PeJEvbwYqpt4LB7b7F-ZpyhCU-3T3SNkMPHY-q7hIBPauRbJbtWdK3w_xjjjCJdgjspk-CEyOUfhogjKmavxcuuXOGBphOeJ7WCRMTlmv9ira0DZqwBCQTGitkGGT98l4guaIYoB27Zsl-wdgxK2F0AwjvHFTYNUsG3Nf9NJ4ULjPMusBBA9hHBoO1UrlNWgXEpJWvr5YV_vt0Omlqvv-ci7M3Rx1-MjRyBYTQRxVRLhtDtGK4TbW4jCEIE38_k5IDqH6WxaUsgxTxFu8rx5xWhpRlKuIQRrDyWA1ylMo_U") + w := httptest.NewRecorder() + h.ServeHTTP(w, r) + assert.Equal(t, http.StatusTeapot, w.Code) + }) +} From 928545e488123ac60d890a169a18be05d02e4fc2 Mon Sep 17 00:00:00 2001 From: Daniel Francesconi Date: Wed, 20 Nov 2024 11:31:13 +0100 Subject: [PATCH 2/4] refactor: change return type of helper function --- middlewarex/oidc.go | 12 ++++-------- middlewarex/oidc_test.go | 14 ++++++++------ 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/middlewarex/oidc.go b/middlewarex/oidc.go index 80547e7..15a6bde 100644 --- a/middlewarex/oidc.go +++ b/middlewarex/oidc.go @@ -20,9 +20,7 @@ type ( oidcContextKey int ) -const ( - idTokenContextKey oidcContextKey = iota -) +const idTokenContextKey oidcContextKey = iota func OIDC(ctx context.Context, issuer string, opts ...OIDCMiddlewareOption) func(next http.Handler) http.Handler { provider, err := oidc.NewProvider(ctx, issuer) @@ -69,11 +67,9 @@ func contextWithIDToken(ctx context.Context, idToken *oidc.IDToken) context.Cont return context.WithValue(ctx, idTokenContextKey, idToken) } -func IDTokenFromContext(ctx context.Context) *oidc.IDToken { - if idToken, ok := ctx.Value(idTokenContextKey).(*oidc.IDToken); ok { - return idToken - } - return nil +func IDTokenFromContext(ctx context.Context) (idToken *oidc.IDToken, ok bool) { + idToken, ok = ctx.Value(idTokenContextKey).(*oidc.IDToken) + return } func validateAuthHeader(s, scheme string) (string, bool) { diff --git a/middlewarex/oidc_test.go b/middlewarex/oidc_test.go index 03b7e10..2e07844 100644 --- a/middlewarex/oidc_test.go +++ b/middlewarex/oidc_test.go @@ -15,26 +15,26 @@ func TestValidateAuthHeader(t *testing.T) { authHeader string scheme string expectedToken string - expectedOk bool + expectedOK bool }{ { authHeader: "", scheme: "bearer", - expectedToken: "", expectedOk: false, + expectedToken: "", expectedOK: false, }, { authHeader: "bearer token", scheme: "bearer ", - expectedToken: "token", expectedOk: true, + expectedToken: "token", expectedOK: true, }, { authHeader: "BEARER token", scheme: "bearer ", - expectedToken: "token", expectedOk: true, + expectedToken: "token", expectedOK: true, }, } for _, tt := range tests { t.Run("", func(t *testing.T) { token, ok := validateAuthHeader(tt.authHeader, tt.scheme) - assert.Equal(t, tt.expectedOk, ok) + assert.Equal(t, tt.expectedOK, ok) assert.Equal(t, tt.expectedToken, token) }) } @@ -43,7 +43,9 @@ func TestValidateAuthHeader(t *testing.T) { func TestHandler(t *testing.T) { issuer := "https://api.accounts.hgv.it" next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.NotNil(t, IDTokenFromContext(r.Context())) + idToken, ok := IDTokenFromContext(r.Context()) + assert.True(t, ok) + assert.NotNil(t, idToken) w.WriteHeader(http.StatusTeapot) }) From bb524b7a85aa9f824992931ba952b17d6d316f4a Mon Sep 17 00:00:00 2001 From: Daniel Francesconi Date: Fri, 20 Jun 2025 15:23:04 +0200 Subject: [PATCH 3/4] feat: move OIDC middleware into its own package and add email check --- middlewarex/oidc.go | 100 -------------------------- middlewarex/oidc_test.go | 83 --------------------- oidcx/middleware.go | 152 +++++++++++++++++++++++++++++++++++++++ oidcx/middleware_test.go | 119 ++++++++++++++++++++++++++++++ 4 files changed, 271 insertions(+), 183 deletions(-) delete mode 100644 middlewarex/oidc.go delete mode 100644 middlewarex/oidc_test.go create mode 100644 oidcx/middleware.go create mode 100644 oidcx/middleware_test.go diff --git a/middlewarex/oidc.go b/middlewarex/oidc.go deleted file mode 100644 index 15a6bde..0000000 --- a/middlewarex/oidc.go +++ /dev/null @@ -1,100 +0,0 @@ -package middlewarex - -import ( - "context" - "errors" - "net/http" - "strings" - - "github.com/coreos/go-oidc/v3/oidc" -) - -type ( - OIDCMiddlewareOption func(*oidcMiddleware) - - oidcMiddleware struct { - authFailedHandler func(error) http.HandlerFunc - config *oidc.Config - verifier *oidc.IDTokenVerifier - } - oidcContextKey int -) - -const idTokenContextKey oidcContextKey = iota - -func OIDC(ctx context.Context, issuer string, opts ...OIDCMiddlewareOption) func(next http.Handler) http.Handler { - provider, err := oidc.NewProvider(ctx, issuer) - if err != nil { - panic(err) - } - - mw := oidcMiddleware{ - authFailedHandler: oidcAuthFailed, - } - - for _, opt := range opts { - opt(&mw) - } - - if mw.config == nil { - mw.verifier = provider.Verifier(&oidc.Config{}) - } else { - mw.verifier = provider.Verifier(mw.config) - } - - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - authHeader := r.Header.Get("Authorization") - bearerToken, ok := validateAuthHeader(authHeader, "Bearer ") - if !ok { - mw.authFailedHandler(errors.New("bearer token is missing or invalid")).ServeHTTP(w, r) - return - } - - idToken, err := mw.verifier.Verify(r.Context(), bearerToken) - if err != nil { - mw.authFailedHandler(err).ServeHTTP(w, r) - return - } - - ctx := contextWithIDToken(r.Context(), idToken) - next.ServeHTTP(w, r.WithContext(ctx)) - }) - } -} - -func contextWithIDToken(ctx context.Context, idToken *oidc.IDToken) context.Context { - return context.WithValue(ctx, idTokenContextKey, idToken) -} - -func IDTokenFromContext(ctx context.Context) (idToken *oidc.IDToken, ok bool) { - idToken, ok = ctx.Value(idTokenContextKey).(*oidc.IDToken) - return -} - -func validateAuthHeader(s, scheme string) (string, bool) { - if len(s) >= len(scheme) && strings.EqualFold(s[0:len(scheme)], scheme) { - return s[len(scheme):], true - } - return s, false -} - -func oidcAuthFailed(err error) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusUnauthorized) - } -} - -func WithAuthFailedHandler(h func(error) http.HandlerFunc) OIDCMiddlewareOption { - return func(opt *oidcMiddleware) { - if h != nil { - opt.authFailedHandler = h - } - } -} - -func WithOIDCConfig(c oidc.Config) OIDCMiddlewareOption { - return func(opt *oidcMiddleware) { - opt.config = &c - } -} diff --git a/middlewarex/oidc_test.go b/middlewarex/oidc_test.go deleted file mode 100644 index 2e07844..0000000 --- a/middlewarex/oidc_test.go +++ /dev/null @@ -1,83 +0,0 @@ -package middlewarex - -import ( - "context" - "net/http" - "net/http/httptest" - "testing" - - "github.com/coreos/go-oidc/v3/oidc" - "github.com/stretchr/testify/assert" -) - -func TestValidateAuthHeader(t *testing.T) { - tests := []struct { - authHeader string - scheme string - expectedToken string - expectedOK bool - }{ - { - authHeader: "", scheme: "bearer", - expectedToken: "", expectedOK: false, - }, - { - authHeader: "bearer token", scheme: "bearer ", - expectedToken: "token", expectedOK: true, - }, - { - authHeader: "BEARER token", scheme: "bearer ", - expectedToken: "token", expectedOK: true, - }, - } - - for _, tt := range tests { - t.Run("", func(t *testing.T) { - token, ok := validateAuthHeader(tt.authHeader, tt.scheme) - assert.Equal(t, tt.expectedOK, ok) - assert.Equal(t, tt.expectedToken, token) - }) - } -} - -func TestHandler(t *testing.T) { - issuer := "https://api.accounts.hgv.it" - next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - idToken, ok := IDTokenFromContext(r.Context()) - assert.True(t, ok) - assert.NotNil(t, idToken) - w.WriteHeader(http.StatusTeapot) - }) - - t.Run("unauthorized", func(t *testing.T) { - h := OIDC(context.Background(), issuer)(next) - r := httptest.NewRequest(http.MethodGet, "/", nil) - w := httptest.NewRecorder() - h.ServeHTTP(w, r) - assert.Equal(t, http.StatusUnauthorized, w.Code) - }) - - t.Run("overwrite default error handler", func(t *testing.T) { - h := OIDC(context.Background(), issuer, WithAuthFailedHandler(func(err error) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusForbidden) - } - }))(next) - r := httptest.NewRequest(http.MethodGet, "/", nil) - w := httptest.NewRecorder() - h.ServeHTTP(w, r) - assert.Equal(t, http.StatusForbidden, w.Code) - }) - - t.Run("valid expired token", func(t *testing.T) { - h := OIDC(context.Background(), issuer, WithOIDCConfig(oidc.Config{ - SkipClientIDCheck: true, - SkipExpiryCheck: true, - }))(next) - r := httptest.NewRequest(http.MethodGet, "/", nil) - r.Header.Add("Authorization", "Bearer eyJhbGciOiJSUzI1NiIsImtpZCI6IjJlNTc0NjE3LTJlYzYtNGNhNy1hYTE2LThiYTYyMWRlMGI3YSIsInR5cCI6IkpXVCJ9.eyJhdWQiOltdLCJjbGllbnRfaWQiOiJjMDI2NTZiZC00NzZkLTQ1MGYtOWMwZC0zN2ZiMDhiYTI3MjEiLCJleHAiOjE3MzE2Njk3NDMsImV4dCI6e30sImlhdCI6MTczMTY2NjE0MywiaXNzIjoiaHR0cHM6Ly9hcGkuYWNjb3VudHMuaGd2Lml0IiwianRpIjoiZjk3YWE1ODAtZjZmNC00ZGQ3LTlkMDgtMjM1YTM5ZGU4ZWZlIiwibmJmIjoxNzMxNjY2MTQzLCJzY3AiOltdLCJzdWIiOiJjMDI2NTZiZC00NzZkLTQ1MGYtOWMwZC0zN2ZiMDhiYTI3MjEifQ.IeIc2EWCYjH8EaYClYpaTpYz-DDRbpu4vRuzirmBXZy28r7OazSrJdRSEa2a_G9Yq0UzmJXeBtPAouvsQdwmHX1PdBFzwwqLPT4kXcxMmlX6RvnTy-95wVfXnJJP-cGU5U4sMKKFGnsecAQotesEsYk19Dxylr5RMA-DsgwwpN8GQuf4KdLJk4IDJx8Z-FlfAG4XMODGM2S3sqGCwc6b5nQUXa_cUTIMqJCyUdb3Kd3OcQHKEK0o0esG1CBgqj3RrRE98BejeEjR5LOYiQpY1aAklmxa_3UOtEi9Bej1PRyybRxV7QbNE8_K0WVdj3CCedbtpK7DB0mNGCtas2bjiFxsr9MBHUtDcU3taXEoEkSqye7vIbLgd66SFm5gq78-PeJEvbwYqpt4LB7b7F-ZpyhCU-3T3SNkMPHY-q7hIBPauRbJbtWdK3w_xjjjCJdgjspk-CEyOUfhogjKmavxcuuXOGBphOeJ7WCRMTlmv9ira0DZqwBCQTGitkGGT98l4guaIYoB27Zsl-wdgxK2F0AwjvHFTYNUsG3Nf9NJ4ULjPMusBBA9hHBoO1UrlNWgXEpJWvr5YV_vt0Omlqvv-ci7M3Rx1-MjRyBYTQRxVRLhtDtGK4TbW4jCEIE38_k5IDqH6WxaUsgxTxFu8rx5xWhpRlKuIQRrDyWA1ylMo_U") - w := httptest.NewRecorder() - h.ServeHTTP(w, r) - assert.Equal(t, http.StatusTeapot, w.Code) - }) -} diff --git a/oidcx/middleware.go b/oidcx/middleware.go new file mode 100644 index 0000000..d0e2d0a --- /dev/null +++ b/oidcx/middleware.go @@ -0,0 +1,152 @@ +package oidcx + +import ( + "context" + "errors" + "fmt" + "net/http" + "strings" + + "github.com/coreos/go-oidc/v3/oidc" +) + +type Middleware struct { + o *middlewareOptions + v *oidc.IDTokenVerifier +} + +type middlewareOptions struct { + ClientID string + SkipClientIDCheck bool + Email string + SkipEmailCheck bool + InsecureSkipSignatureCheck bool + AuthFailedHandler func(error) http.HandlerFunc +} + +type MiddlewareOption func(*middlewareOptions) + +type idTokenContextKey struct{} + +func NewMiddleware(ctx context.Context, issuer string, opts ...MiddlewareOption) *Middleware { + provider, err := oidc.NewProvider(ctx, issuer) + if err != nil { + panic(err) + } + + o := &middlewareOptions{ + AuthFailedHandler: defaultAuthFailedHandler, + } + + for _, opt := range opts { + opt(o) + } + + mw := Middleware{ + o: o, + v: provider.VerifierContext(ctx, &oidc.Config{ + ClientID: o.ClientID, + SkipClientIDCheck: o.SkipClientIDCheck, + InsecureSkipSignatureCheck: o.InsecureSkipSignatureCheck, + }), + } + + return &mw +} + +func (mw *Middleware) Handler(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + bearerToken, ok := validateAuthHeader(r.Header.Get("Authorization"), "Bearer ") + if !ok { + mw.o.AuthFailedHandler(errors.New("bearer token is missing or invalid")).ServeHTTP(w, r) + return + } + + idToken, err := mw.v.Verify(ctx, bearerToken) + if err != nil { + mw.o.AuthFailedHandler(err).ServeHTTP(w, r) + return + } + + if !mw.o.SkipEmailCheck { + if mw.o.Email == "" { + mw.o.AuthFailedHandler(errors.New("invalid configuration, Email must be provided or SkipEmailCheck must be set")).ServeHTTP(w, r) + return + } + + var claims struct { + Email string `json:"email"` + } + if err = idToken.Claims(&claims); err != nil { + mw.o.AuthFailedHandler(err).ServeHTTP(w, r) + return + } + if !strings.EqualFold(mw.o.Email, claims.Email) { + mw.o.AuthFailedHandler(fmt.Errorf("expected email %q got %q", mw.o.Email, claims.Email)).ServeHTTP(w, r) + return + } + } + + ctx = context.WithValue(ctx, idTokenContextKey{}, idToken) + next.ServeHTTP(w, r.WithContext(ctx)) + } + return http.HandlerFunc(fn) +} + +func IDTokenFromContext(ctx context.Context) (*oidc.IDToken, bool) { + octx, ok := ctx.Value(idTokenContextKey{}).(*oidc.IDToken) + return octx, ok +} + +func WithAuthFailedHandler(h func(error) http.HandlerFunc) MiddlewareOption { + return func(o *middlewareOptions) { + if h != nil { + o.AuthFailedHandler = h + } + } +} + +func WithClientID(clientID string) MiddlewareOption { + return func(o *middlewareOptions) { + o.ClientID = clientID + } +} + +func WithSkipClientIDCheck() MiddlewareOption { + return func(o *middlewareOptions) { + o.SkipClientIDCheck = true + } +} + +func WithEmail(email string) MiddlewareOption { + return func(o *middlewareOptions) { + o.Email = email + } +} + +func WithSkipEmailCheck() MiddlewareOption { + return func(o *middlewareOptions) { + o.SkipEmailCheck = true + } +} + +func withInsecureSkipSignatureCheck() MiddlewareOption { + return func(o *middlewareOptions) { + o.InsecureSkipSignatureCheck = true + } +} + +func validateAuthHeader(s, scheme string) (string, bool) { + if len(s) >= len(scheme) && strings.EqualFold(s[0:len(scheme)], scheme) { + return s[len(scheme):], true + } + return s, false +} + +func defaultAuthFailedHandler(err error) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + } +} diff --git a/oidcx/middleware_test.go b/oidcx/middleware_test.go new file mode 100644 index 0000000..fb54d1d --- /dev/null +++ b/oidcx/middleware_test.go @@ -0,0 +1,119 @@ +package oidcx + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestValidateAuthHeader(t *testing.T) { + tests := []struct { + authHeader string + scheme string + expectedToken string + expectedOK bool + }{ + { + authHeader: "", scheme: "bearer", + expectedToken: "", expectedOK: false, + }, + { + authHeader: "bearer token", scheme: "bearer ", + expectedToken: "token", expectedOK: true, + }, + { + authHeader: "BEARER token", scheme: "bearer ", + expectedToken: "token", expectedOK: true, + }, + } + + for _, tt := range tests { + t.Run("", func(t *testing.T) { + token, ok := validateAuthHeader(tt.authHeader, tt.scheme) + assert.Equal(t, tt.expectedOK, ok) + assert.Equal(t, tt.expectedToken, token) + }) + } +} + +func TestHandler(t *testing.T) { + issuer := "https://api.accounts.hgv.it" + + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + idToken, ok := IDTokenFromContext(r.Context()) + assert.True(t, ok) + assert.NotNil(t, idToken) + w.WriteHeader(http.StatusTeapot) + }) + + makeRequest := func(h http.Handler, token string) *httptest.ResponseRecorder { + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.Header.Add("Authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ0ZXN0LXVzZXIiLCJhdWQiOlsidGVzdC1jbGllbnQiXSwiaXNzIjoiaHR0cHM6Ly9hcGkuYWNjb3VudHMuaGd2Lml0IiwiaWF0IjoxNjAwMDAwMDAwLCJleHAiOjIwMDAwMDAwMDB9.hJREizNgcJpnEEyZ5lE5VC9tPY45JIFJoxm9ZlIPgTI") + w := httptest.NewRecorder() + h.ServeHTTP(w, r) + return w + } + + t.Run("unauthorized with no token", func(t *testing.T) { + h := NewMiddleware(context.Background(), issuer).Handler(next) + w := makeRequest(h, "") + assert.Equal(t, http.StatusUnauthorized, w.Code) + }) + + t.Run("custom error handler returns 403", func(t *testing.T) { + h := NewMiddleware(context.Background(), issuer, + WithAuthFailedHandler(func(err error) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + } + }), + ).Handler(next) + w := makeRequest(h, "") + assert.Equal(t, http.StatusForbidden, w.Code) + }) + + t.Run("email config required but missing", func(t *testing.T) { + h := NewMiddleware(context.Background(), issuer, + WithAuthFailedHandler(func(err error) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(err.Error())) + } + }), + WithSkipClientIDCheck(), + withInsecureSkipSignatureCheck(), + ).Handler(next) + w := makeRequest(h, "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ0ZXN0LXVzZXIiLCJhdWQiOlsidGVzdC1jbGllbnQiXSwiaXNzIjoiaHR0cHM6Ly9hcGkuYWNjb3VudHMuaGd2Lml0IiwiaWF0IjoxNjAwMDAwMDAwLCJleHAiOjIwMDAwMDAwMDB9.hJREizNgcJpnEEyZ5lE5VC9tPY45JIFJoxm9ZlIPgTI") + b, _ := io.ReadAll(w.Body) + assert.Equal(t, "invalid configuration, Email must be provided or SkipEmailCheck must be set", string(b)) + }) + + t.Run("email mismatch", func(t *testing.T) { + h := NewMiddleware(context.Background(), issuer, + WithAuthFailedHandler(func(err error) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(err.Error())) + } + }), + WithSkipClientIDCheck(), + WithEmail("test@hgv.it"), + withInsecureSkipSignatureCheck(), + ).Handler(next) + w := makeRequest(h, "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ0ZXN0LXVzZXIiLCJhdWQiOlsidGVzdC1jbGllbnQiXSwiaXNzIjoiaHR0cHM6Ly9hcGkuYWNjb3VudHMuaGd2Lml0IiwiaWF0IjoxNjAwMDAwMDAwLCJleHAiOjIwMDAwMDAwMDB9.hJREizNgcJpnEEyZ5lE5VC9tPY45JIFJoxm9ZlIPgTI") + b, _ := io.ReadAll(w.Body) + assert.Equal(t, "expected email \"test@hgv.it\" got \"\"", string(b)) + }) + + t.Run("valid expired token without email check", func(t *testing.T) { + h := NewMiddleware(context.Background(), issuer, + WithSkipClientIDCheck(), + WithSkipEmailCheck(), + withInsecureSkipSignatureCheck(), + ).Handler(next) + w := makeRequest(h, "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ0ZXN0LXVzZXIiLCJhdWQiOlsidGVzdC1jbGllbnQiXSwiaXNzIjoiaHR0cHM6Ly9hcGkuYWNjb3VudHMuaGd2Lml0IiwiaWF0IjoxNjAwMDAwMDAwLCJleHAiOjIwMDAwMDAwMDB9.hJREizNgcJpnEEyZ5lE5VC9tPY45JIFJoxm9ZlIPgTI") + assert.Equal(t, http.StatusTeapot, w.Code) + }) +} From 17206eeef8b8bffce1c8f7fde54cbd4d90c9145b Mon Sep 17 00:00:00 2001 From: Daniel Francesconi Date: Mon, 23 Jun 2025 14:03:06 +0200 Subject: [PATCH 4/4] refactor: inline Middleware allocation --- oidcx/middleware.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/oidcx/middleware.go b/oidcx/middleware.go index d0e2d0a..43521dc 100644 --- a/oidcx/middleware.go +++ b/oidcx/middleware.go @@ -42,7 +42,7 @@ func NewMiddleware(ctx context.Context, issuer string, opts ...MiddlewareOption) opt(o) } - mw := Middleware{ + return &Middleware{ o: o, v: provider.VerifierContext(ctx, &oidc.Config{ ClientID: o.ClientID, @@ -50,8 +50,6 @@ func NewMiddleware(ctx context.Context, issuer string, opts ...MiddlewareOption) InsecureSkipSignatureCheck: o.InsecureSkipSignatureCheck, }), } - - return &mw } func (mw *Middleware) Handler(next http.Handler) http.Handler {