Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions api/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,8 @@ func (c *Client) AuthorizedRequestURLParams(r *http.Request) (string, bool, erro
return "", false, ErrMissingAuthURLParams
}
// get the strToken and email from the request headers
strToken, _ := url.QueryUnescape(params.Get(DefaultAuthTokenURLParam))
userEmail, _ := url.QueryUnescape(params.Get(DefaultAuthEmailURLParam))
strToken := params.Get(DefaultAuthTokenURLParam)
userEmail := params.Get(DefaultAuthEmailURLParam)
// check if the token and email are valid
if strToken == "" || userEmail == "" {
return "", false, ErrMissingAuthURLParams
Expand Down
21 changes: 15 additions & 6 deletions api/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"log"
"math"
"math/big"
"net/http"
"net/http/httptest"
"net/url"
Expand All @@ -24,8 +25,8 @@ import (
var (
// server
testServer = "127.0.0.1"
testPort = 8082
testEmailPort = 2526
testPort = randomPort()
testEmailPort = randomPort()
testEndpoint = fmt.Sprintf("http://%s:%d", testServer, testPort)
testServerSecret = "serversecret"
testServerEmail = "server@email.com"
Expand Down Expand Up @@ -591,8 +592,8 @@ func TestAuthorizedRequestURLParams(t *testing.T) {
}
reqURL, _ := url.Parse("http://example.com/test")
reqURL.RawQuery = fmt.Sprintf("%s=%s&%s=%s",
DefaultAuthEmailURLParam, url.QueryEscape(testEmail),
DefaultAuthTokenURLParam, url.QueryEscape(testToken.String()))
DefaultAuthEmailURLParam, testEmail,
DefaultAuthTokenURLParam, testToken.String())

req, err := http.NewRequest(http.MethodGet, reqURL.String(), nil)
if err != nil {
Expand Down Expand Up @@ -643,8 +644,8 @@ func TestAuthorizedRequestURLParams(t *testing.T) {
invalidEmail := "noemail.com"
reqURL, _ := url.Parse("http://example.com/test")
reqURL.RawQuery = fmt.Sprintf("%s=%s&%s=%s",
DefaultAuthEmailURLParam, url.QueryEscape(invalidEmail),
DefaultAuthTokenURLParam, url.QueryEscape(testToken.String()))
DefaultAuthEmailURLParam, invalidEmail,
DefaultAuthTokenURLParam, testToken.String())

req, err := http.NewRequest(http.MethodGet, reqURL.String(), nil)
if err != nil {
Expand Down Expand Up @@ -913,3 +914,11 @@ func randomEmail() string {
randStr := strings.ToLower(rand.Text()[0:8])
return fmt.Sprintf("%s@example.com", randStr)
}

func randomPort() int {
port, err := rand.Int(rand.Reader, big.NewInt(65535-1024))
if err != nil {
panic(err)
}
return int(port.Int64()) + 1024
}
36 changes: 36 additions & 0 deletions internal/base64url/base64url.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// base64url package provides a way to encode and decode data using a modified
// base64 encoding scheme that is safe for URLs.
package base64url

import (
"bytes"
"encoding/base64"
)

// RawEncode function encodes the input byte slice into a base64 URL-safe
// encoded byte slice. It encodes the input using the raw standard base64
// encoding and then replaces the URL-unsafe characters.
func RawEncode(src []byte) []byte {
rawBase64 := make([]byte, base64.RawStdEncoding.EncodedLen(len(src)))
base64.RawStdEncoding.Encode(rawBase64, src)
// replace the characters that are not URL safe (+, /)
rawBase64 = bytes.ReplaceAll(rawBase64, []byte("+"), []byte("-"))
rawBase64 = bytes.ReplaceAll(rawBase64, []byte("/"), []byte("_"))
return rawBase64
}

// RawDecode function decodes the input byte slice from a base64 URL-safe
// encoded byte slice. It replaces the URL-unsafe characters with their
// standard base64 counterparts and then decodes the result using the raw
// standard base64 decoding.
func RawDecode(data []byte) ([]byte, error) {
// recover the characters that are not URL safe (+, /)
rawBase64 := bytes.ReplaceAll(data, []byte("-"), []byte("+"))
rawBase64 = bytes.ReplaceAll(rawBase64, []byte("_"), []byte("/"))
// return the decoded string from basic base64
res := make([]byte, base64.RawStdEncoding.DecodedLen(len(data)))
if _, err := base64.RawStdEncoding.Decode(res, rawBase64); err != nil {
return nil, err
}
return res, nil
}
96 changes: 96 additions & 0 deletions internal/base64url/base64url_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package base64url

import (
"bytes"
"testing"
)

func TestRawEncode(t *testing.T) {
inputs := [][]byte{
// normal
[]byte("Hello, World!"),
// some input that produce "+"
{0xFF, 0xFF, 0xFF},
// some input that produce "/"
{0xFB, 0xEF, 0xBF},
}
expectedOutputs := [][]byte{
[]byte("SGVsbG8sIFdvcmxkIQ"),
[]byte("____"),
[]byte("---_"),
}
for i, input := range inputs {
encoded := RawEncode(input)
if !bytes.Equal(encoded, expectedOutputs[i]) {
t.Errorf("Unexpected output for input %q: got %q, want %q",
input, encoded, expectedOutputs[i])
}
}
}

func TestRawDecode(t *testing.T) {
t.Run("success", func(t *testing.T) {

inputs := [][]byte{
// normal
[]byte("SGVsbG8sIFdvcmxkIQ"),
// some input that produce "+"
[]byte("____"),
// some input that produce "/"
[]byte("---_"),
}
expectedOutputs := [][]byte{
[]byte("Hello, World!"),
{0xFF, 0xFF, 0xFF},
{0xFB, 0xEF, 0xBF},
}
for i, input := range inputs {
decoded, err := RawDecode(input)
if err != nil {
t.Errorf("Unexpected error for input %q: %v",
input, err)
continue
}
if !bytes.Equal(decoded, expectedOutputs[i]) {
t.Errorf("Unexpected output for input %q: got %q, want %q",
input, decoded, expectedOutputs[i])
}
}
})
t.Run("decode error", func(t *testing.T) {
invalidInputs := [][]byte{
[]byte("SGVsbG8sIFdvcmxkIQ==="),
}
for _, input := range invalidInputs {
decoded, err := RawDecode(input)
if err == nil {
t.Errorf("Expected error for input %q, got %q",
input, decoded)
}
}
})
}

func TestRawEncodeDecode(t *testing.T) {
inputs := [][]byte{
// normal
[]byte("Hello, World!"),
// some input that produce "+"
{0xFF, 0xFF, 0xFF},
// some input that produce "/"
{0xFB, 0xEF, 0xBF},
}
for _, input := range inputs {
encoded := RawEncode(input)
decoded, err := RawDecode(encoded)
if err != nil {
t.Errorf("Unexpected error for input %q: %v",
input, err)
continue
}
if !bytes.Equal(decoded, input) {
t.Errorf("Unexpected output for input %q: got %q, want %q",
input, decoded, input)
}
}
}
14 changes: 6 additions & 8 deletions token/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ package token

import (
"bytes"
"encoding/base64"
"encoding/hex"
"strings"
"time"

"github.com/simpleauthlink/authapi/internal/base64url"
)

// App represents an application that can request tokens. It has a name, a
Expand Down Expand Up @@ -123,20 +124,17 @@ func (app *App) Marshal() []byte {
if !app.Valid(nil) {
return nil
}
bApp := app.Bytes()
b := make([]byte, base64.RawStdEncoding.EncodedLen(len(bApp)))
base64.RawStdEncoding.Encode(b, bApp)
return b
return base64url.RawEncode(app.Bytes())
}

// Unmarshal method sets the app from a base64-encoded byte slice. It is used
// to extract the app from the app ID.
func (app *App) Unmarshal(data []byte) *App {
b := make([]byte, base64.RawStdEncoding.DecodedLen(len(data)))
if _, err := base64.RawStdEncoding.Decode(b, data); err != nil {
bApp, err := base64url.RawDecode(data)
if err != nil {
return nil
}
return app.SetBytes(b)
return app.SetBytes(bApp)
}

// ID method returns the app ID of the app. The app ID is a self-contained
Expand Down
11 changes: 5 additions & 6 deletions token/expiration.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package token

import (
"encoding/base64"
"time"

"github.com/simpleauthlink/authapi/internal/base64url"
)

// Expiration represents a time when a token expires. It is a wrapper around
Expand Down Expand Up @@ -120,17 +121,15 @@ func (exp *Expiration) Marshal() []byte {
if len(bExp) == 0 || bExp[0] == 0 {
return nil
}
b := make([]byte, base64.RawStdEncoding.EncodedLen(len(bExp)))
base64.RawStdEncoding.Encode(b, bExp)
return b
return base64url.RawEncode(bExp)
}

// Unmarshal method sets the expiration time from a base64 encoded byte slice. It
// is useful for decoding the expiration time. If the expiration is nil or
// invalid, nil is returned.
func (exp *Expiration) Unmarshal(data []byte) *Expiration {
b := make([]byte, base64.RawStdEncoding.DecodedLen(len(data)))
if _, err := base64.RawStdEncoding.Decode(b, data); err != nil {
b, err := base64url.RawDecode(data)
if err != nil {
return nil
}
return exp.SetBytes(b)
Expand Down
11 changes: 5 additions & 6 deletions token/id.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ import (
"crypto/ed25519"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"

"github.com/simpleauthlink/authapi/internal/base64url"
)

// AppID represents an application ID that is used to generate and verify
Expand Down Expand Up @@ -101,9 +102,7 @@ func (id *AppID) Sign(secret Secret, msg []byte) []byte {
// sign the data with the private key
rawSign := ed25519.Sign(privKey, data[:])
// encode the signature to base64 and return it
sign := make([]byte, base64.RawStdEncoding.EncodedLen(len(rawSign)))
base64.RawStdEncoding.Encode(sign, rawSign)
return sign
return base64url.RawEncode(rawSign)
}

// Verify method returns true if the signature of the message is valid for
Expand All @@ -126,8 +125,8 @@ func (id *AppID) Verify(secret Secret, msg, sig []byte) bool {
return false
}
// decode sign from base64
rawSign := make([]byte, base64.RawStdEncoding.DecodedLen(len(sig)))
if _, err := base64.RawStdEncoding.Decode(rawSign, sig); err != nil {
rawSign, err := base64url.RawDecode(sig)
if err != nil {
return false
}
// recover the data with the nonce and the message
Expand Down