diff --git a/.gitignore b/.gitignore
index e5e071d..8d07fdc 100644
--- a/.gitignore
+++ b/.gitignore
@@ -121,8 +121,10 @@ tests/e2e/toy_examples/deepspeed/synchronous/output.txt
*.lock
# data
+data/
*.parquet
agentfly/agents/data/*
+test_cache/
# local logs
logs
@@ -133,6 +135,10 @@ data/
test_cache/
/*.jpg
/*.png
+slurm/
+*.err
+*.out
+*.log
# Notebooks
agentfly/tests/*.ipynb
@@ -146,3 +152,11 @@ test_outputs/
agentfly/data/
*.ipynb
+# training scripts
+training_scripts/
+verl/training_scripts/
+
+# training scripts
+training_scripts/
+verl/training_scripts/
+
diff --git a/agentfly/agents/agent_base.py b/agentfly/agents/agent_base.py
index 2810673..9df34c1 100644
--- a/agentfly/agents/agent_base.py
+++ b/agentfly/agents/agent_base.py
@@ -1,9 +1,10 @@
from abc import ABC, abstractmethod
from collections import defaultdict
+from datetime import datetime
import json
from .utils.messages import MessagesList
from ..templates.templates import get_template
-from ..__init__ import AGENT_DATA_DIR
+from .. import AGENT_DATA_DIR
from .llm_backends import (
AsyncVLLMBackend,
AsyncVerlBackend,
@@ -23,6 +24,7 @@
import logging
from .chain.streaming_observer import ConsoleStreamObserver, StreamingManager
from .utils.tokenizer import create_processor, create_tokenizer
+from ..utils.monitor import JsonlSink, Monitor, WandbSink
try:
from verl.protocol import DataProto
except ImportError:
@@ -51,10 +53,12 @@ def __init__(
backend_config: Any = None,
reward_fn: Callable = None,
log_file: str = "agent",
- project_name: str = None,
- run_name: str = None,
streaming: str = "console",
debug: bool = False,
+ monitors: List[str] = [],
+ wandb_project_name: str = None,
+ wandb_run_name: str = None,
+ local_cache_dir: str = None,
**kwargs # To pass other unused arguments
):
"""
@@ -94,7 +98,6 @@ def __init__(
# Create appropriate tokenizer for trajectory processing
self.tokenizer = create_tokenizer(model_name_or_path)
-
self.processor = create_processor(model_name_or_path)
self._reward_fn = reward_fn
@@ -104,8 +107,12 @@ def __init__(
else:
self.jinja_template = get_template(self.template).jinja_template()
- self.project_name = project_name
- self.run_name = run_name
+ self.wandb_project_name = wandb_project_name
+ self.wandb_run_name = wandb_run_name
+ self.local_cache_dir = local_cache_dir
+ self.local_run_cache_dir = None
+ self._initialize_monitor(monitors)
+
self.streaming_manager = StreamingManager()
if streaming == "console":
self.streaming_manager.add_observer(ConsoleStreamObserver())
@@ -177,6 +184,17 @@ def _preprocess_messages(self, messages: List[Dict]):
return messages_list.to_list()
+ def _initialize_monitor(self, monitors: List[str]) -> None:
+ for monitor in monitors:
+ if monitor == "local":
+ assert self.local_cache_dir is not None, "local_cache_dir must be set when using local monitor."
+ self.local_run_cache_dir = f"{os.path.join(self.local_cache_dir, os.path.basename(self.model_name_or_path), datetime.now().strftime('%Y%m%d_%H%M%S'))}"
+ Monitor.add_sink("jsonl", JsonlSink(f"{self.local_run_cache_dir}/"))
+ elif monitor == "wandb":
+ Monitor.add_sink("wandb", WandbSink(project=self.wandb_project_name, run_name=self.wandb_run_name))
+ else:
+ raise ValueError(f"Monitor {monitor} is not supported.")
+
async def run(self,
messages: Union[List[dict], np.ndarray, Dict],
max_turns: int,
@@ -392,4 +410,4 @@ def get_verl_data_proto(self):
batch = DataProto.from_single_dict(inputs, meta_info={"use_agent": True})
return batch
-
\ No newline at end of file
+
diff --git a/agentfly/agents/chain/chain_base.py b/agentfly/agents/chain/chain_base.py
index c610c23..8ac5053 100644
--- a/agentfly/agents/chain/chain_base.py
+++ b/agentfly/agents/chain/chain_base.py
@@ -137,7 +137,6 @@ def __init__(self):
self.terminal_status = ["terminal", "finish"]
self.global_step = 0
self.finished_chains_count = 0
- self.initialize_monitor()
self.monitor_info = defaultdict(list)
def reset(self) -> None:
@@ -333,7 +332,7 @@ async def _run_single_chain(self,
await done_queue.put((chain_id, chain, current_node))
self.finished_chains_count += 1
- self.monitor_chain()
+ self.monitor_chain(trajectory=current_node.messages.messages)
async def _generate_response(self, current_node, tools, depth, chain_id, enable_streaming):
"""Generate response with optional streaming support."""
@@ -485,7 +484,6 @@ async def _finalize_chain(self, chain_id, chain, current_node, depth):
await self.release_resources(chain_id)
-
async def release_resources(self, id: str) -> None:
for tool in self.tools:
if isinstance(tool, Tool):
@@ -498,10 +496,6 @@ async def set_tools(self, id: str, env_args: Dict[str, Any]) -> None:
if isinstance(tool, Tool):
await tool.set_env(id, env_args)
- def initialize_monitor(self) -> None:
- Monitor.add_sink("jsonl", JsonlSink(f"{AGENT_DATA_DIR}/demo_metrics.jsonl"))
- Monitor.add_sink("wandb", WandbSink(project=self.project_name, run_name=self.run_name))
-
def monitor_step(self) -> None:
messages = self.get_messages()
avg_turns = 0
@@ -589,9 +583,19 @@ def monitor_step(self) -> None:
emit(evt)
- def monitor_chain(self) -> None:
+ def monitor_chain(self, trajectory) -> None:
self.monitor_info['Agent/chains'].append(self.finished_chains_count)
for tool in self.tools:
if tool.is_stateful and tool.pool_size > 0:
self.monitor_info[f"Agent/Tool/{tool.name}/used_env_size"].append(tool.used_env_size)
+ # We only log the trajectory to local jsonl file, for wandb much bandwidth is needed
+ evt = MetricEvent(
+ sinks=["jsonl"],
+ kind="text",
+ name="Agent/rollout/trajectory",
+ value=json.dumps(serialize_for_json(trajectory), indent=2),
+ x=self.global_step,
+ x_name="Agent/rollout/step"
+ )
+ emit(evt)
diff --git a/agentfly/agents/chain/chain_base_simplified.py b/agentfly/agents/chain/chain_base_simplified.py
new file mode 100644
index 0000000..533fb64
--- /dev/null
+++ b/agentfly/agents/chain/chain_base_simplified.py
@@ -0,0 +1,635 @@
+import asyncio
+from collections import defaultdict
+from dataclasses import dataclass, field
+import json
+import time
+from ..utils.messages import MessagesList, Messages
+from ...utils.timing import Timer
+from typing import Any, Dict, List, Optional, Tuple, Union, Callable
+import uuid
+from termcolor import colored
+import numpy as np
+from copy import deepcopy
+from ...tools.tool_base import Tool, submit_tool_call, submit_tool_calls
+from tqdm.asyncio import tqdm_asyncio
+from ...utils.monitor import JsonlSink, MetricEvent, Monitor, WandbSink, emit, serialize_for_json
+from ... import AGENT_DATA_DIR
+import wandb
+from .streaming_observer import ConsoleStreamObserver, StreamingManager, StreamEvent, StreamEventType
+
+@dataclass
+class SimplifiedNode:
+ """
+ Simplified node that aligns with message types.
+ Each node represents a single message turn and only stores its own content.
+ """
+ role: str # "user", "assistant", "tool", "system"
+ content: Any # The actual message content
+ is_terminal: bool = False
+ is_pruned: bool = False
+ description: str = ""
+ observation: str = "" # For tool nodes, stores the tool result
+ tool_name: Optional[str] = None # For tool nodes
+ tool_call_id: Optional[str] = None # For tool nodes
+ parent: Optional["SimplifiedNode"] = None
+ children: List["SimplifiedNode"] = field(default_factory=list)
+
+ @property
+ def depth(self) -> int:
+ return 0 if self.parent is None else self.parent.depth + 1
+
+ def print_node(self, process_id: int = 0) -> None:
+ if process_id != 0:
+ return
+ color_converter = {
+ "user": "green",
+ "assistant": "blue",
+ "tool": "yellow",
+ "system": "magenta"
+ }
+ color = color_converter.get(self.role, "white")
+ print(colored(f"{self.role.upper()}: {self.description}", color=color))
+ if self.observation:
+ obs = (
+ self.observation
+ if len(self.observation) < 1536
+ else f"{self.observation[:1536]}...(len={len(self.observation)})"
+ )
+ print(colored(f"Observation: {obs}", color="yellow"))
+
+ def to_json(self) -> dict:
+ json_obj = {
+ "role": self.role,
+ "is_terminal": self.is_terminal,
+ "is_pruned": self.is_pruned,
+ "depth": self.depth,
+ "description": self.description,
+ "content": self.content
+ }
+ if self.observation:
+ json_obj["observation"] = self.observation
+ if self.tool_name:
+ json_obj["tool_name"] = self.tool_name
+ if self.tool_call_id:
+ json_obj["tool_call_id"] = self.tool_call_id
+ return json_obj
+
+ def to_json_recursive(self) -> dict:
+ data = self.to_json()
+ data["children"] = [child.to_json_recursive() for child in self.children]
+ return data
+
+
+class SimplifiedChain:
+ """
+ Simplified chain that stores nodes aligned with message types.
+ Each node represents a single message turn.
+ """
+ def __init__(self, info):
+ self.root: Optional[SimplifiedNode] = None
+ self.info: Dict[str, Any] = info
+ self.system_message: Optional[SimplifiedNode] = None
+
+ def add_node(
+ self,
+ role: str,
+ content: Any,
+ is_terminal: bool = False,
+ is_pruned: bool = False,
+ description: str = "",
+ observation: str = "",
+ tool_name: Optional[str] = None,
+ tool_call_id: Optional[str] = None
+ ) -> SimplifiedNode:
+ new_node = SimplifiedNode(
+ role=role,
+ content=content,
+ is_terminal=is_terminal,
+ is_pruned=is_pruned,
+ description=description,
+ observation=observation,
+ tool_name=tool_name,
+ tool_call_id=tool_call_id
+ )
+
+ if self.root is None:
+ self.root = new_node
+ else:
+ current = self.root
+ while len(current.children) > 0:
+ current = current.children[0]
+ current.children = [new_node]
+ new_node.parent = current
+ return new_node
+
+ def get_full_messages(self) -> List[Dict[str, Any]]:
+ """
+ Reconstruct the full message history from the chain nodes.
+ """
+ messages = []
+ node = self.root
+
+ # Add system message if it exists
+ if self.system_message:
+ messages.append({
+ "role": "system",
+ "content": self.system_message.content
+ })
+
+ # Traverse the chain and collect messages
+ while node:
+ if node.role == "tool":
+ # For tool messages, we need to reconstruct the tool message format
+ messages.append({
+ "role": "tool",
+ "tool_call_id": node.tool_call_id,
+ "tool_name": node.tool_name,
+ "content": [{"type": "text", "text": node.observation}]
+ })
+ else:
+ # For user and assistant messages
+ messages.append({
+ "role": node.role,
+ "content": node.content
+ })
+
+ if node.children:
+ node = node.children[0]
+ else:
+ break
+
+ return messages
+
+ def to_json(self) -> List[dict]:
+ chain_json = []
+ node = self.root
+ while node:
+ chain_json.append(node.to_json())
+ if node.children:
+ node = node.children[0]
+ else:
+ break
+ return chain_json
+
+
+class SimplifiedChainRollout:
+ """
+ Simplified chain-based rollout that uses message-aligned nodes.
+ """
+ def __init__(self):
+ self.reset()
+ self.chains: Dict[str, SimplifiedChain] = {}
+ self.current_nodes: Dict[str, SimplifiedNode] = {}
+ self.timer = Timer()
+ self.terminal_status = ["terminal", "finish"]
+ self.global_step = 0
+ self.finished_chains_count = 0
+ self.monitor_info = defaultdict(list)
+
+ def reset(self) -> None:
+ self.status_code: str = "continue"
+ self.query_count: int = 0
+ self.total_tokens: int = 0
+ self.success_count: int = 0
+ self.chains = []
+ self.current_nodes = {}
+
+ @property
+ def timing_data(self):
+ return self.timer.timing_data
+
+ def to_json(self) -> dict:
+ return {
+ "finish": [chain.status_code == "success" for chain in self.chains],
+ "chains": [chain.to_json() for chain in self.chains]
+ }
+
+ def initialize_chains(self, messages_list: MessagesList, num_chains: int) -> Tuple[Dict[str, SimplifiedChain], Dict[str, SimplifiedNode]]:
+ chains = {}
+ start_nodes = {}
+ group_ids = [str(uuid.uuid4()) for _ in range(len(messages_list))]
+
+ for group_idx, messages in enumerate(messages_list):
+ group_id = group_ids[group_idx]
+ for j in range(num_chains):
+ ch = SimplifiedChain(messages.meta | {"group_id": group_id})
+
+ # Extract system message if present
+ if messages.messages and messages.messages[0]["role"] == "system":
+ system_content = messages.messages[0]["content"]
+ ch.system_message = SimplifiedNode(
+ role="system",
+ content=system_content,
+ description="System prompt"
+ )
+ user_messages = messages.messages[1:]
+ else:
+ user_messages = messages.messages
+
+ # Create user node with the initial user message(s)
+ user_content = user_messages[0]["content"] if user_messages else ""
+ root = ch.add_node(
+ role="user",
+ content=user_content,
+ description="Initial user input"
+ )
+
+ cid = str(uuid.uuid4())
+ chains[cid] = ch
+ start_nodes[cid] = root
+
+ return chains, start_nodes
+
+ def get_messages(self) -> List[Any]:
+ messages = []
+ for id, node in self.current_nodes.items():
+ info = self.chains[id].info
+ message_item = {}
+ message_item["messages"] = self.chains[id].get_full_messages()
+ message_item.update(info)
+ messages.append(message_item)
+ return messages
+
+ def validate_run_args(self, max_turns: int, num_chains: int, enable_streaming: bool):
+ assert max_turns >= 1, "max_turns must be at least 1."
+ assert num_chains >= 1, "num_chains must be at least 1."
+ for observer in self.streaming_manager.observers:
+ if isinstance(observer, ConsoleStreamObserver) and enable_streaming:
+ assert num_chains == 1, "num_chains must be 1 when ConsoleStreamObserver is used."
+
+
+ async def run_async(self,
+ messages: List[Dict],
+ max_turns: int,
+ num_chains: int,
+ generation_config: Optional[Dict[str, Any]] = None,
+ enable_streaming: bool = False,
+ streaming_callback: Optional[Callable] = None,
+ ):
+ """
+ Run the simplified chain-based rollout.
+ """
+ self.validate_run_args(max_turns, num_chains, enable_streaming)
+ Monitor.ensure_started()
+ self.reset()
+
+ messages_list = MessagesList.from_data(messages)
+ chains, first_nodes = self.initialize_chains(
+ messages_list,
+ num_chains
+ )
+ tool_schemas = [tool.schema for tool in self.tools]
+
+ done_q = asyncio.Queue()
+ tasks = [
+ asyncio.create_task(
+ self._run_single_chain(
+ cid,
+ node,
+ chains[cid],
+ tool_schemas,
+ max_turns=max_turns,
+ done_queue=done_q,
+ enable_streaming=enable_streaming
+ )
+ )
+ for cid, node in first_nodes.items()
+ ]
+
+ await tqdm_asyncio.gather(*tasks)
+
+ self.chains = {}
+ while not done_q.empty():
+ cid, chain, node = done_q.get_nowait()
+ self.chains[cid] = chain
+ self.current_nodes[cid] = node
+
+ self.global_step += 1
+ self.monitor_step()
+
+ async def _run_single_chain(self,
+ chain_id: str,
+ first_node: SimplifiedNode,
+ chain: SimplifiedChain,
+ tools: List[Dict],
+ max_turns: int,
+ done_queue: asyncio.Queue,
+ enable_streaming: bool = False
+ ):
+ """
+ Run a single simplified chain.
+ """
+ current_node = first_node
+ depth = 0
+ have_set_tools = False
+
+ while not current_node.is_terminal and depth < max_turns:
+ # Get current message history for generation
+ current_messages = chain.get_full_messages()
+
+ if not current_node.is_terminal:
+ # Generate assistant response
+ assistant_msg = await self._generate_response(
+ current_messages, tools, depth, chain_id, enable_streaming
+ )
+
+ # Create assistant node
+ assistant_node = chain.add_node(
+ role="assistant",
+ content=assistant_msg.get("content", ""),
+ description=assistant_msg.get("content", "")[:100] + "..." if len(assistant_msg.get("content", "")) > 100 else assistant_msg.get("content", ""),
+ is_terminal=assistant_msg.get("status", "continue") in self.terminal_status
+ )
+ current_node = assistant_node
+
+ # Check if the assistant node is terminal
+ if current_node.is_terminal:
+ break
+
+ # Handle tool calls
+ if assistant_msg.get("tool_calls"):
+ for tool_call in assistant_msg["tool_calls"]:
+ result = await self._execute_tool_call(
+ tool_call, chain, chain_id, depth,
+ have_set_tools, enable_streaming
+ )
+ have_set_tools = True
+
+ # Create tool node
+ tool_node = chain.add_node(
+ role="tool",
+ content=result.get("arguments", ""),
+ description=f"Tool: {result.get('name', 'unknown')}",
+ observation=result["observation"],
+ tool_name=result.get("name"),
+ tool_call_id=tool_call["id"],
+ is_terminal=result["status"] in self.terminal_status
+ )
+ current_node = tool_node
+ else:
+ # No tool calls, chain is finished
+ break
+
+ depth += 1
+
+ # Finalize chain
+ await self._finalize_chain(chain_id, chain, current_node, depth)
+ await done_queue.put((chain_id, chain, current_node))
+
+ self.finished_chains_count += 1
+ self.monitor_chain(trajectory=chain.get_full_messages())
+
+ async def _generate_response(self, current_messages, tools, depth, chain_id, enable_streaming):
+ """Generate response with optional streaming support."""
+ if enable_streaming:
+ # Emit generation start event
+ await self.streaming_manager.emit_event(StreamEvent(
+ event_type=StreamEventType.LLM_GENERATION_START,
+ chain_id=chain_id,
+ timestamp=time.time(),
+ data={"depth": depth},
+ step=depth,
+ depth=depth
+ ))
+
+ # Check if we have streaming capabilities
+ has_streaming = False
+ if hasattr(self, 'generate_streaming'):
+ has_streaming = True
+ elif hasattr(self, 'llm_engine') and hasattr(self.llm_engine, 'generate_streaming'):
+ has_streaming = True
+ # Create a wrapper to use the LLM engine's streaming
+ async def generate_streaming_wrapper(messages_list, **kwargs):
+ async for chunk in self.llm_engine.generate_streaming(messages_list, **kwargs):
+ yield chunk
+ self.generate_streaming = generate_streaming_wrapper
+
+ if has_streaming:
+ # Collect full response from streaming
+ full_response = ""
+ async for chunk in self.generate_streaming([current_messages], tools=tools):
+ await self.streaming_manager.emit_event(StreamEvent(
+ event_type=StreamEventType.LLM_GENERATION_CHUNK,
+ chain_id=chain_id,
+ timestamp=time.time(),
+ data={"content": chunk},
+ step=depth,
+ depth=depth
+ ))
+ full_response = chunk
+
+ # Emit generation end event
+ await self.streaming_manager.emit_event(StreamEvent(
+ event_type=StreamEventType.LLM_GENERATION_END,
+ chain_id=chain_id,
+ timestamp=time.time(),
+ data={"full_response": full_response},
+ step=depth,
+ depth=depth
+ ))
+
+ # Parse response
+ new_msg = self.parse([full_response], tools=self.tools)
+ return new_msg[0]
+ else:
+ # Fallback to non-streaming generation
+ responses = await self.generate_async([current_messages], tools=tools, num_return_sequences=1)
+ new_msg = self.parse(responses, tools=self.tools)
+
+ # Emit a single chunk event for the full response
+ full_response = new_msg[0].get("content", "")
+ if isinstance(full_response, list) and len(full_response) > 0:
+ if isinstance(full_response[0], dict) and "text" in full_response[0]:
+ full_response = full_response[0]["text"]
+ else:
+ full_response = str(full_response)
+ elif not isinstance(full_response, str):
+ full_response = str(full_response)
+
+ await self.streaming_manager.emit_event(StreamEvent(
+ event_type=StreamEventType.LLM_GENERATION_CHUNK,
+ chain_id=chain_id,
+ timestamp=time.time(),
+ data={"content": full_response},
+ step=depth,
+ depth=depth
+ ))
+
+ # Emit generation end event
+ await self.streaming_manager.emit_event(StreamEvent(
+ event_type=StreamEventType.LLM_GENERATION_END,
+ chain_id=chain_id,
+ timestamp=time.time(),
+ data={"full_response": full_response},
+ step=depth,
+ depth=depth
+ ))
+
+ return new_msg[0]
+ else:
+ # Non-streaming generation
+ responses = await self.generate_async([current_messages], tools=tools, num_return_sequences=1)
+ new_msg = self.parse(responses, tools=self.tools)
+ return new_msg[0]
+
+ async def _execute_tool_call(self, tool_call, chain, chain_id, depth, have_set_tools, enable_streaming):
+ """Execute a tool call with optional streaming support."""
+ tool_name = tool_call["function"]["name"]
+ tool_input = tool_call["function"]["arguments"]
+
+ # Set up tools if needed
+ if not have_set_tools:
+ await self.set_tools(chain_id, chain.info)
+ have_set_tools = True
+
+ # Execute tool call
+ result = await submit_tool_call(
+ tool_name,
+ tool_input,
+ id=chain_id,
+ allowed_tool_names=self.tool_names
+ )
+
+ if enable_streaming:
+ # Emit tool observation event
+ await self.streaming_manager.emit_event(StreamEvent(
+ event_type=StreamEventType.TOOL_OBSERVATION,
+ chain_id=chain_id,
+ timestamp=time.time(),
+ data={
+ "tool_name": tool_name,
+ "observation": result["observation"],
+ "status": result["status"]
+ },
+ step=depth,
+ depth=depth
+ ))
+
+ return result
+
+
+ async def _finalize_chain(self, chain_id, chain, current_node, depth):
+ """Finalize the chain with reward calculation and cleanup."""
+ if self._reward_fn is not None:
+ trajectory = chain.get_full_messages()
+ final_response = self.extract_final_response(trajectory)
+ other_args = {k: v for k, v in chain.info.items() if k not in ['prediction', 'trajectory', 'id']}
+ chain.info["reward"] = await self._reward_fn(prediction=final_response, **other_args, trajectory=trajectory, id=chain_id)
+ else:
+ chain.info["reward"] = None
+
+ await self.release_resources(chain_id)
+
+ async def release_resources(self, id: str) -> None:
+ for tool in self.tools:
+ if isinstance(tool, Tool):
+ await tool.release(id=id)
+ if self._reward_fn is not None:
+ await self._reward_fn.release(id=id)
+
+ async def set_tools(self, id: str, env_args: Dict[str, Any]) -> None:
+ for tool in self.tools:
+ if isinstance(tool, Tool):
+ await tool.set_env(id, env_args)
+
+ def monitor_step(self) -> None:
+ messages = self.get_messages()
+ avg_turns = 0
+ avg_tool_calls = 0
+ avg_response_length = 0
+ tool_calls_by_name = defaultdict(int)
+
+ for message in messages:
+ for msg in message['messages']:
+ if msg['role'] == 'assistant':
+ avg_turns += 1
+ if msg['role'] == 'tool':
+ avg_tool_calls += 1
+ tool_call_name = msg['tool_name']
+ tool_calls_by_name[tool_call_name] += 1
+
+ avg_turns /= len(messages)
+ avg_tool_calls /= len(messages)
+
+ ent = MetricEvent(
+ kind="scalar",
+ name=f"Agent/rollout/step",
+ value=self.global_step,
+ x=self.global_step,
+ x_name="Agent/rollout/step"
+ )
+ emit(ent)
+
+ evt = MetricEvent(
+ kind="scalar",
+ name=f"Agent/rollout/avg_turns",
+ value=avg_turns,
+ x=self.global_step,
+ x_name="Agent/rollout/step"
+ )
+ emit(evt)
+
+ evt = MetricEvent(
+ kind="scalar",
+ name=f"Agent/rollout/avg_tool_calls",
+ value=avg_tool_calls,
+ x=self.global_step,
+ x_name="Agent/rollout/step"
+ )
+ emit(evt)
+
+
+ for tool_name, tool_call_count in tool_calls_by_name.items():
+ evt = MetricEvent(
+ kind="scalar",
+ name=f"Agent/rollout/tool_calls/{tool_name}",
+ value=tool_call_count,
+ x=self.global_step,
+ x_name="Agent/rollout/step"
+ )
+ emit(evt)
+
+ evt = MetricEvent(
+ kind="scalar",
+ name=f"Agent/rollout/step",
+ value=self.global_step,
+ x=self.global_step,
+ x_name="Agent/rollout/step"
+ )
+ emit(evt)
+
+ sample_message_json = json.dumps(serialize_for_json(messages[0]), indent=2)
+ evt = MetricEvent(
+ kind="text",
+ name="Agent/rollout/sample_message",
+ value=sample_message_json,
+ x=self.global_step,
+ x_name="Agent/rollout/step"
+ )
+ emit(evt)
+
+ for k, v in self.monitor_info.items():
+ if k != "Agent/chains": # We don't log number of chains
+ evt = MetricEvent(
+ kind="list",
+ name=k,
+ value=v,
+ x=self.monitor_info['Agent/chains'],
+ )
+ emit(evt)
+
+
+ def monitor_chain(self, trajectory) -> None:
+ self.monitor_info['Agent/chains'].append(self.finished_chains_count)
+ for tool in self.tools:
+ if tool.is_stateful and tool.pool_size > 0:
+ self.monitor_info[f"Agent/Tool/{tool.name}/used_env_size"].append(tool.used_env_size)
+
+ evt = MetricEvent(
+ kind="text",
+ name="Agent/rollout/trajectory",
+ value=json.dumps(serialize_for_json(trajectory), indent=2),
+ x=self.global_step,
+ x_name="Agent/rollout/step"
+ )
+ emit(evt)
diff --git a/agentfly/agents/llm_backends/llm_backends.py b/agentfly/agents/llm_backends/llm_backends.py
index 7ab9a4b..df62335 100644
--- a/agentfly/agents/llm_backends/llm_backends.py
+++ b/agentfly/agents/llm_backends/llm_backends.py
@@ -21,8 +21,7 @@
import logging
import PIL
-
-LOGGER = logging.getLogger(__name__)
+logger = logging.getLogger(__name__)
try:
from verl.protocol import DataProto
@@ -340,13 +339,13 @@ async def generate_async(self, messages_list: str, **kwargs) -> str:
inputs = self._process_inputs(prompts, vision_inputs)
if n > 1:
inputs = [_input for _input in inputs for _ in range(n)]
- LOGGER.debug(f"[AsyncVLLMBackend] inputs: {inputs}")
+ logger.debug(f"[AsyncVLLMBackend] inputs: {inputs}")
tasks = [self._generate_single(_input, sampling_params) for _input in inputs]
outputs = await asyncio.gather(*tasks)
# Flatten the outputs
outputs = [output for output_list in outputs for output in output_list]
response_texts = [output.text for output in outputs]
- LOGGER.debug(f"[AsyncVLLMBackend] response_texts: {response_texts}")
+ logger.debug(f"[AsyncVLLMBackend] response_texts: {response_texts}")
return response_texts
@@ -513,7 +512,7 @@ def __init__(
# --------------------------------------------------------------------- #
# Low‑level single request (runs in threadpool so it doesn't block loop)
# --------------------------------------------------------------------- #
- @retry(stop=stop_after_attempt(1), wait=wait_exponential(multiplier=1, min=4, max=15))
+ @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=15))
def _blocking_call(self, messages: List[List[Dict]], **kwargs) -> str:
if "num_return_sequences" in kwargs:
n = kwargs.pop("num_return_sequences")
@@ -525,7 +524,6 @@ def _blocking_call(self, messages: List[List[Dict]], **kwargs) -> str:
else:
tool_choice = "none"
- print(f"[ClientBackend] messages: {messages}")
resp = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
@@ -570,7 +568,7 @@ def _convert_to_openai_chat_without_tool_call_processing(self, messages: list) -
return messages
# Public API ‑‑ sync or async depending on caller's context
- def async_generate(
+ def generate(
self,
messages: List[List[Dict]] | List[Dict],
**kwargs,
@@ -588,10 +586,12 @@ def async_generate(
messages_list = [messages] # single
else:
messages_list = messages # batch
- print(f"[ClientBackend] messages_list: {messages_list}")
+ logger.debug(f"[ClientBackend] messages_list: {messages_list}")
messages_list = [self._convert_to_openai_chat_without_tool_call_processing(messages) for messages in messages_list]
async def _runner():
+ # Ensure refiller is running in this event loop
+ self._ensure_refiller_running()
tasks = [asyncio.create_task(self._call(_input, **kwargs)) for _input in messages_list]
# Flatten the response list
response_texts_list_or_dict = await asyncio.gather(*tasks)
@@ -621,7 +621,7 @@ async def _runner():
async def generate_async(self,
messages: List[List[Dict]] | List[Dict],
**kwargs) -> List[str]:
- return await self.async_generate(messages, **kwargs)
+ return await self.generate(messages, **kwargs)
# Background token‑bucket refill (one token each 60/max_rpm seconds)
async def _refill_tokens(self):
@@ -633,11 +633,11 @@ async def _refill_tokens(self):
def _ensure_refiller_running(self):
if self._refill_task is None or self._refill_task.done():
- loop = asyncio.get_event_loop()
- self._refill_task = loop.create_task(self._refill_tokens())
-
- # Automatically start the refiller at first public use
- def __getattribute__(self, name):
- if name == "generate":
- self._ensure_refiller_running()
- return super().__getattribute__(name)
+ try:
+ # Try to get running loop first
+ loop = asyncio.get_running_loop()
+ self._refill_task = loop.create_task(self._refill_tokens())
+ except RuntimeError:
+ # No event loop running, this will be handled by the caller
+ # The refiller will be started when we're in an event loop
+ pass
diff --git a/agentfly/agents/specialized/hf_agent.py b/agentfly/agents/specialized/hf_agent.py
index df75ba5..eee913f 100644
--- a/agentfly/agents/specialized/hf_agent.py
+++ b/agentfly/agents/specialized/hf_agent.py
@@ -1,6 +1,7 @@
from ast import Dict
import json
+import os
from typing import List
from ..agent_base import BaseAgent
from ..parsers import extract_tool_calls
diff --git a/agentfly/rewards/__init__.py b/agentfly/rewards/__init__.py
index 09ba592..ee2e107 100644
--- a/agentfly/rewards/__init__.py
+++ b/agentfly/rewards/__init__.py
@@ -17,4 +17,6 @@
from .alfworld_reward import alfworld_episode_reward
from .scienceworld_reward import scienceworld_reward
from .gui_reward import gui_reward
+from .vlm_as_judge.vlm_as_judge_reward import vlm_as_judge_reward
+from .vlm_as_judge.vlm_as_judge_reward import vlm_as_judge_pass_reward
diff --git a/agentfly/rewards/llm_as_judge/llm_as_judge_client.py b/agentfly/rewards/llm_as_judge/llm_as_judge_client.py
index a33078e..895cf62 100644
--- a/agentfly/rewards/llm_as_judge/llm_as_judge_client.py
+++ b/agentfly/rewards/llm_as_judge/llm_as_judge_client.py
@@ -415,4 +415,3 @@ async def llm_as_judge_client_math_reward(prediction: str, answer: str) -> float
"reward": 0.0,
}
-
diff --git a/agentfly/tests/unit/agents/templates/__init__.py b/agentfly/rewards/vlm_as_judge/__init__.py
similarity index 100%
rename from agentfly/tests/unit/agents/templates/__init__.py
rename to agentfly/rewards/vlm_as_judge/__init__.py
diff --git a/agentfly/rewards/vlm_as_judge/vlm_as_judge_client.py b/agentfly/rewards/vlm_as_judge/vlm_as_judge_client.py
new file mode 100644
index 0000000..79fd483
--- /dev/null
+++ b/agentfly/rewards/vlm_as_judge/vlm_as_judge_client.py
@@ -0,0 +1,478 @@
+import os
+import re
+import json
+import glob
+import time
+import asyncio
+import logging
+from typing import Any, Dict, List, Optional
+
+from pathlib import Path
+from openai import AsyncOpenAI
+from tqdm.asyncio import tqdm_asyncio
+
+from ... import AGENT_HOME
+
+logger = logging.getLogger(__name__)
+
+
+def _resolve_server_status_dir() -> str:
+ """Resolve the directory that contains vLLM server status JSON files.
+
+ Priority:
+ 1) env `VLLM_SERVER_STATUS_DIR`
+ 2) env `DATA_PROCESS_HOME` + /vllm_server/server_status
+ 3) AGENT_HOME + /data-process/vllm_server/server_status
+ 4) Explicit fallback to /mnt/weka/... path as provided by user
+ """
+ # 1) Explicit override
+ override = os.getenv("VLLM_SERVER_STATUS_DIR")
+ if override and os.path.isdir(override):
+ return override
+
+ # 2) data-process home
+ dp_home = os.getenv("DATA_PROCESS_HOME")
+ if dp_home:
+ candidate = os.path.join(dp_home, "vllm_server", "server_status")
+ if os.path.isdir(candidate):
+ return candidate
+
+ # 3) default relative to AGENT_HOME
+ candidate = os.path.join(AGENT_HOME, "data-process", "vllm_server", "server_status")
+ if os.path.isdir(candidate):
+ return candidate
+
+ # Last resort: return the AGENT_HOME-based path even if missing (caller errors out)
+ return candidate
+
+
+def get_server_ips(model: str) -> List[str]:
+ """Get list of server IPs from the most recent complete server instances file for a specific model."""
+ server_status_dir = _resolve_server_status_dir()
+
+ # Clean model name for filename matching (replace / and - with _)
+ model_clean = model.replace('/', '_').replace('-', '_')
+
+ # Try multiple patterns to match the server files
+ patterns = [
+ f"server_instances_complete_{model_clean}_*.json",
+ f"server_instances_complete_vllm_{model_clean}_*.json",
+ f"server_instances_complete_*{model_clean}*.json"
+ ]
+
+ json_files = []
+ for pattern in patterns:
+ search_pattern = os.path.join(server_status_dir, pattern)
+ found_files = glob.glob(search_pattern)
+ if found_files:
+ json_files = found_files
+ break
+
+ if not json_files:
+ # Fallback: try to find any server instances file and filter by model in the JSON content
+ fallback_pattern = os.path.join(server_status_dir, "server_instances_complete_*.json")
+ all_files = glob.glob(fallback_pattern)
+
+ for file in all_files:
+ try:
+ with open(file, 'r') as f:
+ server_info = json.load(f)
+
+ # Check if any server in this file matches our model
+ matching_servers = [info for info in server_info if info.get('model') == model]
+ if matching_servers:
+ json_files = [file]
+ logger.info(f"Found servers for model '{model}' in fallback file: {file}")
+ break
+ except Exception as e:
+ logger.warning(f"Error reading file {file}: {e}")
+ continue
+
+ if not json_files:
+ raise RuntimeError(
+ f"No server instances file found for model '{model}' in {search_pattern} or any fallback file under {server_status_dir}"
+ )
+
+ # Get the most recent file
+ latest_file = max(json_files, key=os.path.getctime)
+
+ with open(latest_file, 'r') as f:
+ server_info = json.load(f)
+
+ # Filter servers by model and extract IPs
+ ips = []
+ for info in server_info:
+ if info.get('model') == model and 'ip' in info:
+ ips.append(info['ip'])
+
+ if not ips:
+ raise RuntimeError(f"No IPs found for model '{model}' in server instances file {latest_file}")
+
+ logger.info(f"Found {len(ips)} server instances for model '{model}': {ips}")
+ return ips
+
+
+class RateLimiter:
+ def __init__(self, max_window_size: int):
+ self.max_window_size = max_window_size
+ self.semaphore = asyncio.Semaphore(max_window_size)
+
+ async def acquire(self):
+ await self.semaphore.acquire()
+
+ async def release(self):
+ self.semaphore.release()
+
+
+class RoundRobinClient:
+ def __init__(self, ips: List[str], port: int, api_key: str, timeout: int, rate_limiters: List[RateLimiter]):
+ self.ips = ips
+ self.current_index = 0
+ self.port = port
+ self.api_key = api_key
+ self.clients = [
+ AsyncOpenAI(
+ base_url=f"http://{ip}:{port}/v1",
+ api_key=api_key,
+ timeout=timeout,
+ ) for ip in ips
+ ]
+ self.rate_limiters = rate_limiters
+
+ async def get_next_available_client(self) -> tuple[AsyncOpenAI, RateLimiter]:
+ # Find the instance with the most available slots
+ max_available = -1
+ best_client = None
+ best_limiter = None
+ best_index = -1
+
+ for i in range(len(self.clients)):
+ available = self.rate_limiters[i].semaphore._value
+ if available > max_available and available > 0:
+ max_available = available
+ best_client = self.clients[i]
+ best_limiter = self.rate_limiters[i]
+ best_index = i
+
+ if best_client is not None:
+ await best_limiter.acquire()
+ return best_client, best_limiter
+
+ # If no instance has available slots, wait on all and race
+ wait_tasks = [(i, asyncio.create_task(limiter.semaphore.acquire()))
+ for i, limiter in enumerate(self.rate_limiters)]
+ done, pending = await asyncio.wait(
+ [task for _, task in wait_tasks],
+ return_when=asyncio.FIRST_COMPLETED
+ )
+
+ for _, task in wait_tasks:
+ if task not in done:
+ task.cancel()
+
+ for i, task in wait_tasks:
+ if task in done:
+ return self.clients[i], self.rate_limiters[i]
+
+ raise RuntimeError("No instance became available despite wait completion")
+
+
+class VLMClient:
+ def __init__(self,
+ model: str,
+ timeout_seconds: int = 60,
+ max_window_size_per_instance: int = 10,
+ port: int = 8000,
+ api_key: str = "token-abc123"):
+ self.timeout_seconds = timeout_seconds
+ server_ips = get_server_ips(model)
+ # server_ips = ["10.24.3.24"]
+ rate_limiters = [RateLimiter(max_window_size_per_instance) for _ in server_ips]
+ self.client_manager = RoundRobinClient(server_ips, port, api_key, timeout_seconds, rate_limiters)
+
+ async def single_call(self, inputs, model, **kwargs):
+ try:
+ client, rate_limiter = await self.client_manager.get_next_available_client()
+ try:
+ response = await client.chat.completions.create(
+ model=model,
+ messages=inputs,
+ timeout=self.timeout_seconds,
+ **kwargs
+ )
+ return response.choices[0].message.content
+ finally:
+ await rate_limiter.release()
+ except asyncio.TimeoutError:
+ logger.error(f"Request timed out after {self.timeout_seconds}s")
+ return None
+ except Exception as e:
+ logger.error(f"Error processing request: {e}")
+ return None
+
+ async def process_all_inputs(self, inputs_list, num_generations=1, model=None, **kwargs):
+ all_tasks = []
+ for inputs in inputs_list:
+ for _ in range(num_generations):
+ all_tasks.append(self.single_call(inputs, model=model, **kwargs))
+
+ responses = await tqdm_asyncio.gather(*all_tasks, desc="Processing VLM requests")
+
+ grouped_responses = []
+ for i in range(0, len(responses), num_generations):
+ grouped_responses.append(responses[i:i + num_generations])
+
+ return grouped_responses
+
+ def check_availability(self) -> Dict[str, Any]:
+ # Mirror llm client basic stats
+ total_capacity = 0
+ total_available = 0
+ total_used = 0
+ instance_details = []
+
+ for i in range(len(self.client_manager.clients)):
+ available = self.client_manager.rate_limiters[i].semaphore._value
+ capacity = self.client_manager.rate_limiters[i].max_window_size
+ used = capacity - available
+
+ total_capacity += capacity
+ total_available += available
+ total_used += used
+
+ instance_details.append({
+ 'instance_id': i,
+ 'ip': self.client_manager.ips[i],
+ 'port': self.client_manager.port,
+ 'available_slots': available,
+ 'used_slots': used,
+ 'total_slots': capacity,
+ })
+
+ return {
+ 'total_instances': len(self.client_manager.clients),
+ 'total_capacity': total_capacity,
+ 'total_available': total_available,
+ 'total_used': total_used,
+ 'has_available_slots': total_available > 0,
+ 'instances': instance_details
+ }
+
+ def is_available(self) -> bool:
+ availability = self.check_availability()
+ return availability['has_available_slots']
+
+
+DEFAULT_VLM_PROMPT_TEMPLATE = """You are given a set of visual verification questions and a description of objects and motion observed in an input medium (e.g., image or video).
+
+Your task is to **evaluate each question** based on whether it is **correctly reflected in the visual content**, considering visual cues, shape changes from viewpoint, and possible symbolic representations.
+
+
+---
+
+ **Visual Reasoning Guidelines**:
+
+1. **Perspective Awareness**:
+ Objects may appear different based on viewpoint. For example:
+ - A **cylinder** may look like a **circle (top view)** or a **rectangle/square (side view)**.
+ - A **circular path** may appear as a **wave-like curve or straight line** in 2D projection.
+
+2. **Symbolic Representations**:
+ Common simplifications may be used. You should **reasonably infer** their meaning:
+ - A series of **dots or circles** may represent **foam markers** or control points.
+ - A **rectangle** may represent a **container** (e.g., cylindrical viewed from the side).
+ - A **line** may represent a **rubber mat** or constraint boundary.
+ - The object and track specifics might do not match directly, if the motion can be interpreted correctly, it is still true.
+ - It might use color to represent different objects, such as a green line to represent the flat surface is covered with a felt-like material.
+ - The rotation of the object might cannot be judged from the video, but the motion can be interpreted correctly, it is still true.
+
+3. **Container Boundaries**:
+ - If **no container is drawn**, you may assume the **video frame itself is the container boundary**.
+ - If a **container is visible**, treat it as **transparent** if inner content is visible.
+ - If the object is not visible, you should not assume it is in the container.
+
+4. **Focus on Shape & Position**, **not material**:
+ - Ignore assumptions about object **material**, **color**, or **texture**.
+ - Base your decisions entirely on **observable geometry** (e.g., shape, layout, structure) and **motion** (e.g., direction, trajectory).
+ - Use visible movement and positioning to judge truthfulness — even if the object type is unknown.
+ - If the described motion is **sliding down a slope**, but the video shows an **upward movement**, the result should be `"False"` — regardless of material or appearance.
+ - Make geometric and motion-based reasoning the core of your judgment, even when objects are **partially occluded**.
+
+5. **Occlusion Handling**:
+ - If an object is **partially blocked**, assess based on surrounding evidence whether its state or motion can still be inferred.
+
+6. **Avoid excessive uncertainty**:
+ - If there is enough visual context and logical structure, make a **confident judgment**.
+ - Use "Not sure" only when the evidence is **truly insufficient or ambiguous**.
+
+---
+
+ **Input**:
+- Questions: {all_questions}
+- Object and motion description: {summarize}
+
+---
+
+ **For each question**, return:
+- `"index"`: the question index
+- `"question"`: the full question text
+- `"analysis"`: your reasoning process and visual inference
+- `"result"`: one of `"True"`, `"False"`, or `"Not sure"`
+- `"confidence_score"`: an integer from 1 (very uncertain) to 5 (very certain)
+
+---
+
+**Output Format**:
+Return a JSON list like this:
+[
+ {{
+ "index": "1",
+ "question": "The ball rolls along the circular path.",
+ "analysis": "The object follows a closed curve consistent with a circular path from the top view.",
+ "result": "True",
+ "confidence_score": "5"
+ }},
+ ...
+]
+"""
+
+
+def _format_keywords(text: str,
+ keywords: Optional[List[str]] = None,
+ style: str = "bold",
+ case_sensitive: bool = False) -> str:
+ """Format given keywords in text.
+
+ Args:
+ text: The input text to process.
+ keywords: List of keywords to highlight/format.
+ style: One of "bold" (default), "bracket", or "caps".
+ case_sensitive: Whether to match case-sensitively.
+
+ Returns:
+ Formatted text with keywords highlighted.
+ """
+ if not keywords:
+ return text
+
+ # Deduplicate and sort by length (longest first) to avoid partial overlapping matches
+ unique_keywords = [k for k in sorted(set(keywords), key=lambda s: len(s or ""), reverse=True) if k]
+ if not unique_keywords:
+ return text
+
+ flags = 0 if case_sensitive else re.IGNORECASE
+ pattern = re.compile("|".join(re.escape(k) for k in unique_keywords), flags)
+
+ def repl(match: re.Match) -> str:
+ start, end = match.start(), match.end()
+ # Avoid double-formatting if already bolded ("**word**")
+ if style == "bold":
+ prev2 = text[max(0, start - 2):start]
+ next2 = text[end:end + 2]
+ if prev2 == "**" and next2 == "**":
+ return match.group(0)
+ return f"**{match.group(0)}**"
+ elif style == "bracket":
+ return f"[{match.group(0)}]"
+ elif style == "caps":
+ return match.group(0).upper()
+ else:
+ return f"**{match.group(0)}**"
+
+ return pattern.sub(repl, text)
+
+
+class _SafeDict(dict):
+ def __missing__(self, key):
+ return "{" + key + "}"
+
+
+def create_vlm_prompt_from_template(
+ prompt_template: str,
+ variables: Optional[Dict[str, Any]] = None,
+ keywords: Optional[List[str]] = None,
+ style: str = "bold",
+ case_sensitive: bool = False,
+) -> str:
+ """Create a VLM prompt from a template with optional keyword formatting.
+
+ Args:
+ prompt_template: Template string containing placeholders like {summarize}, {all_questions}.
+ variables: Mapping used to format the template.
+ keywords: Keywords to highlight after formatting.
+ style: Keyword highlight style: "bold" (default), "bracket", or "caps".
+ case_sensitive: Whether keyword matching is case-sensitive.
+
+ Returns:
+ Final prompt string.
+ """
+ text = prompt_template
+ if variables:
+ try:
+ text = prompt_template.format_map(_SafeDict(variables))
+ except Exception:
+ # Fall back to the original template if formatting fails
+ text = prompt_template
+ return _format_keywords(text, keywords=keywords, style=style, case_sensitive=case_sensitive)
+
+
+def create_vlm_prompt_custom(
+ prompt: str,
+ keywords: Optional[List[str]] = None,
+ style: str = "bold",
+ case_sensitive: bool = False,
+) -> str:
+ """Create a VLM prompt from a raw prompt string and optional keywords.
+
+ This function does not inject any default template; it only formats
+ the provided prompt and highlights keywords as requested.
+
+ Args:
+ prompt: The prompt content to send to the VLM.
+ keywords: List of keywords to format/highlight.
+ style: Highlight style: "bold" (default), "bracket", or "caps".
+ case_sensitive: Whether keyword matching is case-sensitive.
+
+ Returns:
+ Final prompt string.
+ """
+ return _format_keywords(prompt, keywords=keywords, style=style, case_sensitive=case_sensitive)
+
+
+def create_vlm_prompt(summarize: str, all_questions: str) -> str:
+ """Create the default VLM prompt (backward-compatible).
+
+ Existing call sites expect (summarize, all_questions). This delegates to
+ a template-based builder to enable future customization.
+ """
+ return create_vlm_prompt_from_template(
+ DEFAULT_VLM_PROMPT_TEMPLATE,
+ variables={"summarize": summarize, "all_questions": all_questions},
+ )
+
+
+def _extract_json_list(output_str: str) -> List[Dict[str, Any]]:
+ """Extract the VLM JSON list from the model output.
+
+ Tries strict JSON first; falls back to extracting the substring between
+ the first '[' and the last ']' and parsing that.
+ """
+ try:
+ parsed = json.loads(output_str)
+ if isinstance(parsed, list):
+ return parsed
+ except Exception:
+ pass
+
+ # Fallback: try to extract JSON list portion
+ start = output_str.find('[')
+ end = output_str.rfind(']')
+ if start != -1 and end != -1 and end > start:
+ try:
+ parsed = json.loads(output_str[start:end+1])
+ if isinstance(parsed, list):
+ return parsed
+ except Exception:
+ pass
+ raise ValueError("Failed to parse VLM JSON list from output")
+
diff --git a/agentfly/rewards/vlm_as_judge/vlm_as_judge_reward.py b/agentfly/rewards/vlm_as_judge/vlm_as_judge_reward.py
new file mode 100644
index 0000000..572fba7
--- /dev/null
+++ b/agentfly/rewards/vlm_as_judge/vlm_as_judge_reward.py
@@ -0,0 +1,633 @@
+"""VLM as Judge Reward Function for AgentFly RL Training"""
+
+import os
+import re
+import json
+import uuid
+import tempfile
+import subprocess
+import asyncio
+import logging
+import concurrent.futures
+from typing import Dict, Any, Optional, List, Tuple
+from pathlib import Path
+
+# Support running both as a package module and as a standalone script
+try:
+ from ..reward_base import reward
+ from .vlm_as_judge_client import VLMClient, create_vlm_prompt, _extract_json_list
+except ImportError: # Running as a script without package context
+ import sys
+ # Add repo root to sys.path so absolute imports work when invoked directly
+ sys.path.append(str(Path(__file__).resolve().parents[2]))
+ from agentfly.rewards.reward_base import reward
+ from agentfly.rewards.vlm_as_judge.vlm_as_judge_client import (
+ VLMClient,
+ create_vlm_prompt,
+ _extract_json_list,
+ create_vlm_prompt_from_template,
+ create_vlm_prompt_custom,
+ DEFAULT_VLM_PROMPT_TEMPLATE,
+ )
+
+logger = logging.getLogger(__name__)
+
+
+class VideoGenerator:
+ """Helper class to generate videos from code"""
+
+ def __init__(self, output_dir: Optional[str] = None):
+ """Initialize video generator
+
+ Args:
+ output_dir: Directory to save generated videos
+ """
+ # Prefer a shared directory accessible by the VLM server if provided.
+ if output_dir is None:
+ output_dir = os.getenv("VLM_SHARED_VIDEO_DIR", "/tmp/vlm_videos")
+ self.output_dir = output_dir
+ os.makedirs(output_dir, exist_ok=True)
+
+ def extract_code_from_response(self, response: str) -> Optional[str]:
+ """Extract Python code from model response
+
+ Args:
+ response: Model response containing code
+
+ Returns:
+ Extracted Python code or None
+ """
+ # Remove tags if present
+ cleaned = re.sub(r'.*?', '', response, flags=re.DOTALL)
+
+ # Extract code from ```python blocks
+ pattern = r"```python\n(.*?)```"
+ matches = re.findall(pattern, cleaned, re.DOTALL)
+
+ if matches:
+ return matches[0]
+ return None
+
+ def generate_video_from_code(self, code: str, output_path: str) -> bool:
+ """Execute Python code to generate video
+
+ Args:
+ code: Python code to execute
+ output_path: Path to save the generated video
+
+ Returns:
+ True if video generation successful, False otherwise
+ """
+ try:
+ # Create a temporary Python file
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
+ # Modify code to use the specified output path
+ modified_code = code
+
+ # Handle sys.argv usage for output filename
+ if 'sys.argv[1]' in code:
+ modified_code = code.replace('sys.argv[1]', f'"{output_path}"')
+ elif 'sys.argv' in code and 'len(sys.argv)' in code:
+ # Add sys.argv mock at the beginning
+ modified_code = f"import sys\nsys.argv = ['script.py', '{output_path}']\n" + code
+ else:
+ # If no sys.argv usage, try to modify output filename assignments
+ # Look for common patterns like output_file = ... or out = cv2.VideoWriter(...)
+ if 'output_file' in code:
+ # Replace output_file assignment
+ modified_code = re.sub(
+ r'output_file\s*=\s*["\'].*?["\']',
+ f'output_file = "{output_path}"',
+ code
+ )
+ elif 'VideoWriter(' in code:
+ # Try to replace the first string argument in VideoWriter
+ modified_code = re.sub(
+ r'VideoWriter\s*\(\s*["\'].*?["\']',
+ f'VideoWriter("{output_path}"',
+ code
+ )
+ else:
+ # Last resort: append output path assignment
+ modified_code = f"output_file = '{output_path}'\n" + code
+
+ f.write(modified_code)
+ temp_file = f.name
+
+ # Execute the code
+ # Always pass the output path as an argument so scripts that expect
+ # sys.argv[1] or check len(sys.argv) continue without exiting.
+ result = subprocess.run(
+ ['python', temp_file, output_path],
+ capture_output=True,
+ text=True,
+ timeout=120, # Increased timeout for video generation
+ cwd=self.output_dir # Run in output directory
+ )
+
+ # Clean up temp file
+ os.unlink(temp_file)
+
+ # Check if video was created and is not empty
+ if os.path.exists(output_path) and os.path.getsize(output_path) > 1000: # At least 1KB
+ logger.info(f"Successfully generated video: {output_path} ({os.path.getsize(output_path)} bytes)")
+ return True
+ else:
+ logger.error(f"Video generation failed or file too small. stderr: {result.stderr}")
+ return False
+
+ except subprocess.TimeoutExpired:
+ logger.error("Video generation timed out")
+ if 'temp_file' in locals():
+ try:
+ os.unlink(temp_file)
+ except:
+ pass
+ return False
+ except Exception as e:
+ logger.error(f"Error generating video: {e}")
+ if 'temp_file' in locals():
+ try:
+ os.unlink(temp_file)
+ except:
+ pass
+ return False
+
+
+def extract_vlm_questions_from_data(data: Dict[str, Any]) -> Tuple[str, str, List[Dict]]:
+ """Extract VLM questions and summary from data
+
+ Args:
+ data: Dictionary containing vlm_questions data
+
+ Returns:
+ Tuple of (all_questions_str, summarize, questions_list)
+ """
+ all_questions = ""
+ summarize = ""
+ questions_list = []
+
+ if "vlm_questions" in data:
+ vlm_data = data["vlm_questions"]
+ if isinstance(vlm_data, dict):
+ # Get summary
+ summarize = vlm_data.get("summarize", "")
+
+ # Extract questions from nested vlm_questions field
+ if "vlm_questions" in vlm_data:
+ questions_list = vlm_data["vlm_questions"]
+ if isinstance(questions_list, list):
+ for q in questions_list:
+ if isinstance(q, dict):
+ idx = q.get("index", "")
+ question = q.get("question", "")
+ all_questions += f"{idx}. {question}\n"
+ else:
+ logger.warning(f"vlm_questions inner field is not a list: {type(questions_list)}")
+ else:
+ logger.warning(f"vlm_questions is not a dict: {type(vlm_data)}")
+ else:
+ logger.warning(f"No vlm_questions field in data. Available fields: {list(data.keys())}")
+
+ all_questions = all_questions.strip()
+
+ if not summarize:
+ summarize = "Evaluate the visual content based on the questions provided."
+
+ logger.info(f"Extracted {len(questions_list)} questions from VLM data")
+
+ return all_questions, summarize, questions_list
+
+
+def calculate_weighted_reward(results: List[Dict], questions_list: List[Dict]) -> float:
+ """Calculate weighted reward based on VLM results and question weights
+
+ Args:
+ results: List of VLM evaluation results
+ questions_list: Original questions with weights
+
+ Returns:
+ Weighted reward score between 0.0 and 1.0
+ """
+ if not results or not questions_list:
+ return 0.0
+
+ # Create weight mapping
+ weight_map = {}
+ for q in questions_list:
+ idx = str(q.get("index", ""))
+ weight = float(q.get("weight", 1.0))
+ weight_map[idx] = weight
+
+ scores = []
+ weights = []
+
+ for result in results:
+ idx = str(result.get("index", ""))
+ result_value = result.get("result", "Not sure")
+ confidence = int(result.get("confidence_score", "1"))
+
+ # Get weight for this question
+ weight = weight_map.get(idx, 1.0)
+
+ # Calculate score based on result
+ if result_value == "True":
+ score = 1.0
+ elif result_value == "False":
+ score = 0.0
+ else: # "Not sure"
+ if confidence >= 4:
+ score = 0.0 # High confidence "Not sure" -> False
+ else:
+ score = 1.0 # Low confidence "Not sure" -> True
+
+ scores.append(score)
+ weights.append(weight)
+
+ # Calculate weighted average
+ # if weights:
+ # weighted_sum = sum(s * w for s, w in zip(scores, weights))
+ # total_weight = sum(weights)
+ # reward = weighted_sum / total_weight if total_weight > 0 else 0.0
+ # else:
+ reward = sum(scores) / len(scores) if scores else 0.0
+
+ return reward
+
+def pass_fail_reward(results: List[Dict], questions_list: List[Dict]) -> float:
+ """Calculate a binary pass/fail score from VLM results.
+
+ Returns 1.0 only when every question is judged as satisfied (or low-confidence
+ "Not sure"), otherwise returns 0.0.
+ """
+ if not results or not questions_list:
+ return 0.0
+
+ result_map = {
+ str(r.get("index", "")).strip(): r
+ for r in results
+ if str(r.get("index", "")).strip()
+ }
+
+ for question in questions_list:
+ idx = str(question.get("index", "")).strip()
+ if not idx:
+ logger.warning("Question without index encountered in pass/fail reward")
+ return 0.0
+
+ result = result_map.get(idx)
+ if result is None:
+ logger.warning("Missing VLM result for question index %s", idx)
+ return 0.0
+
+ result_value = str(result.get("result", "Not sure")).strip().lower()
+ confidence_raw = result.get("confidence_score", "1")
+ try:
+ confidence = int(confidence_raw)
+ except (TypeError, ValueError):
+ confidence = 1
+
+ if result_value == "true":
+ continue
+ if result_value == "not sure" and confidence < 4:
+ continue
+
+ # Any explicit false or high-confidence uncertainty causes failure
+ return 0.0
+
+ return 1.0
+
+@reward(name="vlm_as_judge_pass_reward")
+async def vlm_as_judge_pass_reward(
+ prediction: str,
+ trajectory: Dict[str, Any] = None,
+ vlm_questions: Dict[str, Any] = None,
+ **data_fields
+) -> Dict[str, float]:
+ """VLM as Judge reward function for evaluating agent trajectories
+
+ This reward function:
+ 1. Extracts Python code from the prediction
+ 2. Generates a video using the code
+ 3. Uses VLM server to evaluate the video against provided questions
+ 4. Returns a binary pass/fail score based on VLM judgments
+
+ Args:
+ prediction: Agent's generated response (should contain Python code)
+ trajectory: Agent trajectory information
+ **data_fields: Additional data fields from the RL data, including vlm_questions
+
+ Returns:
+ pass/fail reward score between 0.0 and 1.0
+ """
+ try:
+ # Log incoming data for debugging
+ logger.info(f"=" * 60)
+ logger.info(f"vlm_as_judge_reward called")
+ logger.info(f"Prediction length: {len(prediction) if prediction else 0}")
+
+ # Print the actual prediction content
+ logger.info(f"Prediction content (first 500 chars):")
+ logger.info(f"{prediction[:500] if prediction else 'No prediction'}")
+ if prediction and len(prediction) > 500:
+ logger.info(f"... (truncated, total length: {len(prediction)} chars)")
+
+ logger.info(f"vlm_questions parameter: {vlm_questions is not None}")
+ logger.info(f"Additional data_fields keys: {list(data_fields.keys())}")
+
+ # Initialize video generator
+ video_gen = VideoGenerator()
+
+ # Combine vlm_questions with data_fields for extraction
+ all_data = dict(data_fields)
+ if vlm_questions is not None:
+ all_data['vlm_questions'] = vlm_questions
+ logger.info(f"vlm_questions type: {type(vlm_questions)}")
+ if isinstance(vlm_questions, dict):
+ logger.info(f"vlm_questions keys: {vlm_questions.keys()}")
+ if 'vlm_questions' in vlm_questions:
+ inner_vlm = vlm_questions['vlm_questions']
+ logger.info(f"Inner vlm_questions type: {type(inner_vlm)}")
+ if isinstance(inner_vlm, list):
+ logger.info(f"Number of questions in inner list: {len(inner_vlm)}")
+
+ # Extract VLM questions from data
+ all_questions, summarize, questions_list = extract_vlm_questions_from_data(all_data)
+
+ if not questions_list:
+ logger.warning(f"No VLM questions found in data. Available fields: {list(all_data.keys())}")
+ return {"reward": 0.0}
+
+ # Extract code from prediction
+ code = video_gen.extract_code_from_response(prediction)
+ if not code:
+ logger.warning("No Python code found in prediction")
+ logger.warning(f"Prediction was: {prediction[:1000] if prediction else 'None'}")
+ return {"reward": 0.0}
+
+ logger.info(f"Extracted Python code ({len(code)} chars)")
+ logger.info(f"Code preview (first 300 chars):")
+ logger.info(f"{code[:300]}...")
+ if len(code) > 300:
+ logger.info(f"... (truncated, total length: {len(code)} chars)")
+
+ # Generate unique video filename
+ video_filename = f"video_{uuid.uuid4().hex}.mp4"
+ video_path = os.path.join(video_gen.output_dir, video_filename)
+
+ # Generate video from code
+ success = video_gen.generate_video_from_code(code, video_path)
+ if not success:
+ logger.error("Failed to generate video from code")
+ return {"reward": 0.0}
+
+ # Run VLM evaluation directly since we're already async
+ client = VLMClient(
+ model="Qwen/Qwen2.5-VL-72B-Instruct",
+ timeout_seconds=120
+ )
+
+ # Wait for client availability
+ for _ in range(10):
+ if client.is_available():
+ break
+ await asyncio.sleep(1)
+ else:
+ logger.error("VLM client not available")
+ return {"reward": 0.0}
+
+ # Create VLM prompt
+ prompt_text = create_vlm_prompt(summarize, all_questions)
+
+ # Build message using ' in content %}{% set content = content.split('')[-1] %}{% endif %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|>\\n'}}{% endif %}"
)
)
+register_template(
+ Template(
+ name="llemma",
+ system_template="{system_message}",
+ user_template="Input:{content}\n\n",
+ assistant_template="Response:{content}",
+ stop_words=[""]
+ )
+)
+
+
if __name__ == "__main__":
pass
\ No newline at end of file
diff --git a/agentfly/templates/utils.py b/agentfly/templates/utils.py
index 2761822..75906ac 100644
--- a/agentfly/templates/utils.py
+++ b/agentfly/templates/utils.py
@@ -270,7 +270,7 @@ def tokenize_conversations(
concatenated_mm_inputs = {}
if concatenate_mm_inputs:
for key in batch_mm_inputs[0].keys():
- if mm_inputs[key]:
+ if isinstance(mm_inputs[key], torch.Tensor):
concatenated_mm_inputs[key] = torch.cat([mm_inputs[key] for mm_inputs in batch_mm_inputs if mm_inputs[key] is not None], dim=0)
inputs = dict(
diff --git a/agentfly/tests/scripts/test_cpu_runs.sh b/agentfly/tests/scripts/test_cpu_runs.sh
new file mode 100644
index 0000000..c3f0fc9
--- /dev/null
+++ b/agentfly/tests/scripts/test_cpu_runs.sh
@@ -0,0 +1,10 @@
+#! /bin/bash
+
+# Test CPU runs
+
+
+pytest -x agentfly/tests/unit/tools/
+pytest -x agentfly/tests/unit/envs/
+pytest -x agentfly/tests/unit/rewards/
+
+pytest -x agentfly/tests/unit/templates/
\ No newline at end of file
diff --git a/agentfly/tests/scripts/test_gpu_runs.sh b/agentfly/tests/scripts/test_gpu_runs.sh
index e033a51..4f10934 100644
--- a/agentfly/tests/scripts/test_gpu_runs.sh
+++ b/agentfly/tests/scripts/test_gpu_runs.sh
@@ -2,9 +2,9 @@
# Test GPU runs
-python -m pytest -x tests/unit/agents/test_initialization.py || exit 1
-python -m pytest -x tests/unit/agents/test_auto_agent.py || exit 1
-python -m pytest -x tests/unit/agents/test_code_agent.py || exit 1
-python -m pytest -x tests/unit/agents/test_react_agent.py || exit 1
-python -m pytest -x tests/unit/agents/test_webshop_agent.py || exit 1
-python -m pytest -x tests/unit/agents/test_vision_agent.py || exit 1
\ No newline at end of file
+pytest -x agentfly/tests/unit/agents/test_initialization.py || exit 1
+pytest -x agentfly/tests/unit/agents/test_auto_agent.py || exit 1
+pytest -x agentfly/tests/unit/agents/test_code_agent.py || exit 1
+pytest -x agentfly/tests/unit/agents/test_react_agent.py || exit 1
+pytest -x agentfly/tests/unit/agents/test_webshop_agent.py || exit 1
+pytest -x agentfly/tests/unit/agents/test_vision_agent.py || exit 1
\ No newline at end of file
diff --git a/agentfly/tests/unit/agents/backends/__init__.py b/agentfly/tests/unit/agents/backends/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/agentfly/tests/unit/agents/backends/test_client_backend.py b/agentfly/tests/unit/agents/backends/test_client_backend.py
new file mode 100644
index 0000000..228decb
--- /dev/null
+++ b/agentfly/tests/unit/agents/backends/test_client_backend.py
@@ -0,0 +1,318 @@
+#!/usr/bin/env python3
+# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
+# SPDX-License-Identifier: Apache-2.0
+
+import asyncio
+import time
+import pytest
+import threading
+from unittest.mock import Mock, patch, AsyncMock, MagicMock
+from typing import List, Dict, Any
+import statistics
+import sys
+import os
+
+# Add the parent directory to the Python path
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
+
+from agentfly.agents.llm_backends.llm_backends import ClientBackend
+
+
+class TestClientBackendWorkload:
+ """Test suite for ClientBackend workload and rate limiting functionality."""
+
+ @pytest.fixture
+ def mock_openai_client(self):
+ """Mock OpenAI client for testing."""
+ mock_client = Mock()
+ mock_response = Mock()
+ mock_response.dict.return_value = {
+ "choices": [
+ {
+ "message": {
+ "content": "Test response",
+ "tool_calls": None
+ }
+ }
+ ]
+ }
+ mock_client.chat.completions.create.return_value = mock_response
+ return mock_client
+
+ @pytest.fixture
+ def client_backend(self, mock_openai_client):
+ """Create a ClientBackend instance for testing."""
+ with patch('openai.OpenAI', return_value=mock_openai_client):
+ backend = ClientBackend(
+ model_name_or_path="test-model",
+ template="test-template",
+ base_url="http://localhost:8000/v1",
+ max_requests_per_minute=10, # Low limit for testing
+ timeout=30,
+ api_key="test-key"
+ )
+ return backend
+
+ def test_basic_functionality(self, client_backend, mock_openai_client):
+ """Test basic ClientBackend functionality."""
+ messages = [{"role": "user", "content": "Hello"}]
+
+ # Test sync generation
+ response = client_backend.generate(messages)
+
+ assert isinstance(response, list)
+ assert len(response) == 1
+ assert response[0] == "Test response"
+ mock_openai_client.chat.completions.create.assert_called_once()
+
+ @pytest.mark.asyncio
+ async def test_async_generation(self, client_backend, mock_openai_client):
+ """Test async generation functionality."""
+ messages = [{"role": "user", "content": "Hello"}]
+
+ response = await client_backend.generate_async(messages)
+
+ assert isinstance(response, list)
+ assert len(response) == 1
+ assert response[0] == "Test response"
+
+ def test_rate_limiting_basic(self, client_backend):
+ """Test basic rate limiting functionality."""
+ # Verify semaphore is initialized correctly
+ assert client_backend._tokens._value == 10 # max_requests_per_minute
+ assert client_backend._max_tokens == 10
+
+ @pytest.mark.asyncio
+ async def test_rate_limiting_under_load(self, client_backend, mock_openai_client):
+ """Test rate limiting under high load with 100 concurrent requests."""
+ # Configure mock to simulate API delay
+ def mock_api_call(*args, **kwargs):
+ time.sleep(1) # 1s delay per request
+ return Mock(dict=lambda: {
+ "choices": [{"message": {"content": "Test response", "tool_calls": None}}]
+ })
+
+ mock_openai_client.chat.completions.create = mock_api_call
+
+ # Create 100 concurrent requests (10x the rate limit)
+ num_requests = 100
+ messages = [{"role": "user", "content": f"Request {i}"} for i in range(num_requests)]
+
+ start_time = time.time()
+
+ # Send all requests concurrently
+ tasks = [client_backend.generate_async([msg]) for msg in messages]
+ responses = await asyncio.gather(*tasks)
+
+ total_time = time.time() - start_time
+
+ # Verify all requests completed
+ assert len(responses) == num_requests
+ assert all(isinstance(r, list) and len(r) == 1 for r in responses)
+
+ # Verify rate limiting worked (should take longer than if unlimited)
+ # With 10 RPM limit, 100 requests should take at least 6 minutes in theory
+ # But with our 1s delay, it should take at least 10 seconds
+ assert total_time >= 10 # At least 10s due to our mock delay
+
+ print(f"Rate limiting test: {num_requests} requests completed in {total_time:.2f}s")
+ print(f"Effective rate: {num_requests/total_time:.2f} requests/second")
+
+
+ @pytest.mark.asyncio
+ async def test_concurrent_burst_requests(self, client_backend, mock_openai_client):
+ """Test handling of burst requests that exceed the rate limit."""
+ # Configure mock with realistic delay
+ async def mock_api_call(*args, **kwargs):
+ await asyncio.sleep(0.05) # 50ms delay per request
+ return Mock(dict=lambda: {
+ "choices": [{"message": {"content": "Burst response", "tool_calls": None}}]
+ })
+
+ mock_openai_client.chat.completions.create = mock_api_call
+
+ # Send burst of 20 requests (2x the rate limit)
+ burst_size = 20
+ messages = [{"role": "user", "content": f"Burst {i}"} for i in range(burst_size)]
+
+ start_time = time.time()
+
+ # Send all requests at once
+ tasks = [client_backend.generate_async([msg]) for msg in messages]
+ responses = await asyncio.gather(*tasks)
+
+ total_time = time.time() - start_time
+
+ # Verify all requests completed
+ assert len(responses) == burst_size
+ assert all("Burst response" in r[0] for r in responses)
+
+ # Verify rate limiting was applied
+ # Should take longer than 20 * 0.05 = 1 second due to rate limiting
+ assert total_time >= 0.5 # At least 500ms due to rate limiting
+
+ print(f"Burst test: {burst_size} requests completed in {total_time:.2f}s")
+
+ @pytest.mark.asyncio
+ async def test_sustained_load_performance(self, client_backend, mock_openai_client):
+ """Test performance under sustained load over time."""
+ # Configure mock with realistic delay
+ async def mock_api_call(*args, **kwargs):
+ await asyncio.sleep(0.02) # 20ms delay per request
+ return Mock(dict=lambda: {
+ "choices": [{"message": {"content": "Sustained response", "tool_calls": None}}]
+ })
+
+ mock_openai_client.chat.completions.create = mock_api_call
+
+ # Send requests in batches over time
+ batch_size = 5
+ num_batches = 4
+ total_requests = batch_size * num_batches
+
+ all_responses = []
+ batch_times = []
+
+ for batch in range(num_batches):
+ batch_start = time.time()
+
+ messages = [{"role": "user", "content": f"Sustained batch {batch} req {i}"}
+ for i in range(batch_size)]
+
+ tasks = [client_backend.generate_async([msg]) for msg in messages]
+ batch_responses = await asyncio.gather(*tasks)
+
+ batch_time = time.time() - batch_start
+ batch_times.append(batch_time)
+ all_responses.extend(batch_responses)
+
+ # Small delay between batches
+ await asyncio.sleep(0.1)
+
+ # Verify all requests completed
+ assert len(all_responses) == total_requests
+ assert all("Sustained response" in r[0] for r in all_responses)
+
+ # Analyze performance
+ avg_batch_time = statistics.mean(batch_times)
+ total_time = sum(batch_times)
+
+ print(f"Sustained load test: {total_requests} requests in {num_batches} batches")
+ print(f"Average batch time: {avg_batch_time:.3f}s")
+ print(f"Total time: {total_time:.3f}s")
+ print(f"Effective rate: {total_requests/total_time:.2f} requests/second")
+
+
+ @pytest.mark.asyncio
+ async def test_mixed_sync_async_workload(self, client_backend, mock_openai_client):
+ """Test mixed sync and async calls under load."""
+ # Configure mock
+ async def mock_api_call(*args, **kwargs):
+ await asyncio.sleep(0.01)
+ return Mock(dict=lambda: {
+ "choices": [{"message": {"content": "Mixed response", "tool_calls": None}}]
+ })
+
+ mock_openai_client.chat.completions.create = mock_api_call
+
+ # Mix of sync and async calls
+ sync_messages = [{"role": "user", "content": f"Sync {i}"} for i in range(5)]
+ async_messages = [{"role": "user", "content": f"Async {i}"} for i in range(5)]
+
+ # Run sync calls in a separate thread to avoid blocking
+ def run_sync_calls():
+ return [client_backend.generate([msg]) for msg in sync_messages]
+
+ # Execute sync calls in thread pool
+ loop = asyncio.get_running_loop()
+ sync_task = loop.run_in_executor(None, run_sync_calls)
+
+ # Execute async calls
+ async_tasks = [client_backend.generate_async([msg]) for msg in async_messages]
+ async_results = await asyncio.gather(*async_tasks)
+
+ # Wait for sync calls to complete
+ sync_results = await sync_task
+
+ # Verify all calls completed
+ assert len(sync_results) == 5
+ assert len(async_results) == 5
+
+ # All should be successful
+ all_results = sync_results + async_results
+ assert all("Mixed response" in r[0] for r in all_results)
+
+ print(f"Mixed workload test: {len(all_results)} total requests completed")
+
+ def test_refiller_startup_edge_cases(self, client_backend):
+ """Test refiller startup in different contexts."""
+ # Test startup when no event loop exists
+ with patch('asyncio.get_event_loop', side_effect=RuntimeError("No event loop")):
+ # Should not crash
+ client_backend._ensure_refiller_running()
+ assert client_backend._refill_task is None
+
+ # Test startup when event loop exists but is not running
+ mock_loop = Mock()
+ mock_loop.is_running.return_value = False
+ mock_loop.create_task.return_value = Mock()
+
+ with patch('asyncio.get_event_loop', return_value=mock_loop):
+ client_backend._ensure_refiller_running()
+ mock_loop.create_task.assert_called_once()
+
+ @pytest.mark.asyncio
+ async def test_large_scale_workload(self, client_backend, mock_openai_client):
+ """Test with a large number of requests (1000) to verify system stability."""
+ # Configure mock with minimal delay
+ async def mock_api_call(*args, **kwargs):
+ await asyncio.sleep(0.001) # 1ms delay per request
+ return Mock(dict=lambda: {
+ "choices": [{"message": {"content": "Large scale response", "tool_calls": None}}]
+ })
+
+ mock_openai_client.chat.completions.create = mock_api_call
+
+ # Create 1000 requests
+ num_requests = 1000
+ messages = [{"role": "user", "content": f"Large scale {i}"} for i in range(num_requests)]
+
+ start_time = time.time()
+
+ # Send all requests concurrently
+ tasks = [client_backend.generate_async([msg]) for msg in messages]
+ responses = await asyncio.gather(*tasks)
+
+ total_time = time.time() - start_time
+
+ # Verify all requests completed
+ assert len(responses) == num_requests
+ assert all("Large scale response" in r[0] for r in responses)
+
+ print(f"Large scale test: {num_requests} requests completed in {total_time:.2f}s")
+ print(f"Effective rate: {num_requests/total_time:.2f} requests/second")
+
+ # Verify rate limiting was applied (should be limited to ~10 RPM = 0.167 RPS)
+ # With 1000 requests at 0.167 RPS, should take at least 6000 seconds
+ # But our test is more lenient due to the mock delay
+ assert total_time >= 1.0 # At least 1 second due to rate limiting
+
+ def test_workload_with_real_api(self):
+ client_backend = ClientBackend(
+ model_name_or_path="deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
+ template="deepseek",
+ base_url="http://localhost:8000/v1",
+ max_requests_per_minute=60,
+ timeout=300,
+ api_key="EMPTY"
+ )
+ messages = [[{"role": "user", "content": "Let $x,y$ and $z$ be positive real numbers that satisfy the following system of equations:\n\[\log_2\left({x \over yz}\right) = {1 \over 2}\]\n\[\log_2\left({y \over xz}\right) = {1 \over 3}\]\n\[\log_2\left({z \over xy}\right) = {1 \over 4}\]\nThen the value of $\left|\log_2(x^4y^3z^2)\right|$ is $\tfrac{m}{n}$ where $m$ and $n$ are relatively prime positive integers. Find $m+n$."}]] * 100
+ time_start = time.time()
+ response = client_backend.generate(messages)
+ time_end = time.time()
+
+ assert len(response) == 100
+ print(f"Time taken: {time_end - time_start} seconds")
+ print(f"Effective rate: {100/(time_end - time_start)} requests/second")
+
+
\ No newline at end of file
diff --git a/agentfly/tests/unit/envs/test_webshop_text_env.py b/agentfly/tests/unit/envs/test_webshop_text_env.py
index 8b00a5b..bf3aa1a 100644
--- a/agentfly/tests/unit/envs/test_webshop_text_env.py
+++ b/agentfly/tests/unit/envs/test_webshop_text_env.py
@@ -86,7 +86,7 @@ async def test_env_full_shopping_flow():
assert 'observation' in observation
assert 'reward' in observation
- await env.close()
+ await env.aclose()
# @pytest.mark.asyncio
# async def test_pagination_navigation():
diff --git a/agentfly/tests/unit/rewards/test_vlm_as_judge_reward.py b/agentfly/tests/unit/rewards/test_vlm_as_judge_reward.py
new file mode 100644
index 0000000..0cf2dd0
--- /dev/null
+++ b/agentfly/tests/unit/rewards/test_vlm_as_judge_reward.py
@@ -0,0 +1,282 @@
+import sys
+import os
+from ....rewards.vlm_as_judge.vlm_as_judge_client import VLMClient, create_vlm_prompt, _extract_json_list
+from ....rewards.vlm_as_judge.vlm_as_judge_reward import VideoGenerator, extract_vlm_questions_from_data, calculate_weighted_reward, pass_fail_reward
+from ....rewards.vlm_as_judge.vlm_as_judge_reward import vlm_as_judge_pass_reward
+from pathlib import Path
+import asyncio
+import sys
+from pathlib import Path
+import traceback
+import time
+import json
+import os
+import re
+import subprocess
+import tempfile
+import uuid
+import warnings
+from typing import Dict, List, Optional, Tuple, Any
+
+
+if __name__ == "__main__":
+ """Test VLM client functionality"""
+ import asyncio
+
+ async def test_client():
+ print("="*70)
+ print("Testing VLM Client")
+ print("="*70)
+
+ # Test data
+ test_questions = {
+ "vlm_questions": {
+ "summarize": "A ball rolls down a ramp",
+ "vlm_questions": [
+ {"index": "1", "question": "A ball is visible", "weight": 1.0},
+ {"index": "2", "question": "The ball moves downward", "weight": 1.0}
+ ]
+ }
+ }
+
+ try:
+ # Test client initialization
+ client = VLMClient(
+ model="Qwen/Qwen2.5-VL-72B-Instruct",
+ timeout_seconds=60
+ )
+ print(f"✓ Client initialized")
+
+ # Check availability
+ is_available = client.is_available()
+ print(f"✓ Client available: {is_available}")
+
+ # Test prompt creation
+ all_q = "1. A ball is visible\n2. The ball moves downward"
+ prompt = create_vlm_prompt("A ball rolls down a ramp", all_q)
+ print(f"✓ Prompt created ({len(prompt)} chars)")
+
+ # Test JSON extraction
+ test_response = '''[{"index": "1", "result": "True", "confidence_score": "5"}]'''
+ results = _extract_json_list(test_response)
+ print(f"✓ JSON extraction works: {len(results)} results")
+
+ print("\nAll tests passed!")
+
+ except Exception as e:
+ print(f"✗ Test failed: {e}")
+ import traceback
+ traceback.print_exc()
+
+ asyncio.run(test_client())
+
+
+ """Test VLM-as-judge reward function"""
+ import sys
+
+ # Test data - real physics example with charged sphere
+ test_data = {
+ "question": "A charged 0.6 kg aluminum sphere is placed at the center of a 1.5-meter by 1-meter by 1-meter glass tank filled with air at 25°C and 1 atm. The tank is horizontally divided into two equal compartments by a non-conductive partition. When a 450 N/C vertical electric field is applied, the sphere rises, clears the partition, and settles in the upper compartment over 13 seconds, as the field balances the gravitational force.",
+ "Level": 3,
+ "vlm_questions": {
+ "enableAnnotator": "Yes",
+ "summarize": "A charged 0.6 kg aluminum sphere is placed at the center of a 1.5m x 1m x 1m glass tank filled with air at 25°C and 1 atm. The tank is divided by a non-conductive partition. A 450 N/C vertical electric field causes the sphere to rise, clear the partition, and settle in the upper compartment, balancing gravitational force over 13 seconds.",
+ "vlm_questions": [
+ {
+ "index": "1",
+ "question": "A non-conductive partition divides the tank horizontally into two equal compartments.",
+ "weight": 1.0
+ },
+ {
+ "index": "2",
+ "question": "The sphere is initially placed at the center of the tank.",
+ "weight": 1.0
+ },
+ {
+ "index": "3",
+ "question": "The sphere rises vertically when the electric field is applied.",
+ "weight": 1.0
+ },
+ {
+ "index": "4",
+ "question": "The sphere clears the partition and enters the upper compartment.",
+ "weight": 1.0
+ },
+ {
+ "index": "5",
+ "question": "The sphere settles in the upper compartment after moving.",
+ "weight": 1.0
+ }
+ ]
+ }
+ }
+
+ # Sample physics simulation code
+ sample_code = '''
+import sys
+import subprocess
+import importlib
+
+required_libraries = ['cv2', 'numpy']
+for lib in required_libraries:
+ try:
+ importlib.import_module(lib)
+ except ImportError:
+ subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'opencv-python', 'numpy'])
+ break
+
+import cv2
+import numpy as np
+
+if len(sys.argv) < 2:
+ print("Usage: python script.py output_filename.mp4")
+ sys.exit(1)
+
+output_file = sys.argv[1]
+
+# Physical parameters
+tank_length = 1.5
+tank_width = 1.0
+tank_height = 1.0
+partition_height = 0.5
+initial_z = 0.25
+final_z = 0.75
+total_time = 13.0
+mass = 0.6
+gravity = 9.8
+E_field = 450.0
+charge = (mass * gravity) / E_field
+
+# Video parameters
+fps = 30
+width, height = 1280, 720
+num_frames = int(total_time * fps)
+fourcc = cv2.VideoWriter_fourcc(*'mp4v')
+out = cv2.VideoWriter(output_file, fourcc, fps, (width, height))
+
+# Scaling factors for visualization
+margin = 50
+x_scale = (width - 2 * margin) / tank_length
+z_scale = (height - 2 * margin) / tank_height
+sphere_radius_px = 15
+force_scale = 30
+
+def world_to_pixel(x, z):
+ px = int(margin + x * x_scale)
+ pz = int(height - margin - z * z_scale)
+ return px, pz
+
+for frame_idx in range(num_frames):
+ t = frame_idx / fps
+ progress = min(1.0, t / total_time)
+ current_z = initial_z + (final_z - initial_z) * progress
+ current_pos = [tank_length/2, tank_width/2, current_z]
+ velocity = (final_z - initial_z) / total_time
+
+ # Create white background
+ img = np.ones((height, width, 3), dtype=np.uint8) * 255
+
+ # Draw tank
+ tank_tl = world_to_pixel(0, tank_height)
+ tank_br = world_to_pixel(tank_length, 0)
+ cv2.rectangle(img, tank_tl, tank_br, (200, 200, 255), 2)
+
+ # Draw partition
+ part_start = world_to_pixel(0, partition_height)
+ part_end = world_to_pixel(tank_length, partition_height)
+ cv2.line(img, part_start, part_end, (100, 100, 100), 2)
+
+ # Draw sphere
+ sphere_pos = world_to_pixel(tank_length/2, current_z)
+ cv2.circle(img, sphere_pos, sphere_radius_px, (0, 0, 255), -1)
+
+ # Draw force vectors
+ g_vector_end = (sphere_pos[0], sphere_pos[1] + force_scale)
+ cv2.arrowedLine(img, sphere_pos, g_vector_end, (0, 150, 0), 2, tipLength=0.3)
+
+ e_vector_end = (sphere_pos[0], sphere_pos[1] - force_scale)
+ cv2.arrowedLine(img, sphere_pos, e_vector_end, (255, 0, 0), 2, tipLength=0.3)
+
+ # Draw text overlays
+ cv2.putText(img, f"Time: {t:.2f}s / {total_time}s", (20, 40),
+ cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 0), 2)
+ cv2.putText(img, f"Mass: {mass} kg", (width-300, 40),
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2)
+ cv2.putText(img, f"Gravity: {gravity} m/s^2", (width-300, 80),
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2)
+ cv2.putText(img, f"E-Field: {E_field} N/C", (width-300, 120),
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2)
+ cv2.putText(img, f"Charge: {charge:.5f} C", (width-300, 160),
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2)
+ cv2.putText(img, f"Velocity: {velocity:.4f} m/s", (width-300, 200),
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2)
+ cv2.putText(img, f"Position: (0.75, 0.50, {current_z:.2f}) m", (width-300, 240),
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2)
+
+ # Write frame
+ out.write(img)
+
+out.release()
+'''
+
+ print("="*70)
+ print("Testing VLM-as-Judge Reward")
+ print("="*70)
+
+ # Test video generator
+ print("\n1. Testing VideoGenerator...")
+ gen = VideoGenerator()
+
+ # Test code extraction
+ test_response = f"```python\n{sample_code}\n```"
+ code = gen.extract_code_from_response(test_response)
+ print(f" ✓ Extracted {len(code) if code else 0} chars of code")
+
+ # Test question extraction
+ print("\n2. Testing question extraction...")
+ all_q, summary, q_list = extract_vlm_questions_from_data(test_data)
+ print(f" ✓ Extracted {len(q_list)} questions")
+ print(f" ✓ Summary: {summary[:50]}...")
+
+ # Test reward calculation
+ print("\n3. Testing reward calculation...")
+ test_results = [
+ {"index": "1", "result": "True", "confidence_score": "5"},
+ {"index": "2", "result": "True", "confidence_score": "4"},
+ {"index": "3", "result": "False", "confidence_score": "3"}
+ ]
+ # reward = calculate_weighted_reward(test_results, q_list)
+ reward = pass_fail_reward(test_results, q_list)
+ print(f" ✓ Calculated reward: {reward:.3f}")
+
+ print("\n4. Testing full reward function...")
+ print(" Note: This requires VLM server to be running")
+
+ async def test_reward():
+ """Async wrapper for testing the reward function"""
+ try:
+ # Test with physics simulation prediction including think tags
+ prediction_with_think = f"\n{test_data.get('think', 'Analyzing the physics problem...')}\n\n```python\n{sample_code}\n```"
+
+ # Alternative: Test with just code
+ prediction = f"```python\n{sample_code}\n```"
+
+ reward_value = await vlm_as_judge_pass_reward(
+ prediction=prediction,
+ trajectory={},
+ **test_data
+ )
+ print(f" ✓ Reward function returned: {reward_value}")
+ return reward_value
+ except Exception as e:
+ print(f" ⚠ Reward function test failed (expected if VLM server not running)")
+ print(f" Error: {e}")
+ return None
+
+ # Run the async test
+ import asyncio
+ result = asyncio.run(test_reward())
+
+ print("\nTest complete!")
+ if result:
+ print(f"Final reward score: {result.get('reward', 0.0):.3f}")
\ No newline at end of file
diff --git a/agentfly/tests/unit/templates/__init__.py b/agentfly/tests/unit/templates/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/agentfly/tests/unit/agents/templates/test_qwen3_prompt.py b/agentfly/tests/unit/templates/test_qwen3_prompt.py
similarity index 98%
rename from agentfly/tests/unit/agents/templates/test_qwen3_prompt.py
rename to agentfly/tests/unit/templates/test_qwen3_prompt.py
index 3c3faae..7eb6b3e 100644
--- a/agentfly/tests/unit/agents/templates/test_qwen3_prompt.py
+++ b/agentfly/tests/unit/templates/test_qwen3_prompt.py
@@ -1,4 +1,4 @@
-from .....templates.utils import compare_hf_template
+from ....templates.utils import compare_hf_template
import pytest
from transformers import AutoTokenizer
diff --git a/agentfly/tests/unit/agents/templates/test_qwen3_tokenize.py b/agentfly/tests/unit/templates/test_qwen3_tokenize.py
similarity index 97%
rename from agentfly/tests/unit/agents/templates/test_qwen3_tokenize.py
rename to agentfly/tests/unit/templates/test_qwen3_tokenize.py
index fb47144..e006a05 100644
--- a/agentfly/tests/unit/agents/templates/test_qwen3_tokenize.py
+++ b/agentfly/tests/unit/templates/test_qwen3_tokenize.py
@@ -7,11 +7,11 @@
Since the align for textual prompt is already tested in other files, we only need to test the tokenization of the templates.
"""
-from .....templates.utils import tokenize_conversation
+from ....templates.utils import tokenize_conversation
import pytest
from transformers import AutoTokenizer
import torch
-from .....templates.templates import Chat
+from ....templates.templates import Chat
@pytest.mark.parametrize("template", ["qwen3"])
@pytest.mark.parametrize("messages", [
diff --git a/agentfly/tests/unit/templates/test_single_turn_template_tokenize.py b/agentfly/tests/unit/templates/test_single_turn_template_tokenize.py
new file mode 100644
index 0000000..9bf1c63
--- /dev/null
+++ b/agentfly/tests/unit/templates/test_single_turn_template_tokenize.py
@@ -0,0 +1,38 @@
+from ....templates.utils import tokenize_conversation
+import pytest
+from transformers import AutoTokenizer
+import torch
+from ....templates.templates import Chat
+
+@pytest.mark.parametrize("template", ["deepseek-r1-distill-qwen"])
+@pytest.mark.parametrize("messages", [
+ [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": "Hello, how are you?"},
+ {"role": "assistant", "content": "I am fine, thank you."},
+ ],
+ [
+ {"role": "user", "content": "Hello, how are you?"},
+ {"role": "assistant", "content": " This is test thinking content. I am fine, thank you."},
+ ],
+])
+@pytest.mark.parametrize("tools", [
+ None,
+])
+@pytest.mark.parametrize("add_generation_prompt", [False, True])
+def test_template_tokenize(template, messages, tools, add_generation_prompt):
+ template_tokenizer_mapping = {
+ "qwen2.5": "Qwen/Qwen2.5-3B-Instruct",
+ "llama-3.2": "meta-llama/Llama-3.2-3B-Instruct",
+ "deepseek-r1-distill-qwen": "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
+ }
+ tokenizer = AutoTokenizer.from_pretrained(template_tokenizer_mapping[template], trust_remote_code=True)
+
+ chat = Chat(template, messages, tools=tools)
+ prompt = chat.prompt(add_generation_prompt=add_generation_prompt, tools=tools)
+
+ hf_inputs = tokenizer(prompt, return_tensors="pt")
+
+ implemented_inputs = tokenize_conversation(messages, tokenizer, template, max_length=2048, tools=tools, add_generation_prompt=add_generation_prompt, return_tensors="pt")
+
+ assert torch.equal(hf_inputs["input_ids"], implemented_inputs["input_ids"]), f"template: {template}\n\nmessages: {messages}\n\ntools: {tools}\n\nadd_generation_prompt: {add_generation_prompt}\n\nprompt: {prompt}\n\nimplemented_prompt: {tokenizer.decode(implemented_inputs['input_ids'][0])}\n\nhf_inputs: {hf_inputs}\n\nimplemented_inputs: {implemented_inputs}"
\ No newline at end of file
diff --git a/agentfly/tests/unit/templates/test_single_turn_templates.py b/agentfly/tests/unit/templates/test_single_turn_templates.py
new file mode 100644
index 0000000..e968ef6
--- /dev/null
+++ b/agentfly/tests/unit/templates/test_single_turn_templates.py
@@ -0,0 +1,54 @@
+from ....templates import compare_hf_template
+from transformers import AutoTokenizer
+import pytest
+
+
+# "qwen2.5-think", "qwen2.5", "qwen2.5-no-tool",
+# "llama-3.2", "mistral", "glm-4", "internlm2.5", "phi-3.5", "phi-4"
+@pytest.mark.parametrize("template", ["deepseek-r1-distill-qwen"])
+@pytest.mark.parametrize("messages", [
+ [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": "Hello, how are you?"},
+ {"role": "assistant", "content": "I am fine, thank you."},
+ ],
+ [
+ {"role": "user", "content": "Hello, how are you?"},
+ {"role": "assistant", "content": " This is test thinking content. I am fine, thank you."},
+ ],
+ [
+ {"role": "user", "content": "Hello, how are you?"},
+ ],
+])
+@pytest.mark.parametrize("tools", [
+ None,
+])
+@pytest.mark.parametrize("add_generation_prompt", [True, False])
+def test_chat_template_equal(template, messages, tools, add_generation_prompt):
+ # Filter invalid combinations
+ if add_generation_prompt and messages[-1]['role'] == 'assistant':
+ return
+
+
+ template_tokenizer_mapping = {
+ "qwen2.5": "Qwen/Qwen2.5-3B-Instruct",
+ "qwen2.5-think": "Qwen/Qwen2.5-3B-Instruct",
+ "qwen2.5-no-system-tool": "Qwen/Qwen2.5-3B-Instruct",
+ "deepseek-prover-v2": "deepseek-ai/DeepSeek-Prover-V2-7B",
+ "llama-3.2": "meta-llama/Llama-3.2-3B-Instruct",
+ "mistral": "mistralai/Mistral-7B-Instruct-v0.3",
+ "glm-4": "THUDM/glm-4-9b-chat",
+ "internlm2.5": "internlm/internlm2_5-7b-chat",
+ "phi-3.5": "microsoft/Phi-3.5-mini-instruct",
+ "phi-4": "microsoft/Phi-4",
+ "nemotron": "nvidia/Llama-3.1-Nemotron-Nano-8B-v1",
+ "deepseek-r1-distill-qwen": "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
+ }
+ tokenizer = AutoTokenizer.from_pretrained(template_tokenizer_mapping[template], trust_remote_code=True)
+
+ is_equal, is_equal_between_implemented_prompts, is_equal_between_jinja_prompts, official_prompt, implemented_prompt, implemented_jinja_prompt, highlighted_prompt = compare_hf_template(tokenizer, template, messages=messages, tools=tools,add_generation_prompt=add_generation_prompt)
+
+ print(f"Official prompt:\n\n{official_prompt}")
+ print(f"Highlighted prompt:\n\n{highlighted_prompt}")
+ assert is_equal, f"Template: {template}\n\nMessages: {messages}\n\ntools: {tools}\n\nadd_generation_prompt: {add_generation_prompt}\n\nOfficial prompt:\n\n{official_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}"
+ assert is_equal_between_jinja_prompts, f"Template: {template}\n\nMessages: {messages}\n\ntools: {tools}\n\nadd_generation_prompt: {add_generation_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}\n\nJinja prompt:\n\n{implemented_jinja_prompt}"
diff --git a/agentfly/tests/unit/agents/templates/test_template_utilities.py b/agentfly/tests/unit/templates/test_template_utilities.py
similarity index 83%
rename from agentfly/tests/unit/agents/templates/test_template_utilities.py
rename to agentfly/tests/unit/templates/test_template_utilities.py
index 2cdbf4b..246f1cc 100644
--- a/agentfly/tests/unit/agents/templates/test_template_utilities.py
+++ b/agentfly/tests/unit/templates/test_template_utilities.py
@@ -1,5 +1,5 @@
-from .....agents.templates.templates import get_template, register_template, Template
-from .....agents.templates.vision_processor import get_processor
+from ....templates.templates import get_template, register_template, Template
+from ....templates.vision_processor import get_processor
def test_template_registration():
register_template(
diff --git a/agentfly/tests/unit/agents/templates/test_text_templates_full_align.py b/agentfly/tests/unit/templates/test_text_templates_full_align.py
similarity index 99%
rename from agentfly/tests/unit/agents/templates/test_text_templates_full_align.py
rename to agentfly/tests/unit/templates/test_text_templates_full_align.py
index 652ffa4..0e4e57d 100644
--- a/agentfly/tests/unit/agents/templates/test_text_templates_full_align.py
+++ b/agentfly/tests/unit/templates/test_text_templates_full_align.py
@@ -8,7 +8,7 @@
"""
-from .....templates import compare_hf_template
+from ....templates import compare_hf_template
from transformers import AutoTokenizer
import pytest
diff --git a/agentfly/tests/unit/agents/templates/test_text_templates_partial_align.py b/agentfly/tests/unit/templates/test_text_templates_partial_align.py
similarity index 96%
rename from agentfly/tests/unit/agents/templates/test_text_templates_partial_align.py
rename to agentfly/tests/unit/templates/test_text_templates_partial_align.py
index ea26532..86f322d 100644
--- a/agentfly/tests/unit/agents/templates/test_text_templates_partial_align.py
+++ b/agentfly/tests/unit/templates/test_text_templates_partial_align.py
@@ -1,7 +1,7 @@
import pytest
from transformers import AutoTokenizer
-from agentfly.agents.templates.templates import get_template
-from agentfly.agents.templates.utils import compare_hf_template
+from ....templates.templates import get_template
+from ....templates.utils import compare_hf_template
# nemotron, phi-4, glm-4
@pytest.mark.parametrize("template_name", ["qwen2.5-think", "qwen2.5-no-system-tool",])
diff --git a/agentfly/tests/unit/agents/templates/test_text_templates_tokenize.py b/agentfly/tests/unit/templates/test_text_templates_tokenize.py
similarity index 96%
rename from agentfly/tests/unit/agents/templates/test_text_templates_tokenize.py
rename to agentfly/tests/unit/templates/test_text_templates_tokenize.py
index 079537b..f92fdc2 100644
--- a/agentfly/tests/unit/agents/templates/test_text_templates_tokenize.py
+++ b/agentfly/tests/unit/templates/test_text_templates_tokenize.py
@@ -7,11 +7,11 @@
Since the align for textual prompt is already tested in other files, we only need to test the tokenization of the templates.
"""
-from .....agents.templates.utils import tokenize_conversation
+from ....templates.utils import tokenize_conversation
import pytest
from transformers import AutoTokenizer
import torch
-from .....agents.templates.templates import Chat
+from ....templates.templates import Chat
@pytest.mark.parametrize("template", ["llama-3.2", "qwen2.5"])
@pytest.mark.parametrize("messages", [
diff --git a/agentfly/tests/unit/agents/templates/test_vision_templates_full_align.py b/agentfly/tests/unit/templates/test_vision_templates_full_align.py
similarity index 98%
rename from agentfly/tests/unit/agents/templates/test_vision_templates_full_align.py
rename to agentfly/tests/unit/templates/test_vision_templates_full_align.py
index 0d98620..9514529 100644
--- a/agentfly/tests/unit/agents/templates/test_vision_templates_full_align.py
+++ b/agentfly/tests/unit/templates/test_vision_templates_full_align.py
@@ -9,7 +9,7 @@
"""
-from .....agents.templates.utils import compare_hf_template
+from ....templates.utils import compare_hf_template
from transformers import AutoTokenizer
import pytest
# "qwen2.5-think", "qwen2.5", "qwen2.5-no-tool",
diff --git a/agentfly/tests/unit/agents/templates/test_vision_templates_tokenize.py b/agentfly/tests/unit/templates/test_vision_templates_tokenize.py
similarity index 97%
rename from agentfly/tests/unit/agents/templates/test_vision_templates_tokenize.py
rename to agentfly/tests/unit/templates/test_vision_templates_tokenize.py
index 4117a5a..bfc380a 100644
--- a/agentfly/tests/unit/agents/templates/test_vision_templates_tokenize.py
+++ b/agentfly/tests/unit/templates/test_vision_templates_tokenize.py
@@ -8,8 +8,8 @@
"""
-from .....agents.templates.templates import Chat
-from .....agents.templates.utils import compare_hf_template, tokenize_conversation
+from ....templates.templates import Chat
+from ....templates.utils import compare_hf_template, tokenize_conversation
from transformers import AutoTokenizer
import pytest
import torch
diff --git a/agentfly/utils/deploy.py b/agentfly/utils/deploy.py
new file mode 100644
index 0000000..bf2101c
--- /dev/null
+++ b/agentfly/utils/deploy.py
@@ -0,0 +1,51 @@
+
+
+import os
+from ..templates import get_template
+from .. import AGENT_DATA_DIR
+import click
+
+
+def vllm_serve(model_name_or_path, template, tp, pp, dp, gpu_memory_utilization):
+ port = 8000
+
+
+ if template is None:
+ template_option = ""
+ else:
+ jinja_template = get_template(template).jinja_template()
+ if not os.path.exists(f"{AGENT_DATA_DIR}/cache"):
+ os.makedirs(f"{AGENT_DATA_DIR}/cache")
+ with open(f"{AGENT_DATA_DIR}/cache/jinja_template.jinja", "w") as f:
+ f.write(jinja_template)
+ template_option = f"--chat-template {AGENT_DATA_DIR}/cache/jinja_template.jinja"
+ # command = f"vllm serve {model_name_or_path} --chat-template {AGENT_DATA_DIR}/cache/jinja_template.jinja --tensor-parallel-size {tp} --pipeline-parallel-size {pp} --data-parallel-size {dp} --port {port} --enable-auto-tool-choice --tool-call-parser hermes --expand-tools-even-if-tool-choice-none"
+ command = f"""vllm serve {model_name_or_path} \
+{template_option} \
+--tensor-parallel-size {tp} \
+--pipeline-parallel-size {pp} \
+--data-parallel-size {dp} --port {port} \
+--gpu-memory-utilization {gpu_memory_utilization} \
+--enable-auto-tool-choice --tool-call-parser hermes"""
+
+ print(command)
+ os.system(command)
+
+
+
+@click.command()
+@click.option("--model_name_or_path")
+@click.option("--template", default=None)
+@click.option("--tp", type=int, default=1)
+@click.option("--pp", type=int, default=1)
+@click.option("--dp", type=int, default=1)
+@click.option("--gpu_memory_utilization", type=float, default=0.8)
+def main(model_name_or_path, template, tp, pp, dp, gpu_memory_utilization):
+ vllm_serve(model_name_or_path, template, tp, pp, dp, gpu_memory_utilization)
+
+
+if __name__=="__main__":
+ "python -m agentfly.utils.deploy --model_name_or_path Qwen/Qwen2.5-3B-Instruct --template qwen2.5 --tp 2 --dp 2"
+ "python -m agentfly.utils.deploy --model_name_or_path openai/gpt-oss-20b --tp 1 --dp 1"
+ "python -m agentfly.utils.deploy --model_name_or_path deepseek-ai/DeepSeek-R1-Distill-Qwen-7B --tp 1 --dp 1"
+ main()
\ No newline at end of file
diff --git a/agentfly/utils/monitor.py b/agentfly/utils/monitor.py
index 504f1f5..2dd66f1 100644
--- a/agentfly/utils/monitor.py
+++ b/agentfly/utils/monitor.py
@@ -38,11 +38,13 @@ class MetricEvent:
step Integer training step or episode counter.
timestamp Unix seconds (auto‑filled if omitted).
tags Arbitrary key/value pairs for filtering (e.g. run_id, module).
+ sinks List of sink names to send this event to. If None, sends to all sinks.
"""
kind: Literal["scalar", "hist", "text", "resource", "list"]
name: str
value: Any
+ sinks: Optional[List[str]] = None
step: Optional[int] = None
x: Optional[int] = None
x_name: Optional[str] = "x_axis"
@@ -55,7 +57,6 @@ def __post_init__(self) -> None:
self.timestamp = time.time()
-
class BaseSink(abc.ABC):
"""Abstract writer backend."""
@@ -100,21 +101,40 @@ def serialize_for_json(obj):
class JsonlSink(BaseSink):
"""Append events as JSON-Lines - human & machine friendly."""
- def __init__(self, path: str) -> None:
- os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
- self._f = open(path, "a", buffering=1, encoding="utf-8")
+ def __init__(self, directory: str) -> None:
+ os.makedirs(os.path.dirname(directory) or ".", exist_ok=True)
+ self.directory = directory
+ if os.path.isdir(directory):
+ default_file = os.path.join(directory, "default.jsonl")
+ with open(default_file, 'w') as f:
+ pass
+
+ self.log_files = {"default": open(default_file, "a", buffering=1, encoding="utf-8")}
+ else:
+ self.log_files = {}
+
self._lock = asyncio.Lock()
async def log(self, evt: MetricEvent) -> None:
+ evt_name = evt.name.replace("/", "-")
+ if evt_name not in self.log_files:
+ file_name = os.path.join(self.directory, f"{evt_name}.jsonl")
+ with open(file_name, 'w') as f:
+ pass
+ self.log_files[evt_name] = open(file_name, "a", buffering=1, encoding="utf-8")
+ file_obj = self.log_files[evt_name]
+
async with self._lock:
- self._f.write(json.dumps(serialize_for_json(asdict(evt)), ensure_ascii=False) + "\n")
+ file_obj.write(json.dumps(serialize_for_json(asdict(evt)), ensure_ascii=False) + "\n")
async def flush(self) -> None:
- self._f.flush()
+ for file_obj in self.log_files.values():
+ file_obj.flush()
async def close(self) -> None:
await super().close()
- self._f.close()
+ for file_obj in self.log_files.values():
+ file_obj.close()
class WandbSink(BaseSink):
@@ -130,6 +150,9 @@ def __init__(self, project: str, **wandb_init_kwargs: Any) -> None: # noqa: D40
self.tables: Dict[str, wandb.Table] = {}
async def log(self, evt: MetricEvent) -> None: # pragma: no cover
+ """
+ Log the event to wandb.
+ """
if wandb.run is not None:
payload = {evt.name: evt.value, **evt.tags}
if evt.x is not None:
@@ -249,7 +272,10 @@ async def _consumer_loop(cls) -> None:
evt = await cls._queue.get()
if evt is None: # sentinel
break
- for sink in list(cls._sinks.values()):
+ for sink_name, sink in list(cls._sinks.items()):
+ # Check if this sink should receive this event
+ if evt.sinks is not None and sink_name not in evt.sinks:
+ continue
try:
await sink.log(evt)
except Exception as exc:
diff --git a/agentfly/utils/trajectories.py b/agentfly/utils/trajectories.py
new file mode 100644
index 0000000..71332b2
--- /dev/null
+++ b/agentfly/utils/trajectories.py
@@ -0,0 +1,69 @@
+"""
+This module is used to collect trajectories from the agent.
+"""
+from typing import Dict, List
+import asyncio
+import json
+import os
+from ..agents.llm_backends.backend_configs import ClientConfig
+from ..agents import HFAgent
+import click
+
+def gather_responses(trajectories: List[Dict]):
+ responses = []
+ for trajectory in trajectories:
+ responses.append({
+ "id": trajectory["id"],
+ "response": trajectory["messages"][-1]["content"][0]["text"],
+ })
+
+ return responses
+
+@click.command()
+@click.option("--model_name_or_path", type=str, required=True)
+@click.option("--api_key", type=str, default="EMPTY")
+@click.option("--max_turns", type=int, required=True)
+@click.option("--num_chains", type=int, default=1, required=True)
+@click.option("--data_file", type=str, required=True)
+@click.option("--output_dir", type=str, required=True)
+def main(
+ model_name_or_path: str,
+ api_key: str,
+ max_turns: int,
+ num_chains: int,
+ data_file: str,
+ output_dir: str,
+):
+ async def run_agent():
+ with open(data_file, 'r') as f:
+ messages = json.load(f)
+
+ agent = HFAgent(
+ model_name_or_path=model_name_or_path,
+ tools=[],
+ backend="client",
+ backend_config=ClientConfig(
+ base_url="http://0.0.0.0:8000/v1",
+ api_key=api_key,
+ max_new_tokens=30720,
+ timeout=2400,
+ ),
+ local_cache_dir="test_cache",
+ )
+ await agent.run(
+ messages=messages,
+ num_chains=num_chains,
+ max_turns=max_turns,
+ )
+ responses = gather_responses(agent.trajectories)
+
+ os.makedirs(output_dir, exist_ok=True)
+ with open(os.path.join(output_dir, os.path.basename(model_name_or_path) + '_responses.json'), 'w') as f:
+ json.dump(responses, f, indent=2)
+
+ asyncio.run(run_agent())
+
+if __name__ == "__main__":
+ # """python -m agentfly.utils.trajectories --model_name_or_path Qwen/Qwen2.5-3B-Instruct --api_key EMPTY --max_turns 1 --data_file ../../datasets/viphy_test.json --output_dir data/trajectories/"""
+ """python -m agentfly.utils.trajectories --model_name_or_path openai/gpt-oss-20b --api_key EMPTY --max_turns 1 --data_file ../../datasets/viphy_test.json --output_dir data/trajectories/"""
+ main()
diff --git a/test_qwen3_template.py b/test_qwen3_template.py
deleted file mode 100644
index 172228c..0000000
--- a/test_qwen3_template.py
+++ /dev/null
@@ -1,83 +0,0 @@
-#!/usr/bin/env python3
-"""
-Test script for Qwen3Template implementation
-"""
-
-import sys
-import os
-sys.path.append(os.path.dirname(os.path.abspath(__file__)))
-
-from agentfly.templates.templates import Qwen3Template
-
-def test_qwen3_template():
- """Test the Qwen3Template with various scenarios"""
-
- # Create a Qwen3Template instance
- template = Qwen3Template(
- name="qwen3-test",
- system_template="<|im_start|>system\n{system_message}<|im_end|>\n",
- user_template="<|im_start|>user\n{content}<|im_end|>\n",
- assistant_template="<|im_start|>assistant\n{content}<|im_end|>\n",
- stop_words=["<|im_end|>"],
- generation_prompt="<|im_start|>assistant\n",
- )
-
- # Test case 1: Basic conversation without thinking
- print("=== Test Case 1: Basic conversation without thinking ===")
- messages1 = [
- {"role": "user", "content": "Hello, how are you?"},
- {"role": "assistant", "content": "I'm doing well, thank you!"}
- ]
-
- prompt1, elements1, roles1 = template.render(messages1, add_generation_prompt=False, enable_thinking=False)
- print("Prompt:")
- print(prompt1)
- print()
-
- # Test case 2: Conversation with thinking content that should be cleaned
- print("=== Test Case 2: Conversation with thinking content (should be cleaned) ===")
- messages2 = [
- {"role": "user", "content": "What is 2+2?"},
- {"role": "assistant", "content": "I need to add 2 and 2 together. This is basic arithmetic.The answer is 4."}
- ]
-
- prompt2, elements2, roles2 = template.render(messages2, add_generation_prompt=False, enable_thinking=False)
- print("Prompt (thinking content should be removed):")
- print(prompt2)
- print()
-
- # Test case 3: With add_generation_prompt=True and enable_thinking=False
- print("=== Test Case 3: With generation prompt and enable_thinking=False ===")
- messages3 = [
- {"role": "user", "content": "Tell me a joke"}
- ]
-
- prompt3, elements3, roles3 = template.render(messages3, add_generation_prompt=True, enable_thinking=False)
- print("Prompt (should include empty think tokens):")
- print(prompt3)
- print()
-
- # Test case 4: With add_generation_prompt=True and enable_thinking=True
- print("=== Test Case 4: With generation prompt and enable_thinking=True ===")
- prompt4, elements4, roles4 = template.render(messages3, add_generation_prompt=True, enable_thinking=True)
- print("Prompt (should NOT include empty think tokens):")
- print(prompt4)
- print()
-
- # Test case 5: Last message is assistant with enable_thinking=False
- print("=== Test Case 5: Last message is assistant with enable_thinking=False ===")
- messages5 = [
- {"role": "user", "content": "What's the weather like?"},
- {"role": "assistant", "content": "I don't have access to current weather data."}
- ]
-
- prompt5, elements5, roles5 = template.render(messages5, add_generation_prompt=False, enable_thinking=False)
- print("Prompt (last assistant message should have empty think tokens):")
- print(prompt5)
- print()
-
- print("All tests completed!")
-
-if __name__ == "__main__":
- test_qwen3_template()
-