diff --git a/go.mod b/go.mod index f44cf23..cf740e9 100644 --- a/go.mod +++ b/go.mod @@ -3,14 +3,18 @@ module github.com/HGV/x go 1.22 require ( + github.com/coreos/go-oidc/v3 v3.11.0 github.com/jackc/pgx/v5 v5.7.2 github.com/stretchr/testify v1.10.0 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/go-jose/go-jose/v4 v4.0.2 // indirect github.com/kr/text v0.2.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rogpeppe/go-internal v1.13.1 // indirect + golang.org/x/crypto v0.31.0 // indirect + golang.org/x/oauth2 v0.21.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index bb3efd2..f2a3b1c 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,12 @@ +github.com/coreos/go-oidc/v3 v3.11.0 h1:Ia3MxdwpSw702YW0xgfmP1GVCMA9aEFWu12XUZ3/OtI= +github.com/coreos/go-oidc/v3 v3.11.0/go.mod h1:gE3LgjOgFoHi9a4ce4/tJczr0Ai2/BoDhf0r5lltWI0= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-jose/go-jose/v4 v4.0.2 h1:R3l3kkBds16bO7ZFAEEcofK0MkrAJt3jlJznWZG0nvk= +github.com/go-jose/go-jose/v4 v4.0.2/go.mod h1:WVf9LFMHh/QVrmqrOfqun0C45tMe3RoiKJMPvgWwLfY= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= @@ -21,6 +27,8 @@ github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOf github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= +golang.org/x/oauth2 v0.21.0 h1:tsimM75w1tF/uws5rbeHzIWxEqElMehnc+iW793zsZs= +golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= diff --git a/oidcx/middleware.go b/oidcx/middleware.go new file mode 100644 index 0000000..43521dc --- /dev/null +++ b/oidcx/middleware.go @@ -0,0 +1,150 @@ +package oidcx + +import ( + "context" + "errors" + "fmt" + "net/http" + "strings" + + "github.com/coreos/go-oidc/v3/oidc" +) + +type Middleware struct { + o *middlewareOptions + v *oidc.IDTokenVerifier +} + +type middlewareOptions struct { + ClientID string + SkipClientIDCheck bool + Email string + SkipEmailCheck bool + InsecureSkipSignatureCheck bool + AuthFailedHandler func(error) http.HandlerFunc +} + +type MiddlewareOption func(*middlewareOptions) + +type idTokenContextKey struct{} + +func NewMiddleware(ctx context.Context, issuer string, opts ...MiddlewareOption) *Middleware { + provider, err := oidc.NewProvider(ctx, issuer) + if err != nil { + panic(err) + } + + o := &middlewareOptions{ + AuthFailedHandler: defaultAuthFailedHandler, + } + + for _, opt := range opts { + opt(o) + } + + return &Middleware{ + o: o, + v: provider.VerifierContext(ctx, &oidc.Config{ + ClientID: o.ClientID, + SkipClientIDCheck: o.SkipClientIDCheck, + InsecureSkipSignatureCheck: o.InsecureSkipSignatureCheck, + }), + } +} + +func (mw *Middleware) Handler(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + bearerToken, ok := validateAuthHeader(r.Header.Get("Authorization"), "Bearer ") + if !ok { + mw.o.AuthFailedHandler(errors.New("bearer token is missing or invalid")).ServeHTTP(w, r) + return + } + + idToken, err := mw.v.Verify(ctx, bearerToken) + if err != nil { + mw.o.AuthFailedHandler(err).ServeHTTP(w, r) + return + } + + if !mw.o.SkipEmailCheck { + if mw.o.Email == "" { + mw.o.AuthFailedHandler(errors.New("invalid configuration, Email must be provided or SkipEmailCheck must be set")).ServeHTTP(w, r) + return + } + + var claims struct { + Email string `json:"email"` + } + if err = idToken.Claims(&claims); err != nil { + mw.o.AuthFailedHandler(err).ServeHTTP(w, r) + return + } + if !strings.EqualFold(mw.o.Email, claims.Email) { + mw.o.AuthFailedHandler(fmt.Errorf("expected email %q got %q", mw.o.Email, claims.Email)).ServeHTTP(w, r) + return + } + } + + ctx = context.WithValue(ctx, idTokenContextKey{}, idToken) + next.ServeHTTP(w, r.WithContext(ctx)) + } + return http.HandlerFunc(fn) +} + +func IDTokenFromContext(ctx context.Context) (*oidc.IDToken, bool) { + octx, ok := ctx.Value(idTokenContextKey{}).(*oidc.IDToken) + return octx, ok +} + +func WithAuthFailedHandler(h func(error) http.HandlerFunc) MiddlewareOption { + return func(o *middlewareOptions) { + if h != nil { + o.AuthFailedHandler = h + } + } +} + +func WithClientID(clientID string) MiddlewareOption { + return func(o *middlewareOptions) { + o.ClientID = clientID + } +} + +func WithSkipClientIDCheck() MiddlewareOption { + return func(o *middlewareOptions) { + o.SkipClientIDCheck = true + } +} + +func WithEmail(email string) MiddlewareOption { + return func(o *middlewareOptions) { + o.Email = email + } +} + +func WithSkipEmailCheck() MiddlewareOption { + return func(o *middlewareOptions) { + o.SkipEmailCheck = true + } +} + +func withInsecureSkipSignatureCheck() MiddlewareOption { + return func(o *middlewareOptions) { + o.InsecureSkipSignatureCheck = true + } +} + +func validateAuthHeader(s, scheme string) (string, bool) { + if len(s) >= len(scheme) && strings.EqualFold(s[0:len(scheme)], scheme) { + return s[len(scheme):], true + } + return s, false +} + +func defaultAuthFailedHandler(err error) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + } +} diff --git a/oidcx/middleware_test.go b/oidcx/middleware_test.go new file mode 100644 index 0000000..fb54d1d --- /dev/null +++ b/oidcx/middleware_test.go @@ -0,0 +1,119 @@ +package oidcx + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestValidateAuthHeader(t *testing.T) { + tests := []struct { + authHeader string + scheme string + expectedToken string + expectedOK bool + }{ + { + authHeader: "", scheme: "bearer", + expectedToken: "", expectedOK: false, + }, + { + authHeader: "bearer token", scheme: "bearer ", + expectedToken: "token", expectedOK: true, + }, + { + authHeader: "BEARER token", scheme: "bearer ", + expectedToken: "token", expectedOK: true, + }, + } + + for _, tt := range tests { + t.Run("", func(t *testing.T) { + token, ok := validateAuthHeader(tt.authHeader, tt.scheme) + assert.Equal(t, tt.expectedOK, ok) + assert.Equal(t, tt.expectedToken, token) + }) + } +} + +func TestHandler(t *testing.T) { + issuer := "https://api.accounts.hgv.it" + + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + idToken, ok := IDTokenFromContext(r.Context()) + assert.True(t, ok) + assert.NotNil(t, idToken) + w.WriteHeader(http.StatusTeapot) + }) + + makeRequest := func(h http.Handler, token string) *httptest.ResponseRecorder { + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.Header.Add("Authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ0ZXN0LXVzZXIiLCJhdWQiOlsidGVzdC1jbGllbnQiXSwiaXNzIjoiaHR0cHM6Ly9hcGkuYWNjb3VudHMuaGd2Lml0IiwiaWF0IjoxNjAwMDAwMDAwLCJleHAiOjIwMDAwMDAwMDB9.hJREizNgcJpnEEyZ5lE5VC9tPY45JIFJoxm9ZlIPgTI") + w := httptest.NewRecorder() + h.ServeHTTP(w, r) + return w + } + + t.Run("unauthorized with no token", func(t *testing.T) { + h := NewMiddleware(context.Background(), issuer).Handler(next) + w := makeRequest(h, "") + assert.Equal(t, http.StatusUnauthorized, w.Code) + }) + + t.Run("custom error handler returns 403", func(t *testing.T) { + h := NewMiddleware(context.Background(), issuer, + WithAuthFailedHandler(func(err error) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + } + }), + ).Handler(next) + w := makeRequest(h, "") + assert.Equal(t, http.StatusForbidden, w.Code) + }) + + t.Run("email config required but missing", func(t *testing.T) { + h := NewMiddleware(context.Background(), issuer, + WithAuthFailedHandler(func(err error) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(err.Error())) + } + }), + WithSkipClientIDCheck(), + withInsecureSkipSignatureCheck(), + ).Handler(next) + w := makeRequest(h, "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ0ZXN0LXVzZXIiLCJhdWQiOlsidGVzdC1jbGllbnQiXSwiaXNzIjoiaHR0cHM6Ly9hcGkuYWNjb3VudHMuaGd2Lml0IiwiaWF0IjoxNjAwMDAwMDAwLCJleHAiOjIwMDAwMDAwMDB9.hJREizNgcJpnEEyZ5lE5VC9tPY45JIFJoxm9ZlIPgTI") + b, _ := io.ReadAll(w.Body) + assert.Equal(t, "invalid configuration, Email must be provided or SkipEmailCheck must be set", string(b)) + }) + + t.Run("email mismatch", func(t *testing.T) { + h := NewMiddleware(context.Background(), issuer, + WithAuthFailedHandler(func(err error) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(err.Error())) + } + }), + WithSkipClientIDCheck(), + WithEmail("test@hgv.it"), + withInsecureSkipSignatureCheck(), + ).Handler(next) + w := makeRequest(h, "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ0ZXN0LXVzZXIiLCJhdWQiOlsidGVzdC1jbGllbnQiXSwiaXNzIjoiaHR0cHM6Ly9hcGkuYWNjb3VudHMuaGd2Lml0IiwiaWF0IjoxNjAwMDAwMDAwLCJleHAiOjIwMDAwMDAwMDB9.hJREizNgcJpnEEyZ5lE5VC9tPY45JIFJoxm9ZlIPgTI") + b, _ := io.ReadAll(w.Body) + assert.Equal(t, "expected email \"test@hgv.it\" got \"\"", string(b)) + }) + + t.Run("valid expired token without email check", func(t *testing.T) { + h := NewMiddleware(context.Background(), issuer, + WithSkipClientIDCheck(), + WithSkipEmailCheck(), + withInsecureSkipSignatureCheck(), + ).Handler(next) + w := makeRequest(h, "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ0ZXN0LXVzZXIiLCJhdWQiOlsidGVzdC1jbGllbnQiXSwiaXNzIjoiaHR0cHM6Ly9hcGkuYWNjb3VudHMuaGd2Lml0IiwiaWF0IjoxNjAwMDAwMDAwLCJleHAiOjIwMDAwMDAwMDB9.hJREizNgcJpnEEyZ5lE5VC9tPY45JIFJoxm9ZlIPgTI") + assert.Equal(t, http.StatusTeapot, w.Code) + }) +}