diff --git a/api/client/client.go b/api/client/client.go index 65275cd..1c53af3 100644 --- a/api/client/client.go +++ b/api/client/client.go @@ -292,8 +292,8 @@ func (c *Client) AuthorizedRequestURLParams(r *http.Request) (string, bool, erro return "", false, ErrMissingAuthURLParams } // get the strToken and email from the request headers - strToken, _ := url.QueryUnescape(params.Get(DefaultAuthTokenURLParam)) - userEmail, _ := url.QueryUnescape(params.Get(DefaultAuthEmailURLParam)) + strToken := params.Get(DefaultAuthTokenURLParam) + userEmail := params.Get(DefaultAuthEmailURLParam) // check if the token and email are valid if strToken == "" || userEmail == "" { return "", false, ErrMissingAuthURLParams diff --git a/api/client/client_test.go b/api/client/client_test.go index e331402..5d90b91 100644 --- a/api/client/client_test.go +++ b/api/client/client_test.go @@ -6,6 +6,7 @@ import ( "fmt" "log" "math" + "math/big" "net/http" "net/http/httptest" "net/url" @@ -24,8 +25,8 @@ import ( var ( // server testServer = "127.0.0.1" - testPort = 8082 - testEmailPort = 2526 + testPort = randomPort() + testEmailPort = randomPort() testEndpoint = fmt.Sprintf("http://%s:%d", testServer, testPort) testServerSecret = "serversecret" testServerEmail = "server@email.com" @@ -591,8 +592,8 @@ func TestAuthorizedRequestURLParams(t *testing.T) { } reqURL, _ := url.Parse("http://example.com/test") reqURL.RawQuery = fmt.Sprintf("%s=%s&%s=%s", - DefaultAuthEmailURLParam, url.QueryEscape(testEmail), - DefaultAuthTokenURLParam, url.QueryEscape(testToken.String())) + DefaultAuthEmailURLParam, testEmail, + DefaultAuthTokenURLParam, testToken.String()) req, err := http.NewRequest(http.MethodGet, reqURL.String(), nil) if err != nil { @@ -643,8 +644,8 @@ func TestAuthorizedRequestURLParams(t *testing.T) { invalidEmail := "noemail.com" reqURL, _ := url.Parse("http://example.com/test") reqURL.RawQuery = fmt.Sprintf("%s=%s&%s=%s", - DefaultAuthEmailURLParam, url.QueryEscape(invalidEmail), - DefaultAuthTokenURLParam, url.QueryEscape(testToken.String())) + DefaultAuthEmailURLParam, invalidEmail, + DefaultAuthTokenURLParam, testToken.String()) req, err := http.NewRequest(http.MethodGet, reqURL.String(), nil) if err != nil { @@ -913,3 +914,11 @@ func randomEmail() string { randStr := strings.ToLower(rand.Text()[0:8]) return fmt.Sprintf("%s@example.com", randStr) } + +func randomPort() int { + port, err := rand.Int(rand.Reader, big.NewInt(65535-1024)) + if err != nil { + panic(err) + } + return int(port.Int64()) + 1024 +} diff --git a/internal/base64url/base64url.go b/internal/base64url/base64url.go new file mode 100644 index 0000000..9d0468d --- /dev/null +++ b/internal/base64url/base64url.go @@ -0,0 +1,36 @@ +// base64url package provides a way to encode and decode data using a modified +// base64 encoding scheme that is safe for URLs. +package base64url + +import ( + "bytes" + "encoding/base64" +) + +// RawEncode function encodes the input byte slice into a base64 URL-safe +// encoded byte slice. It encodes the input using the raw standard base64 +// encoding and then replaces the URL-unsafe characters. +func RawEncode(src []byte) []byte { + rawBase64 := make([]byte, base64.RawStdEncoding.EncodedLen(len(src))) + base64.RawStdEncoding.Encode(rawBase64, src) + // replace the characters that are not URL safe (+, /) + rawBase64 = bytes.ReplaceAll(rawBase64, []byte("+"), []byte("-")) + rawBase64 = bytes.ReplaceAll(rawBase64, []byte("/"), []byte("_")) + return rawBase64 +} + +// RawDecode function decodes the input byte slice from a base64 URL-safe +// encoded byte slice. It replaces the URL-unsafe characters with their +// standard base64 counterparts and then decodes the result using the raw +// standard base64 decoding. +func RawDecode(data []byte) ([]byte, error) { + // recover the characters that are not URL safe (+, /) + rawBase64 := bytes.ReplaceAll(data, []byte("-"), []byte("+")) + rawBase64 = bytes.ReplaceAll(rawBase64, []byte("_"), []byte("/")) + // return the decoded string from basic base64 + res := make([]byte, base64.RawStdEncoding.DecodedLen(len(data))) + if _, err := base64.RawStdEncoding.Decode(res, rawBase64); err != nil { + return nil, err + } + return res, nil +} diff --git a/internal/base64url/base64url_test.go b/internal/base64url/base64url_test.go new file mode 100644 index 0000000..4dce4a8 --- /dev/null +++ b/internal/base64url/base64url_test.go @@ -0,0 +1,96 @@ +package base64url + +import ( + "bytes" + "testing" +) + +func TestRawEncode(t *testing.T) { + inputs := [][]byte{ + // normal + []byte("Hello, World!"), + // some input that produce "+" + {0xFF, 0xFF, 0xFF}, + // some input that produce "/" + {0xFB, 0xEF, 0xBF}, + } + expectedOutputs := [][]byte{ + []byte("SGVsbG8sIFdvcmxkIQ"), + []byte("____"), + []byte("---_"), + } + for i, input := range inputs { + encoded := RawEncode(input) + if !bytes.Equal(encoded, expectedOutputs[i]) { + t.Errorf("Unexpected output for input %q: got %q, want %q", + input, encoded, expectedOutputs[i]) + } + } +} + +func TestRawDecode(t *testing.T) { + t.Run("success", func(t *testing.T) { + + inputs := [][]byte{ + // normal + []byte("SGVsbG8sIFdvcmxkIQ"), + // some input that produce "+" + []byte("____"), + // some input that produce "/" + []byte("---_"), + } + expectedOutputs := [][]byte{ + []byte("Hello, World!"), + {0xFF, 0xFF, 0xFF}, + {0xFB, 0xEF, 0xBF}, + } + for i, input := range inputs { + decoded, err := RawDecode(input) + if err != nil { + t.Errorf("Unexpected error for input %q: %v", + input, err) + continue + } + if !bytes.Equal(decoded, expectedOutputs[i]) { + t.Errorf("Unexpected output for input %q: got %q, want %q", + input, decoded, expectedOutputs[i]) + } + } + }) + t.Run("decode error", func(t *testing.T) { + invalidInputs := [][]byte{ + []byte("SGVsbG8sIFdvcmxkIQ==="), + } + for _, input := range invalidInputs { + decoded, err := RawDecode(input) + if err == nil { + t.Errorf("Expected error for input %q, got %q", + input, decoded) + } + } + }) +} + +func TestRawEncodeDecode(t *testing.T) { + inputs := [][]byte{ + // normal + []byte("Hello, World!"), + // some input that produce "+" + {0xFF, 0xFF, 0xFF}, + // some input that produce "/" + {0xFB, 0xEF, 0xBF}, + } + for _, input := range inputs { + encoded := RawEncode(input) + decoded, err := RawDecode(encoded) + if err != nil { + t.Errorf("Unexpected error for input %q: %v", + input, err) + continue + } + if !bytes.Equal(decoded, input) { + t.Errorf("Unexpected output for input %q: got %q, want %q", + input, decoded, input) + } + } +} diff --git a/token/app.go b/token/app.go index a952a14..ab967ac 100644 --- a/token/app.go +++ b/token/app.go @@ -2,10 +2,11 @@ package token import ( "bytes" - "encoding/base64" "encoding/hex" "strings" "time" + + "github.com/simpleauthlink/authapi/internal/base64url" ) // App represents an application that can request tokens. It has a name, a @@ -123,20 +124,17 @@ func (app *App) Marshal() []byte { if !app.Valid(nil) { return nil } - bApp := app.Bytes() - b := make([]byte, base64.RawStdEncoding.EncodedLen(len(bApp))) - base64.RawStdEncoding.Encode(b, bApp) - return b + return base64url.RawEncode(app.Bytes()) } // Unmarshal method sets the app from a base64-encoded byte slice. It is used // to extract the app from the app ID. func (app *App) Unmarshal(data []byte) *App { - b := make([]byte, base64.RawStdEncoding.DecodedLen(len(data))) - if _, err := base64.RawStdEncoding.Decode(b, data); err != nil { + bApp, err := base64url.RawDecode(data) + if err != nil { return nil } - return app.SetBytes(b) + return app.SetBytes(bApp) } // ID method returns the app ID of the app. The app ID is a self-contained diff --git a/token/expiration.go b/token/expiration.go index bbcb490..302d1c2 100644 --- a/token/expiration.go +++ b/token/expiration.go @@ -1,8 +1,9 @@ package token import ( - "encoding/base64" "time" + + "github.com/simpleauthlink/authapi/internal/base64url" ) // Expiration represents a time when a token expires. It is a wrapper around @@ -120,17 +121,15 @@ func (exp *Expiration) Marshal() []byte { if len(bExp) == 0 || bExp[0] == 0 { return nil } - b := make([]byte, base64.RawStdEncoding.EncodedLen(len(bExp))) - base64.RawStdEncoding.Encode(b, bExp) - return b + return base64url.RawEncode(bExp) } // Unmarshal method sets the expiration time from a base64 encoded byte slice. It // is useful for decoding the expiration time. If the expiration is nil or // invalid, nil is returned. func (exp *Expiration) Unmarshal(data []byte) *Expiration { - b := make([]byte, base64.RawStdEncoding.DecodedLen(len(data))) - if _, err := base64.RawStdEncoding.Decode(b, data); err != nil { + b, err := base64url.RawDecode(data) + if err != nil { return nil } return exp.SetBytes(b) diff --git a/token/id.go b/token/id.go index e7f0699..3bab249 100644 --- a/token/id.go +++ b/token/id.go @@ -4,7 +4,8 @@ import ( "crypto/ed25519" "crypto/hmac" "crypto/sha256" - "encoding/base64" + + "github.com/simpleauthlink/authapi/internal/base64url" ) // AppID represents an application ID that is used to generate and verify @@ -101,9 +102,7 @@ func (id *AppID) Sign(secret Secret, msg []byte) []byte { // sign the data with the private key rawSign := ed25519.Sign(privKey, data[:]) // encode the signature to base64 and return it - sign := make([]byte, base64.RawStdEncoding.EncodedLen(len(rawSign))) - base64.RawStdEncoding.Encode(sign, rawSign) - return sign + return base64url.RawEncode(rawSign) } // Verify method returns true if the signature of the message is valid for @@ -126,8 +125,8 @@ func (id *AppID) Verify(secret Secret, msg, sig []byte) bool { return false } // decode sign from base64 - rawSign := make([]byte, base64.RawStdEncoding.DecodedLen(len(sig))) - if _, err := base64.RawStdEncoding.Decode(rawSign, sig); err != nil { + rawSign, err := base64url.RawDecode(sig) + if err != nil { return false } // recover the data with the nonce and the message