Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ dependencies = [
"typer>=0.15.2",
"litellm==1.74.1",
"weave>=0.51.51",
"uvicorn[standard]",
"fastapi",
]

[project.optional-dependencies]
Expand Down
18 changes: 18 additions & 0 deletions src/art/dev/openai_server.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
import warnings
from typing import Literal

from typing_extensions import TypedDict

from .engine import EngineArgs

ENGINE_INIT_ONLY_ARGS = {
"max_logprobs",
"gpu_memory_utilization",
"tensor_parallel_size",
"max_model_len",
}


def get_openai_server_config(
model_name: str,
Expand Down Expand Up @@ -35,6 +43,16 @@ def get_openai_server_config(
generation_config="vllm",
)
engine_args.update(config.get("engine_args", {}))
user_engine_args = config.get("engine_args", {})
ignored_args = set(user_engine_args.keys()) & ENGINE_INIT_ONLY_ARGS
if ignored_args:
warnings.warn(
f"OpenAIServerConfig.engine_args contains {ignored_args} which will be "
f"ignored. The vLLM engine is initialized by Unsloth before this config "
f"is applied. Use TrainableModel._internal_config.engine_args instead.",
UserWarning,
stacklevel=2,
)
return OpenAIServerConfig(
log_file=log_file, server_args=server_args, engine_args=engine_args
)
Expand Down
20 changes: 20 additions & 0 deletions src/art/local/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,26 @@ async def _train_model(
dev_config: dev.TrainConfig,
verbose: bool = False,
) -> AsyncIterator[dict[str, float]]:
print("[DEBUG _train_model] Received trajectory_groups")
for tg_idx, tg in enumerate(trajectory_groups):
rewards = [t.reward for t in tg.trajectories]
print(f"[DEBUG _train_model] tg={tg_idx} rewards={rewards}")
for traj_idx, traj in enumerate(tg.trajectories):
for msg_idx, msg in enumerate(traj.messages_and_choices):
if isinstance(msg, dict) and msg.get("role") == "assistant":
print(f"[DEBUG _train_model] tg={tg_idx} traj={traj_idx} msg={msg_idx}")
print(f"[DEBUG _train_model] Assistant msg keys: {list(msg.keys())}")
print(f"[DEBUG _train_model] has logprobs: {'logprobs' in msg}")
if 'logprobs' in msg:
lp = msg['logprobs']
print(f"[DEBUG _train_model] logprobs type: {type(lp)}, truthy: {bool(lp)}")
if isinstance(lp, dict):
print(f"[DEBUG _train_model] logprobs keys: {list(lp.keys())}")
if 'values' in lp:
print(f"[DEBUG _train_model] logprobs['values'] len: {len(lp['values'])}")
print(f"[DEBUG _train_model] token_ids present: {'token_ids' in msg and msg.get('token_ids') is not None}")
if 'token_ids' in msg and msg.get('token_ids') is not None:
print(f"[DEBUG _train_model] token_ids len: {len(msg['token_ids'])}")
if verbose:
print("Starting _train_model")
service = await self._get_service(model)
Expand Down
21 changes: 19 additions & 2 deletions src/art/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ class Loss(BaseModel):
mean_kl: torch.Tensor
mean_entropy: torch.Tensor | None
probs_corr: torch.Tensor
frac_old_logprobs_valid: float
mean_importance_ratio: torch.Tensor
clip_fraction: torch.Tensor


def loss_fn(
Expand All @@ -32,6 +35,9 @@ def loss_fn(
)
weights = shift_tensor(inputs["weights"], 0.0)
old_logprobs_mask = ~torch.isnan(old_logprobs)
frac_old_logprobs_valid = (
old_logprobs_mask.float().sum() / (old_logprobs.numel() + 1e-6)
).item()
probs_corr = torch.corrcoef(
torch.stack(
[
Expand Down Expand Up @@ -77,15 +83,23 @@ def loss_fn(
)
if tau := experimental_config.get("kimi_k2_tau", None):
advantages -= tau * logprob_diff.detach()
clipped_ratio = torch.clip(prob_ratio, 1 - epsilon, 1 + epsilon_high)
is_clipped = (prob_ratio < 1 - epsilon) | (prob_ratio > 1 + epsilon_high)
clip_fraction = (is_clipped.float() * assistant_mask).sum() / (
assistant_mask.sum() + 1e-6
)
mean_importance_ratio = (prob_ratio * assistant_mask).sum() / (
assistant_mask.sum() + 1e-6
)
if experimental_config.get("ppo", True):
policy_loss = -torch.min(
prob_ratio * advantages,
torch.clip(prob_ratio, 1 - epsilon, 1 + epsilon_high) * advantages,
clipped_ratio * advantages,
)
else:
# Modified REINFORCE or Clipped IS-weight Policy Optimization (CISPO)
policy_loss = -(
torch.clip(prob_ratio.detach(), 1 - epsilon, 1 + epsilon_high)
clipped_ratio.detach()
* advantages
* new_logprobs
)
Expand Down Expand Up @@ -123,6 +137,9 @@ def loss_fn(
mean_kl=mean_kl,
mean_entropy=mean_entropy,
probs_corr=probs_corr,
frac_old_logprobs_valid=frac_old_logprobs_valid,
mean_importance_ratio=mean_importance_ratio,
clip_fraction=clip_fraction,
)


Expand Down
26 changes: 24 additions & 2 deletions src/art/model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
from typing import TYPE_CHECKING, Generic, Iterable, Optional, TypeVar, cast, overload
from typing import (
TYPE_CHECKING,
Any,
Generic,
Iterable,
Optional,
TypeVar,
cast,
overload,
)

import httpx
from openai import AsyncOpenAI, DefaultAsyncHttpxClient
from pydantic import BaseModel
from pydantic import BaseModel, model_validator
from typing_extensions import Never

from . import dev
Expand Down Expand Up @@ -279,6 +288,19 @@ def __init__(
# Bypass BaseModel __setattr__ to allow setting private attr
object.__setattr__(self, "_internal_config", _internal_config)

@model_validator(mode="wrap")
@classmethod
def _preserve_internal_config(
cls, data: Any, handler: Any
) -> "TrainableModel[ModelConfig]":
internal_config = None
if isinstance(data, dict) and "_internal_config" in data:
internal_config = data.pop("_internal_config")
model = handler(data)
if internal_config is not None:
object.__setattr__(model, "_internal_config", internal_config)
return model

@overload
def __new__(
cls,
Expand Down
95 changes: 71 additions & 24 deletions src/art/preprocessing/tokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,23 +159,28 @@ def tokenize_trajectory(
if history.tools is not None
else None
)
chat = cast(
str,
tokenizer.apply_chat_template(
cast(list[dict], messages),
tools=tools, # type: ignore
continue_final_message=True,
tokenize=False,
),
)
original_token_ids = cast(
list[int],
tokenizer.apply_chat_template(
cast(list[dict], messages),
tools=tools, # type: ignore
continue_final_message=True,
),
)
try:
chat = cast(
str,
tokenizer.apply_chat_template(
cast(list[dict], messages),
tools=tools, # type: ignore
continue_final_message=True,
tokenize=False,
),
)
original_token_ids = cast(
list[int],
tokenizer.apply_chat_template(
cast(list[dict], messages),
tools=tools, # type: ignore
continue_final_message=True,
),
)
except ValueError as e:
if "continue_final_message" in str(e):
return None
raise
sentinal_token_id = max(
set(range(cast(int, tokenizer.vocab_size))) - set(original_token_ids)
)
Expand Down Expand Up @@ -216,13 +221,55 @@ def tokenize_trajectory(
if isinstance(message, dict):
content = message.get("content")
assert isinstance(content, str)
content_token_ids = tokenizer.encode(
content,
add_special_tokens=False,
)
token_ids[start:end] = content_token_ids
logprobs[start:end] = [float("nan")] * len(content_token_ids)
assistant_mask[start:end] = [1] * len(content_token_ids)
msg_token_ids = message.get("token_ids")
dict_logprobs = message.get("logprobs")
print(f"[TOKENIZE DEBUG] Processing assistant dict message:")
print(f" msg_token_ids is not None: {msg_token_ids is not None}")
print(f" dict_logprobs truthy: {bool(dict_logprobs)}")
if dict_logprobs:
print(f" dict_logprobs type: {type(dict_logprobs).__name__}")
print(f" dict_logprobs keys: {list(dict_logprobs.keys()) if isinstance(dict_logprobs, dict) else 'N/A'}")
print(f" 'values' in dict_logprobs: {'values' in dict_logprobs if isinstance(dict_logprobs, dict) else 'N/A'}")
if (
msg_token_ids is not None
and dict_logprobs
and "values" in dict_logprobs
):
print(f" -> Using provided token_ids ({len(msg_token_ids)}) and logprobs.values ({len(dict_logprobs['values'])})")
token_ids[start:end] = msg_token_ids
logprobs[start:end] = dict_logprobs["values"]
assistant_mask[start:end] = [1] * len(msg_token_ids)
elif (
dict_logprobs
and "content" in dict_logprobs
and dict_logprobs["content"]
):
token_logprobs = dict_logprobs["content"]
try:
token_ids[start:end] = [
int(lp["token"].split(":")[1]) for lp in token_logprobs
]
except (IndexError, ValueError, KeyError):
token_ids[start:end] = [
token_id if token_id is not None else tokenizer.eos_token_id
for token_id in tokenizer.convert_tokens_to_ids(
[
lp.get("token") or tokenizer.eos_token
for lp in token_logprobs
]
)
]
logprobs[start:end] = [lp["logprob"] for lp in token_logprobs]
assistant_mask[start:end] = [1] * len(token_logprobs)
else:
print(f" -> FALLBACK: re-tokenizing content, logprobs will be NaN")
content_token_ids = tokenizer.encode(
content,
add_special_tokens=False,
)
token_ids[start:end] = content_token_ids
logprobs[start:end] = [float("nan")] * len(content_token_ids)
assistant_mask[start:end] = [1] * len(content_token_ids)
else:
choice = message
assert choice.logprobs or allow_training_without_logprobs, (
Expand Down
4 changes: 3 additions & 1 deletion src/art/rewards/ruler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from rich import print

import art
from art.utils.strip_logprobs import strip_logprobs


class TrajectoryScore(BaseModel):
Expand Down Expand Up @@ -287,9 +288,10 @@ async def ruler_score_group(
new_trajectories.append(new_traj)

# Extract message lists and preserve original rewards for comparison
# Strip logprobs to avoid sending huge token probability data to the judge
message_lists: list[list[ChatCompletionMessageParam]] = []
for traj in new_trajectories:
message_lists.append(traj.messages())
message_lists.append(strip_logprobs(traj.messages()))
traj.metrics["independent_reward"] = traj.reward

try:
Expand Down
5 changes: 4 additions & 1 deletion src/art/skypilot/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,10 @@ async def initialize_cluster(
)
print("Art server task already running, using it…")
else:
art_server_task = sky.Task(name="art_server", run="uv run art")
art_server_task = sky.Task(
name="art_server",
run="source $HOME/.local/bin/env && uv sync --extra backend && uv run art",
)

clusters = await to_thread_typed(
lambda: sky.stream_and_get(
Expand Down
4 changes: 2 additions & 2 deletions src/art/skypilot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ async def wait_for_task_to_start(cluster_name: str, task_name: str) -> None:
task_status = await get_task_status(cluster_name, task_name)

num_checks = 0
while num_checks < 12:
while num_checks < 120:
task_status = await get_task_status(cluster_name, task_name)
if task_status is None:
raise ValueError(f"Task {task_name} not found in cluster {cluster_name}")
Expand All @@ -62,7 +62,7 @@ async def wait_for_task_to_start(cluster_name: str, task_name: str) -> None:
num_checks += 1

raise ValueError(
f"Task {task_name} in cluster {cluster_name} failed to start within 60s"
f"Task {task_name} in cluster {cluster_name} failed to start within 600s"
)


Expand Down
5 changes: 4 additions & 1 deletion src/art/unsloth/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,10 @@ def compute_loss(
trainer._metrics["train"]["learning_rate"].append(config.learning_rate)
trainer._metrics["train"]["policy_loss"].append(loss.mean_policy_loss.item())
if loss.mean_entropy is not None:
trainer._metrics["train"]["entropy"].append(loss.mean_entropy.item()) # type: ignore
trainer._metrics["train"]["entropy"].append(loss.mean_entropy.item())
trainer._metrics["train"]["frac_old_logprobs_valid"].append(loss.frac_old_logprobs_valid)
trainer._metrics["train"]["mean_importance_ratio"].append(loss.mean_importance_ratio.item())
trainer._metrics["train"]["clip_fraction"].append(loss.clip_fraction.item())
if config.beta > 0.0:
trainer._metrics["train"]["kl_div"].append(loss.mean_kl.item())
return loss.mean_policy_loss + config.beta * loss.mean_kl # type: ignore
Expand Down
4 changes: 4 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading