From 54ea252d3eca52bbb663ab5ac96e3e87ef0124a8 Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Wed, 28 May 2025 11:54:29 +0200 Subject: [PATCH] add support for oauth credentials (refresh token) --- finegrain/pyproject.toml | 1 + finegrain/requirements.lock | 14 +++-- finegrain/src/finegrain/__init__.py | 79 +++++++++++++++++++++++++++-- 3 files changed, 87 insertions(+), 7 deletions(-) diff --git a/finegrain/pyproject.toml b/finegrain/pyproject.toml index fbbea06..5e87df2 100644 --- a/finegrain/pyproject.toml +++ b/finegrain/pyproject.toml @@ -8,6 +8,7 @@ authors = [ dependencies = [ "httpx>=0.27.0", "httpx-sse>=0.4.0", + "pyjwt[crypto]>=2.10.1", ] readme = "README.md" requires-python = ">= 3.12, <3.13" diff --git a/finegrain/requirements.lock b/finegrain/requirements.lock index eb00112..256bb79 100644 --- a/finegrain/requirements.lock +++ b/finegrain/requirements.lock @@ -12,12 +12,16 @@ -e file:. anyio==4.9.0 # via httpx -certifi==2025.1.31 +certifi==2025.4.26 # via httpcore # via httpx -h11==0.14.0 +cffi==1.17.1 + # via cryptography +cryptography==45.0.3 + # via pyjwt +h11==0.16.0 # via httpcore -httpcore==1.0.8 +httpcore==1.0.9 # via httpx httpx==0.28.1 # via finegrain @@ -26,6 +30,10 @@ httpx-sse==0.4.0 idna==3.10 # via anyio # via httpx +pycparser==2.22 + # via cffi +pyjwt==2.10.1 + # via finegrain sniffio==1.3.1 # via anyio typing-extensions==4.13.2 diff --git a/finegrain/src/finegrain/__init__.py b/finegrain/src/finegrain/__init__.py index 54622d1..52c8483 100644 --- a/finegrain/src/finegrain/__init__.py +++ b/finegrain/src/finegrain/__init__.py @@ -10,6 +10,7 @@ import httpx import httpx_sse +import jwt from httpx._types import QueryParamTypes, RequestData, RequestFiles # pyright: ignore[reportPrivateImportUsage] logger = logging.getLogger(__name__) @@ -17,7 +18,7 @@ Priority = Literal["low", "standard", "high"] StateID = NewType("StateID", str) -VERSION = "0.2" +VERSION = "0.3" API_KEY_PATTERN = re.compile(r"^FGAPI(\-[A-Z0-9]{6}){4}$") EMAIL_PWD_PATTERN = re.compile(r"^\s*(?P[\S]+?@[\S]+?):(?P\S+)\s*$") @@ -293,7 +294,61 @@ def description(self) -> str: return f"API key {self.api_key[:13]}..." -type Credentials = LoginCredentials | ApiKeyCredentials +@dc.dataclass(kw_only=True) +class OAuthCredentials: + access_token: str + refresh_token: str + client_id: str + client_secret: str + account_url: str = "https://account.finegrain.ai" + account_verify: bool | str = True + + def __post_init__(self): + self.validate_tokens() + + def validate_tokens(self) -> None: + assert self.access_token, "access_token must not be empty" + assert self.refresh_token, "refresh_token must not be empty" + + decoded_access = jwt.decode(self.access_token, options={"verify_signature": False}) + decoded_refresh = jwt.decode(self.refresh_token, options={"verify_signature": False}) + + assert decoded_access["aud"] == "access" + assert decoded_refresh["aud"] == "refresh" + + sub = decoded_access.get("sub", "") + assert sub.startswith("FGUSR-") + + assert decoded_refresh["iss"] == self.client_id + assert decoded_refresh["sub"] == sub + + @property + def as_login_params(self) -> dict[str, str]: + raise ValueError("cannot login with OAuth credentials") + + @property + def description(self) -> str: + return f"OAuth client {self.client_id}" + + async def renew(self) -> None: + async with httpx.AsyncClient(verify=self.account_verify) as client: + response = await client.post( + url=f"{self.account_url}/oauth/token", + data={ + "grant_type": "refresh_token", + "client_id": self.client_id, + "client_secret": self.client_secret, + "refresh_token": self.refresh_token, + }, + ) + check_status(response) + r = response.json() + self.access_token = r["access_token"] + self.refresh_token = r["refresh_token"] + self.validate_tokens() + + +Credentials = LoginCredentials | ApiKeyCredentials | OAuthCredentials class EditorAPIContext: @@ -323,7 +378,7 @@ class EditorAPIContext: def __init__( self, *, - credentials: str | None = None, + credentials: Credentials | str | None = None, api_key: str | None = None, user: str | None = None, password: str | None = None, @@ -340,7 +395,9 @@ def __init__( self.default_timeout = default_timeout self.subscription_topic = subscription_topic - if credentials is not None: + if isinstance(credentials, Credentials): + self.credentials = credentials + elif credentials is not None: if (m := API_KEY_PATTERN.match(credentials)) is not None: self.credentials = ApiKeyCredentials(api_key=m[0]) elif (m := EMAIL_PWD_PATTERN.match(credentials)) is not None: @@ -367,6 +424,9 @@ def __init__( verify=self.verify, ) self.reset() + if isinstance(self.credentials, OAuthCredentials): + # Use token provided in credentials initially (avoids a useless refresh). + self.token = self.credentials.access_token def reset(self) -> None: self.token = None @@ -404,6 +464,16 @@ def auth_headers(self) -> dict[str, str]: return {"Authorization": f"Bearer {self.token}"} async def login(self) -> None: + if isinstance(self.credentials, OAuthCredentials): + if self.token is None: + await self.credentials.renew() + self.token = self.credentials.access_token + # If the token is set but invalid, `me` will fail with 401. + # The token will be unset and `login` will be called again. + r = await self.me() + self.credits = r["credits"] + self.logger.debug(f"logged in as {self.credentials.description} - {r['username']}") + return async with self as client: response = await client.post( f"{self.base_url}/auth/login", @@ -447,6 +517,7 @@ async def _q() -> httpx.Response: r = await _q() if r.status_code == 401: self.logger.debug("renewing token") + self.token = None await self.login() r = await _q()