Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion finegrain/src/finegrain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,12 @@ class EditorAPIContext:
user_agent: str

token: str | None
subscription_topic: str | None = None
logger: logging.Logger

# Note: this is always set when calling `login` or `me` and updated using SSE.
# If you use a subscription topic it is *not* updated, use the value
# from the response metadata or call `me`.
credits: int | None = None

_client: httpx.AsyncClient | None
Expand All @@ -317,6 +322,7 @@ class EditorAPIContext:

def __init__(
self,
*,
credentials: str | None = None,
api_key: str | None = None,
user: str | None = None,
Expand All @@ -325,12 +331,14 @@ def __init__(
priority: Priority = "standard",
verify: bool | str = True,
default_timeout: float = 60.0,
subscription_topic: str | None = None,
user_agent: str | None = None,
) -> None:
self.base_url = base_url or "https://api.finegrain.ai/editor"
self.priority = priority
self.verify = verify
self.default_timeout = default_timeout
self.subscription_topic = subscription_topic

if credentials is not None:
if (m := API_KEY_PATTERN.match(credentials)) is not None:
Expand Down Expand Up @@ -407,6 +415,12 @@ async def login(self) -> None:
self.credits = r["user"]["credits"]
self.token = r["token"]

async def me(self) -> dict[str, Any]:
response = await self.request("GET", "auth/me")
r = response.json()
self.credits = r["credits"]
return r

async def request(
self,
method: Literal["GET", "POST"],
Expand Down Expand Up @@ -441,7 +455,11 @@ async def _q() -> httpx.Response:
return r

async def get_sub_url(self) -> str:
response = await self.request("POST", "sub-auth")
if self.subscription_topic is not None:
params = {"subscription_topic": self.subscription_topic}
else:
params = None
response = await self.request("POST", "sub-auth", json=params)
jdata = response.json()
sub_token = jdata["token"]
self._ping_interval = float(jdata.get("ping_interval", 0.0))
Expand Down Expand Up @@ -557,6 +575,8 @@ async def call_skill(
timeout = timeout or self.default_timeout
user_timeout = max(int(timeout), 1)
params = {"priority": self.priority, "user_timeout": user_timeout} | (params or {})
if self.subscription_topic is not None:
params["subscription_topic"] = self.subscription_topic
response = await self.request("POST", f"skills/{url}", json=params)
state_id: StateID = response.json()["state"]
status = await self.sse_await(state_id, timeout=timeout)
Expand Down Expand Up @@ -926,6 +946,8 @@ async def _create_state(
data: dict[str, str] = {}
if file_url is not None:
data["file_url"] = file_url
if self.ctx.subscription_topic is not None:
data["subscription_topic"] = self.ctx.subscription_topic
if meta is not None:
data["meta"] = json.dumps(meta)
response = await self.ctx.request("POST", "state/create", files=files, data=data)
Expand Down
21 changes: 16 additions & 5 deletions finegrain/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,25 @@ def event_loop():


@pytest.fixture(scope="session")
async def fgctx() -> AsyncGenerator[EditorAPIContext, None]:
def fx_base_url() -> str:
with env.prefixed("FG_API_"):
credentials = env.str("CREDENTIALS", None)
url = env.str("URL", "https://api.finegrain.ai/editor")
assert credentials and url, "set FG_API_CREDENTIALS"
return url


@pytest.fixture(scope="session")
def fx_credentials() -> str:
with env.prefixed("FG_API_"):
credentials = env.str("CREDENTIALS", None)
assert credentials, "set FG_API_CREDENTIALS"
return credentials


@pytest.fixture(scope="session")
async def fgctx(fx_base_url: str, fx_credentials: str) -> AsyncGenerator[EditorAPIContext, None]:
ctx = EditorAPIContext(
credentials=credentials,
base_url=url,
base_url=fx_base_url,
credentials=fx_credentials,
user_agent="finegrain-python-tests",
)

Expand Down
43 changes: 43 additions & 0 deletions finegrain/tests/test_subscription_topic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import pytest

from finegrain import EditorAPIContext, OKResult


@pytest.mark.parametrize("subscription_topic", ["fg-test-topic", None])
async def test_subscription_topic(
fx_base_url: str,
fx_credentials: str,
sofa_cushion_bytes: bytes,
subscription_topic: str | None,
) -> None:
ctx = EditorAPIContext(
base_url=fx_base_url,
credentials=fx_credentials,
user_agent="finegrain-python-tests",
subscription_topic=subscription_topic,
)

await ctx.login()
await ctx.sse_start()

create_r = await ctx.call_async.create_state(
file=sofa_cushion_bytes,
meta={"test-key": "test-value"},
)
assert isinstance(create_r, OKResult)
assert create_r.meta["test-key"] == "test-value"

# to test credits update mechanism
assert isinstance(ctx.credits, int)
ctx.credits = None

infer_ms_r = await ctx.call_async.infer_product_name(create_r.state_id)
assert isinstance(infer_ms_r, OKResult)
assert infer_ms_r.meta["product_name"] == "sofa"

if subscription_topic is not None:
assert ctx.credits is None
await ctx.me()
assert isinstance(ctx.credits, int)

await ctx.sse_stop()
Loading