Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .github/workflows/matrix-checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@ on:
python_version:
required: false
type: string
secrets:
CI_SSH_KEY:
required: true

jobs:
matrix-checks:
Expand Down
3 changes: 0 additions & 3 deletions .github/workflows/tests-macos.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@ on:
python_version:
required: false
type: string
secrets:
CI_SSH_KEY:
required: true

jobs:
tests-macos:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ dependencies = [
"dnet-p2p @ file://${PROJECT_ROOT}/lib/dnet-p2p/bindings/py",
"rich>=13.0.0",
"psutil>=5.9.0",
"outlines>=1.2.0",
]

[project.optional-dependencies]
Expand Down
76 changes: 65 additions & 11 deletions src/dnet/api/inference.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import asyncio
import time
import uuid
import json
import mlx.core as mx
import numpy as np
from typing import Optional, Any, List
from builtins import aiter, anext
from dnet.core.tensor import to_bytes

from .models import (
Expand All @@ -14,11 +16,13 @@
ChatUsage,
ChatCompletionReason,
ChatLogProbs,
StructuredOutputsParams,
)
from .cluster import ClusterManager
from .model_manager import ModelManager
from .strategies.base import ApiAdapterBase
from dnet.core.decoding.config import DecodingConfig
from dnet.utils.logger import logger


async def arange(count: int):
Expand Down Expand Up @@ -64,9 +68,9 @@ async def connect_to_ring(
self._api_callback_addr = api_callback_addr

async def generate_stream(self, req: ChatRequestModel):
"""
Generator for chat completion chunks.
"""
"""Generator for chat completion chunks."""
logger.debug(f"generate_stream called: model={req.model}")

if not self.model_manager.tokenizer:
raise RuntimeError(
"Inference manager not ready (ring not connected or tokenizer not loaded)"
Expand All @@ -79,20 +83,26 @@ async def generate_stream(self, req: ChatRequestModel):
hasattr(tokenizer, "chat_template")
and tokenizer.chat_template is not None
):
message_dicts = [
{"role": m.role, "content": m.content} for m in req.messages
]
# Convert messages to dict format
message_dicts = []
for m in req.messages:
msg_dict = {"role": m.role, "content": m.content or ""}
message_dicts.append(msg_dict)

prompt_text = tokenizer.apply_chat_template(
message_dicts,
add_generation_prompt=True,
tokenize=False,
)
else:
prompt_text = (
"\n".join(m.content for m in req.messages) + "\nAssistant:"
"\n".join(m.content or "" for m in req.messages) + "\nAssistant:"
)
except Exception:
prompt_text = "\n".join(m.content for m in req.messages) + "\nAssistant:"
except Exception as e:
logger.warning(f"Failed to apply chat template: {e}, using fallback")
prompt_text = (
"\n".join(m.content or "" for m in req.messages) + "\nAssistant:"
)

prompt_tokens = tokenizer.encode(prompt_text)
prompt_array = mx.array(prompt_tokens)
Expand All @@ -104,6 +114,16 @@ async def generate_stream(self, req: ChatRequestModel):
tokenizer.encode(stop_word, add_special_tokens=False)
)

# Convert OpenAI response_format to internal structured_outputs format
if req.response_format and req.response_format.get("type") == "json_schema":
json_schema = req.response_format["json_schema"]["schema"]
req.structured_outputs = StructuredOutputsParams(json_schema=json_schema)

# Get grammar JSON schema for structured output
grammar_json_schema = None
if req.structured_outputs and req.structured_outputs.json_schema:
grammar_json_schema = json.dumps(req.structured_outputs.json_schema)

nonce = f"chatcmpl-{uuid.uuid4()}"
t_start = time.perf_counter()
t_first_token: Optional[float] = None
Expand Down Expand Up @@ -152,6 +172,7 @@ async def generate_stream(self, req: ChatRequestModel):
min_tokens_to_keep=req.min_tokens_to_keep
if hasattr(req, "min_tokens_to_keep")
else 1,
grammar_json_schema=grammar_json_schema,
)

# Send tokens to first shard
Expand Down Expand Up @@ -209,9 +230,29 @@ async def generate_stream(self, req: ChatRequestModel):
if token == tokenizer.eos_token_id:
completion_reason = ChatCompletionReason.STOP
break

y = mx.array([token], dtype=mx.int32)

detokenizer.finalize()
final_text = detokenizer.text

# Strip special tokens from output
# mlx-lm's NaiveStreamingDetokenizer calls tokenizer.decode() without skip_special_tokens=True
# (see: https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/tokenizer_utils.py)
# So we strip them manually as a post-processing step
SPECIAL_TOKENS_TO_STRIP = [
"<|im_end|>", # Qwen, ChatML format
"<|im_start|>", # Qwen, ChatML format
"<|endoftext|>", # GPT/generic
"</s>", # Llama, Mistral
"<|eot_id|>", # Llama 3
"<|end|>", # Phi
"<|assistant|>", # Some chat templates
"<|user|>", # Some chat templates
]
for token in SPECIAL_TOKENS_TO_STRIP:
final_text = final_text.replace(token, "")
final_text = final_text.strip()

metrics_dict = None
t_end = time.perf_counter()
Expand All @@ -232,13 +273,19 @@ async def generate_stream(self, req: ChatRequestModel):
),
}

# Final chunk with finish reason
final_message = ChatMessage(
role="assistant",
content=final_text,
)

# Final chunk
yield ChatResponseModel(
id=nonce,
choices=[
ChatChoice(
index=0,
delta=ChatMessage(role="assistant", content=""),
delta=None,
message=final_message,
finish_reason=completion_reason,
)
],
Expand Down Expand Up @@ -288,6 +335,13 @@ async def chat_completions(self, req: ChatRequestModel) -> ChatResponseModel:
if chunk.usage:
usage = chunk.usage

# Clean up structured output responses - remove end tokens
if req.structured_outputs and req.structured_outputs.json_schema:
full_content = full_content.strip()
for token in ["<|im_end|>", "<|endoftext|>", "</s>"]:
if token in full_content:
full_content = full_content.split(token)[0].strip()

return ChatResponseModel(
id=nonce,
choices=[
Expand Down
41 changes: 36 additions & 5 deletions src/dnet/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,29 @@ class ChatCompletionReason(str, Enum):
STOP = "stop"


class StructuredOutputsParams(BaseModel):
"""Parameters for structured output generation."""

json_schema: Optional[Dict[str, Any]] = Field(default=None)

@field_validator("json_schema")
@classmethod
def validate_json_schema(cls, v):
if v is None:
return v
if not isinstance(v, dict):
raise ValueError("JSON schema must be a dictionary")
if "type" not in v:
raise ValueError("JSON schema must have a 'type' field")
try:
import json

json.dumps(v)
except (TypeError, ValueError) as e:
raise ValueError(f"JSON schema must be JSON serializable: {e}")
return v


class RingInferenceError(BaseModel):
"""Error response for ring inference."""

Expand All @@ -49,10 +72,13 @@ class RingInferenceError(BaseModel):


class ChatMessage(BaseModel):
"""A single message in a chat conversation."""
"""A single message in a chat conversation.

Compatible with OpenAI format.
"""

role: str # "system" | "user" | "assistant" | "tool" | "developer" # TODO: use Literal?
content: str
role: str # "system" | "user" | "assistant" | "developer" # TODO: use Literal?
content: Optional[str] = None


class ChatParams(BaseModel):
Expand All @@ -78,7 +104,12 @@ class ChatParams(BaseModel):
# prediction: NOT USED
# presence_penalty: float = Field(default=0.0, ge=-2.0, le=2.0) # NOTE: unused
# prompt_cache_key: Optional[str] = Field(default=None) # NOTE: unused
# TODO: response_format:
structured_outputs: Optional[StructuredOutputsParams] = Field(
default=None
) # Structured output parameters for grammar-constrained generation
response_format: Optional[Dict[str, Any]] = Field(
default=None
) # OpenAI-compatible response format (json_schema, etc.)
# safety_identifier: Optional[str] = Field(default=None) # NOTE: unused
# service_tier: Optional[str] = Field(default=None) # NOTE: unused
stop: Union[str, List[str]] = Field(default_factory=list)
Expand Down Expand Up @@ -298,7 +329,7 @@ class ListModelsResponseModel(BaseModel):
data: List[ModelObject]


type RetrieveModelResponseModel = ModelObject
RetrieveModelResponseModel = ModelObject


# ------------------------
Expand Down
3 changes: 3 additions & 0 deletions src/dnet/api/strategies/ring.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@ async def send_tokens(
min_tokens_to_keep=decoding_config.min_tokens_to_keep
if decoding_config
else 1,
grammar_json_schema=decoding_config.grammar_json_schema
if decoding_config and hasattr(decoding_config, "grammar_json_schema")
else None,
)
req = msg.to_proto(tokens)

Expand Down
3 changes: 3 additions & 0 deletions src/dnet/core/decoding/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass
from typing import Optional


@dataclass
Expand All @@ -12,3 +13,5 @@ class DecodingConfig:
logit_bias: dict[int, float] | None = None
min_p: float = 0.0
min_tokens_to_keep: int = 1
# Structured output support
grammar_json_schema: Optional[str] = None
Loading