From 5fdf41010295b716f2ceab4df6ce48c36757e9ea Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Wed, 30 Apr 2025 17:19:49 +0200 Subject: [PATCH] support subscription topics --- finegrain/src/finegrain/__init__.py | 24 +++++++++++- finegrain/tests/conftest.py | 21 ++++++++--- finegrain/tests/test_subscription_topic.py | 43 ++++++++++++++++++++++ 3 files changed, 82 insertions(+), 6 deletions(-) create mode 100644 finegrain/tests/test_subscription_topic.py diff --git a/finegrain/src/finegrain/__init__.py b/finegrain/src/finegrain/__init__.py index 9964508..54622d1 100644 --- a/finegrain/src/finegrain/__init__.py +++ b/finegrain/src/finegrain/__init__.py @@ -305,7 +305,12 @@ class EditorAPIContext: user_agent: str token: str | None + subscription_topic: str | None = None logger: logging.Logger + + # Note: this is always set when calling `login` or `me` and updated using SSE. + # If you use a subscription topic it is *not* updated, use the value + # from the response metadata or call `me`. credits: int | None = None _client: httpx.AsyncClient | None @@ -317,6 +322,7 @@ class EditorAPIContext: def __init__( self, + *, credentials: str | None = None, api_key: str | None = None, user: str | None = None, @@ -325,12 +331,14 @@ def __init__( priority: Priority = "standard", verify: bool | str = True, default_timeout: float = 60.0, + subscription_topic: str | None = None, user_agent: str | None = None, ) -> None: self.base_url = base_url or "https://api.finegrain.ai/editor" self.priority = priority self.verify = verify self.default_timeout = default_timeout + self.subscription_topic = subscription_topic if credentials is not None: if (m := API_KEY_PATTERN.match(credentials)) is not None: @@ -407,6 +415,12 @@ async def login(self) -> None: self.credits = r["user"]["credits"] self.token = r["token"] + async def me(self) -> dict[str, Any]: + response = await self.request("GET", "auth/me") + r = response.json() + self.credits = r["credits"] + return r + async def request( self, method: Literal["GET", "POST"], @@ -441,7 +455,11 @@ async def _q() -> httpx.Response: return r async def get_sub_url(self) -> str: - response = await self.request("POST", "sub-auth") + if self.subscription_topic is not None: + params = {"subscription_topic": self.subscription_topic} + else: + params = None + response = await self.request("POST", "sub-auth", json=params) jdata = response.json() sub_token = jdata["token"] self._ping_interval = float(jdata.get("ping_interval", 0.0)) @@ -557,6 +575,8 @@ async def call_skill( timeout = timeout or self.default_timeout user_timeout = max(int(timeout), 1) params = {"priority": self.priority, "user_timeout": user_timeout} | (params or {}) + if self.subscription_topic is not None: + params["subscription_topic"] = self.subscription_topic response = await self.request("POST", f"skills/{url}", json=params) state_id: StateID = response.json()["state"] status = await self.sse_await(state_id, timeout=timeout) @@ -926,6 +946,8 @@ async def _create_state( data: dict[str, str] = {} if file_url is not None: data["file_url"] = file_url + if self.ctx.subscription_topic is not None: + data["subscription_topic"] = self.ctx.subscription_topic if meta is not None: data["meta"] = json.dumps(meta) response = await self.ctx.request("POST", "state/create", files=files, data=data) diff --git a/finegrain/tests/conftest.py b/finegrain/tests/conftest.py index 9c9a96d..2c2fa3f 100644 --- a/finegrain/tests/conftest.py +++ b/finegrain/tests/conftest.py @@ -24,14 +24,25 @@ def event_loop(): @pytest.fixture(scope="session") -async def fgctx() -> AsyncGenerator[EditorAPIContext, None]: +def fx_base_url() -> str: with env.prefixed("FG_API_"): - credentials = env.str("CREDENTIALS", None) url = env.str("URL", "https://api.finegrain.ai/editor") - assert credentials and url, "set FG_API_CREDENTIALS" + return url + + +@pytest.fixture(scope="session") +def fx_credentials() -> str: + with env.prefixed("FG_API_"): + credentials = env.str("CREDENTIALS", None) + assert credentials, "set FG_API_CREDENTIALS" + return credentials + + +@pytest.fixture(scope="session") +async def fgctx(fx_base_url: str, fx_credentials: str) -> AsyncGenerator[EditorAPIContext, None]: ctx = EditorAPIContext( - credentials=credentials, - base_url=url, + base_url=fx_base_url, + credentials=fx_credentials, user_agent="finegrain-python-tests", ) diff --git a/finegrain/tests/test_subscription_topic.py b/finegrain/tests/test_subscription_topic.py new file mode 100644 index 0000000..5372f8d --- /dev/null +++ b/finegrain/tests/test_subscription_topic.py @@ -0,0 +1,43 @@ +import pytest + +from finegrain import EditorAPIContext, OKResult + + +@pytest.mark.parametrize("subscription_topic", ["fg-test-topic", None]) +async def test_subscription_topic( + fx_base_url: str, + fx_credentials: str, + sofa_cushion_bytes: bytes, + subscription_topic: str | None, +) -> None: + ctx = EditorAPIContext( + base_url=fx_base_url, + credentials=fx_credentials, + user_agent="finegrain-python-tests", + subscription_topic=subscription_topic, + ) + + await ctx.login() + await ctx.sse_start() + + create_r = await ctx.call_async.create_state( + file=sofa_cushion_bytes, + meta={"test-key": "test-value"}, + ) + assert isinstance(create_r, OKResult) + assert create_r.meta["test-key"] == "test-value" + + # to test credits update mechanism + assert isinstance(ctx.credits, int) + ctx.credits = None + + infer_ms_r = await ctx.call_async.infer_product_name(create_r.state_id) + assert isinstance(infer_ms_r, OKResult) + assert infer_ms_r.meta["product_name"] == "sofa" + + if subscription_topic is not None: + assert ctx.credits is None + await ctx.me() + assert isinstance(ctx.credits, int) + + await ctx.sse_stop()