diff --git a/.gitignore b/.gitignore index e5e071d..8d07fdc 100644 --- a/.gitignore +++ b/.gitignore @@ -121,8 +121,10 @@ tests/e2e/toy_examples/deepspeed/synchronous/output.txt *.lock # data +data/ *.parquet agentfly/agents/data/* +test_cache/ # local logs logs @@ -133,6 +135,10 @@ data/ test_cache/ /*.jpg /*.png +slurm/ +*.err +*.out +*.log # Notebooks agentfly/tests/*.ipynb @@ -146,3 +152,11 @@ test_outputs/ agentfly/data/ *.ipynb +# training scripts +training_scripts/ +verl/training_scripts/ + +# training scripts +training_scripts/ +verl/training_scripts/ + diff --git a/agentfly/agents/agent_base.py b/agentfly/agents/agent_base.py index 2810673..9df34c1 100644 --- a/agentfly/agents/agent_base.py +++ b/agentfly/agents/agent_base.py @@ -1,9 +1,10 @@ from abc import ABC, abstractmethod from collections import defaultdict +from datetime import datetime import json from .utils.messages import MessagesList from ..templates.templates import get_template -from ..__init__ import AGENT_DATA_DIR +from .. import AGENT_DATA_DIR from .llm_backends import ( AsyncVLLMBackend, AsyncVerlBackend, @@ -23,6 +24,7 @@ import logging from .chain.streaming_observer import ConsoleStreamObserver, StreamingManager from .utils.tokenizer import create_processor, create_tokenizer +from ..utils.monitor import JsonlSink, Monitor, WandbSink try: from verl.protocol import DataProto except ImportError: @@ -51,10 +53,12 @@ def __init__( backend_config: Any = None, reward_fn: Callable = None, log_file: str = "agent", - project_name: str = None, - run_name: str = None, streaming: str = "console", debug: bool = False, + monitors: List[str] = [], + wandb_project_name: str = None, + wandb_run_name: str = None, + local_cache_dir: str = None, **kwargs # To pass other unused arguments ): """ @@ -94,7 +98,6 @@ def __init__( # Create appropriate tokenizer for trajectory processing self.tokenizer = create_tokenizer(model_name_or_path) - self.processor = create_processor(model_name_or_path) self._reward_fn = reward_fn @@ -104,8 +107,12 @@ def __init__( else: self.jinja_template = get_template(self.template).jinja_template() - self.project_name = project_name - self.run_name = run_name + self.wandb_project_name = wandb_project_name + self.wandb_run_name = wandb_run_name + self.local_cache_dir = local_cache_dir + self.local_run_cache_dir = None + self._initialize_monitor(monitors) + self.streaming_manager = StreamingManager() if streaming == "console": self.streaming_manager.add_observer(ConsoleStreamObserver()) @@ -177,6 +184,17 @@ def _preprocess_messages(self, messages: List[Dict]): return messages_list.to_list() + def _initialize_monitor(self, monitors: List[str]) -> None: + for monitor in monitors: + if monitor == "local": + assert self.local_cache_dir is not None, "local_cache_dir must be set when using local monitor." + self.local_run_cache_dir = f"{os.path.join(self.local_cache_dir, os.path.basename(self.model_name_or_path), datetime.now().strftime('%Y%m%d_%H%M%S'))}" + Monitor.add_sink("jsonl", JsonlSink(f"{self.local_run_cache_dir}/")) + elif monitor == "wandb": + Monitor.add_sink("wandb", WandbSink(project=self.wandb_project_name, run_name=self.wandb_run_name)) + else: + raise ValueError(f"Monitor {monitor} is not supported.") + async def run(self, messages: Union[List[dict], np.ndarray, Dict], max_turns: int, @@ -392,4 +410,4 @@ def get_verl_data_proto(self): batch = DataProto.from_single_dict(inputs, meta_info={"use_agent": True}) return batch - \ No newline at end of file + diff --git a/agentfly/agents/chain/chain_base.py b/agentfly/agents/chain/chain_base.py index c610c23..8ac5053 100644 --- a/agentfly/agents/chain/chain_base.py +++ b/agentfly/agents/chain/chain_base.py @@ -137,7 +137,6 @@ def __init__(self): self.terminal_status = ["terminal", "finish"] self.global_step = 0 self.finished_chains_count = 0 - self.initialize_monitor() self.monitor_info = defaultdict(list) def reset(self) -> None: @@ -333,7 +332,7 @@ async def _run_single_chain(self, await done_queue.put((chain_id, chain, current_node)) self.finished_chains_count += 1 - self.monitor_chain() + self.monitor_chain(trajectory=current_node.messages.messages) async def _generate_response(self, current_node, tools, depth, chain_id, enable_streaming): """Generate response with optional streaming support.""" @@ -485,7 +484,6 @@ async def _finalize_chain(self, chain_id, chain, current_node, depth): await self.release_resources(chain_id) - async def release_resources(self, id: str) -> None: for tool in self.tools: if isinstance(tool, Tool): @@ -498,10 +496,6 @@ async def set_tools(self, id: str, env_args: Dict[str, Any]) -> None: if isinstance(tool, Tool): await tool.set_env(id, env_args) - def initialize_monitor(self) -> None: - Monitor.add_sink("jsonl", JsonlSink(f"{AGENT_DATA_DIR}/demo_metrics.jsonl")) - Monitor.add_sink("wandb", WandbSink(project=self.project_name, run_name=self.run_name)) - def monitor_step(self) -> None: messages = self.get_messages() avg_turns = 0 @@ -589,9 +583,19 @@ def monitor_step(self) -> None: emit(evt) - def monitor_chain(self) -> None: + def monitor_chain(self, trajectory) -> None: self.monitor_info['Agent/chains'].append(self.finished_chains_count) for tool in self.tools: if tool.is_stateful and tool.pool_size > 0: self.monitor_info[f"Agent/Tool/{tool.name}/used_env_size"].append(tool.used_env_size) + # We only log the trajectory to local jsonl file, for wandb much bandwidth is needed + evt = MetricEvent( + sinks=["jsonl"], + kind="text", + name="Agent/rollout/trajectory", + value=json.dumps(serialize_for_json(trajectory), indent=2), + x=self.global_step, + x_name="Agent/rollout/step" + ) + emit(evt) diff --git a/agentfly/agents/chain/chain_base_simplified.py b/agentfly/agents/chain/chain_base_simplified.py new file mode 100644 index 0000000..533fb64 --- /dev/null +++ b/agentfly/agents/chain/chain_base_simplified.py @@ -0,0 +1,635 @@ +import asyncio +from collections import defaultdict +from dataclasses import dataclass, field +import json +import time +from ..utils.messages import MessagesList, Messages +from ...utils.timing import Timer +from typing import Any, Dict, List, Optional, Tuple, Union, Callable +import uuid +from termcolor import colored +import numpy as np +from copy import deepcopy +from ...tools.tool_base import Tool, submit_tool_call, submit_tool_calls +from tqdm.asyncio import tqdm_asyncio +from ...utils.monitor import JsonlSink, MetricEvent, Monitor, WandbSink, emit, serialize_for_json +from ... import AGENT_DATA_DIR +import wandb +from .streaming_observer import ConsoleStreamObserver, StreamingManager, StreamEvent, StreamEventType + +@dataclass +class SimplifiedNode: + """ + Simplified node that aligns with message types. + Each node represents a single message turn and only stores its own content. + """ + role: str # "user", "assistant", "tool", "system" + content: Any # The actual message content + is_terminal: bool = False + is_pruned: bool = False + description: str = "" + observation: str = "" # For tool nodes, stores the tool result + tool_name: Optional[str] = None # For tool nodes + tool_call_id: Optional[str] = None # For tool nodes + parent: Optional["SimplifiedNode"] = None + children: List["SimplifiedNode"] = field(default_factory=list) + + @property + def depth(self) -> int: + return 0 if self.parent is None else self.parent.depth + 1 + + def print_node(self, process_id: int = 0) -> None: + if process_id != 0: + return + color_converter = { + "user": "green", + "assistant": "blue", + "tool": "yellow", + "system": "magenta" + } + color = color_converter.get(self.role, "white") + print(colored(f"{self.role.upper()}: {self.description}", color=color)) + if self.observation: + obs = ( + self.observation + if len(self.observation) < 1536 + else f"{self.observation[:1536]}...(len={len(self.observation)})" + ) + print(colored(f"Observation: {obs}", color="yellow")) + + def to_json(self) -> dict: + json_obj = { + "role": self.role, + "is_terminal": self.is_terminal, + "is_pruned": self.is_pruned, + "depth": self.depth, + "description": self.description, + "content": self.content + } + if self.observation: + json_obj["observation"] = self.observation + if self.tool_name: + json_obj["tool_name"] = self.tool_name + if self.tool_call_id: + json_obj["tool_call_id"] = self.tool_call_id + return json_obj + + def to_json_recursive(self) -> dict: + data = self.to_json() + data["children"] = [child.to_json_recursive() for child in self.children] + return data + + +class SimplifiedChain: + """ + Simplified chain that stores nodes aligned with message types. + Each node represents a single message turn. + """ + def __init__(self, info): + self.root: Optional[SimplifiedNode] = None + self.info: Dict[str, Any] = info + self.system_message: Optional[SimplifiedNode] = None + + def add_node( + self, + role: str, + content: Any, + is_terminal: bool = False, + is_pruned: bool = False, + description: str = "", + observation: str = "", + tool_name: Optional[str] = None, + tool_call_id: Optional[str] = None + ) -> SimplifiedNode: + new_node = SimplifiedNode( + role=role, + content=content, + is_terminal=is_terminal, + is_pruned=is_pruned, + description=description, + observation=observation, + tool_name=tool_name, + tool_call_id=tool_call_id + ) + + if self.root is None: + self.root = new_node + else: + current = self.root + while len(current.children) > 0: + current = current.children[0] + current.children = [new_node] + new_node.parent = current + return new_node + + def get_full_messages(self) -> List[Dict[str, Any]]: + """ + Reconstruct the full message history from the chain nodes. + """ + messages = [] + node = self.root + + # Add system message if it exists + if self.system_message: + messages.append({ + "role": "system", + "content": self.system_message.content + }) + + # Traverse the chain and collect messages + while node: + if node.role == "tool": + # For tool messages, we need to reconstruct the tool message format + messages.append({ + "role": "tool", + "tool_call_id": node.tool_call_id, + "tool_name": node.tool_name, + "content": [{"type": "text", "text": node.observation}] + }) + else: + # For user and assistant messages + messages.append({ + "role": node.role, + "content": node.content + }) + + if node.children: + node = node.children[0] + else: + break + + return messages + + def to_json(self) -> List[dict]: + chain_json = [] + node = self.root + while node: + chain_json.append(node.to_json()) + if node.children: + node = node.children[0] + else: + break + return chain_json + + +class SimplifiedChainRollout: + """ + Simplified chain-based rollout that uses message-aligned nodes. + """ + def __init__(self): + self.reset() + self.chains: Dict[str, SimplifiedChain] = {} + self.current_nodes: Dict[str, SimplifiedNode] = {} + self.timer = Timer() + self.terminal_status = ["terminal", "finish"] + self.global_step = 0 + self.finished_chains_count = 0 + self.monitor_info = defaultdict(list) + + def reset(self) -> None: + self.status_code: str = "continue" + self.query_count: int = 0 + self.total_tokens: int = 0 + self.success_count: int = 0 + self.chains = [] + self.current_nodes = {} + + @property + def timing_data(self): + return self.timer.timing_data + + def to_json(self) -> dict: + return { + "finish": [chain.status_code == "success" for chain in self.chains], + "chains": [chain.to_json() for chain in self.chains] + } + + def initialize_chains(self, messages_list: MessagesList, num_chains: int) -> Tuple[Dict[str, SimplifiedChain], Dict[str, SimplifiedNode]]: + chains = {} + start_nodes = {} + group_ids = [str(uuid.uuid4()) for _ in range(len(messages_list))] + + for group_idx, messages in enumerate(messages_list): + group_id = group_ids[group_idx] + for j in range(num_chains): + ch = SimplifiedChain(messages.meta | {"group_id": group_id}) + + # Extract system message if present + if messages.messages and messages.messages[0]["role"] == "system": + system_content = messages.messages[0]["content"] + ch.system_message = SimplifiedNode( + role="system", + content=system_content, + description="System prompt" + ) + user_messages = messages.messages[1:] + else: + user_messages = messages.messages + + # Create user node with the initial user message(s) + user_content = user_messages[0]["content"] if user_messages else "" + root = ch.add_node( + role="user", + content=user_content, + description="Initial user input" + ) + + cid = str(uuid.uuid4()) + chains[cid] = ch + start_nodes[cid] = root + + return chains, start_nodes + + def get_messages(self) -> List[Any]: + messages = [] + for id, node in self.current_nodes.items(): + info = self.chains[id].info + message_item = {} + message_item["messages"] = self.chains[id].get_full_messages() + message_item.update(info) + messages.append(message_item) + return messages + + def validate_run_args(self, max_turns: int, num_chains: int, enable_streaming: bool): + assert max_turns >= 1, "max_turns must be at least 1." + assert num_chains >= 1, "num_chains must be at least 1." + for observer in self.streaming_manager.observers: + if isinstance(observer, ConsoleStreamObserver) and enable_streaming: + assert num_chains == 1, "num_chains must be 1 when ConsoleStreamObserver is used." + + + async def run_async(self, + messages: List[Dict], + max_turns: int, + num_chains: int, + generation_config: Optional[Dict[str, Any]] = None, + enable_streaming: bool = False, + streaming_callback: Optional[Callable] = None, + ): + """ + Run the simplified chain-based rollout. + """ + self.validate_run_args(max_turns, num_chains, enable_streaming) + Monitor.ensure_started() + self.reset() + + messages_list = MessagesList.from_data(messages) + chains, first_nodes = self.initialize_chains( + messages_list, + num_chains + ) + tool_schemas = [tool.schema for tool in self.tools] + + done_q = asyncio.Queue() + tasks = [ + asyncio.create_task( + self._run_single_chain( + cid, + node, + chains[cid], + tool_schemas, + max_turns=max_turns, + done_queue=done_q, + enable_streaming=enable_streaming + ) + ) + for cid, node in first_nodes.items() + ] + + await tqdm_asyncio.gather(*tasks) + + self.chains = {} + while not done_q.empty(): + cid, chain, node = done_q.get_nowait() + self.chains[cid] = chain + self.current_nodes[cid] = node + + self.global_step += 1 + self.monitor_step() + + async def _run_single_chain(self, + chain_id: str, + first_node: SimplifiedNode, + chain: SimplifiedChain, + tools: List[Dict], + max_turns: int, + done_queue: asyncio.Queue, + enable_streaming: bool = False + ): + """ + Run a single simplified chain. + """ + current_node = first_node + depth = 0 + have_set_tools = False + + while not current_node.is_terminal and depth < max_turns: + # Get current message history for generation + current_messages = chain.get_full_messages() + + if not current_node.is_terminal: + # Generate assistant response + assistant_msg = await self._generate_response( + current_messages, tools, depth, chain_id, enable_streaming + ) + + # Create assistant node + assistant_node = chain.add_node( + role="assistant", + content=assistant_msg.get("content", ""), + description=assistant_msg.get("content", "")[:100] + "..." if len(assistant_msg.get("content", "")) > 100 else assistant_msg.get("content", ""), + is_terminal=assistant_msg.get("status", "continue") in self.terminal_status + ) + current_node = assistant_node + + # Check if the assistant node is terminal + if current_node.is_terminal: + break + + # Handle tool calls + if assistant_msg.get("tool_calls"): + for tool_call in assistant_msg["tool_calls"]: + result = await self._execute_tool_call( + tool_call, chain, chain_id, depth, + have_set_tools, enable_streaming + ) + have_set_tools = True + + # Create tool node + tool_node = chain.add_node( + role="tool", + content=result.get("arguments", ""), + description=f"Tool: {result.get('name', 'unknown')}", + observation=result["observation"], + tool_name=result.get("name"), + tool_call_id=tool_call["id"], + is_terminal=result["status"] in self.terminal_status + ) + current_node = tool_node + else: + # No tool calls, chain is finished + break + + depth += 1 + + # Finalize chain + await self._finalize_chain(chain_id, chain, current_node, depth) + await done_queue.put((chain_id, chain, current_node)) + + self.finished_chains_count += 1 + self.monitor_chain(trajectory=chain.get_full_messages()) + + async def _generate_response(self, current_messages, tools, depth, chain_id, enable_streaming): + """Generate response with optional streaming support.""" + if enable_streaming: + # Emit generation start event + await self.streaming_manager.emit_event(StreamEvent( + event_type=StreamEventType.LLM_GENERATION_START, + chain_id=chain_id, + timestamp=time.time(), + data={"depth": depth}, + step=depth, + depth=depth + )) + + # Check if we have streaming capabilities + has_streaming = False + if hasattr(self, 'generate_streaming'): + has_streaming = True + elif hasattr(self, 'llm_engine') and hasattr(self.llm_engine, 'generate_streaming'): + has_streaming = True + # Create a wrapper to use the LLM engine's streaming + async def generate_streaming_wrapper(messages_list, **kwargs): + async for chunk in self.llm_engine.generate_streaming(messages_list, **kwargs): + yield chunk + self.generate_streaming = generate_streaming_wrapper + + if has_streaming: + # Collect full response from streaming + full_response = "" + async for chunk in self.generate_streaming([current_messages], tools=tools): + await self.streaming_manager.emit_event(StreamEvent( + event_type=StreamEventType.LLM_GENERATION_CHUNK, + chain_id=chain_id, + timestamp=time.time(), + data={"content": chunk}, + step=depth, + depth=depth + )) + full_response = chunk + + # Emit generation end event + await self.streaming_manager.emit_event(StreamEvent( + event_type=StreamEventType.LLM_GENERATION_END, + chain_id=chain_id, + timestamp=time.time(), + data={"full_response": full_response}, + step=depth, + depth=depth + )) + + # Parse response + new_msg = self.parse([full_response], tools=self.tools) + return new_msg[0] + else: + # Fallback to non-streaming generation + responses = await self.generate_async([current_messages], tools=tools, num_return_sequences=1) + new_msg = self.parse(responses, tools=self.tools) + + # Emit a single chunk event for the full response + full_response = new_msg[0].get("content", "") + if isinstance(full_response, list) and len(full_response) > 0: + if isinstance(full_response[0], dict) and "text" in full_response[0]: + full_response = full_response[0]["text"] + else: + full_response = str(full_response) + elif not isinstance(full_response, str): + full_response = str(full_response) + + await self.streaming_manager.emit_event(StreamEvent( + event_type=StreamEventType.LLM_GENERATION_CHUNK, + chain_id=chain_id, + timestamp=time.time(), + data={"content": full_response}, + step=depth, + depth=depth + )) + + # Emit generation end event + await self.streaming_manager.emit_event(StreamEvent( + event_type=StreamEventType.LLM_GENERATION_END, + chain_id=chain_id, + timestamp=time.time(), + data={"full_response": full_response}, + step=depth, + depth=depth + )) + + return new_msg[0] + else: + # Non-streaming generation + responses = await self.generate_async([current_messages], tools=tools, num_return_sequences=1) + new_msg = self.parse(responses, tools=self.tools) + return new_msg[0] + + async def _execute_tool_call(self, tool_call, chain, chain_id, depth, have_set_tools, enable_streaming): + """Execute a tool call with optional streaming support.""" + tool_name = tool_call["function"]["name"] + tool_input = tool_call["function"]["arguments"] + + # Set up tools if needed + if not have_set_tools: + await self.set_tools(chain_id, chain.info) + have_set_tools = True + + # Execute tool call + result = await submit_tool_call( + tool_name, + tool_input, + id=chain_id, + allowed_tool_names=self.tool_names + ) + + if enable_streaming: + # Emit tool observation event + await self.streaming_manager.emit_event(StreamEvent( + event_type=StreamEventType.TOOL_OBSERVATION, + chain_id=chain_id, + timestamp=time.time(), + data={ + "tool_name": tool_name, + "observation": result["observation"], + "status": result["status"] + }, + step=depth, + depth=depth + )) + + return result + + + async def _finalize_chain(self, chain_id, chain, current_node, depth): + """Finalize the chain with reward calculation and cleanup.""" + if self._reward_fn is not None: + trajectory = chain.get_full_messages() + final_response = self.extract_final_response(trajectory) + other_args = {k: v for k, v in chain.info.items() if k not in ['prediction', 'trajectory', 'id']} + chain.info["reward"] = await self._reward_fn(prediction=final_response, **other_args, trajectory=trajectory, id=chain_id) + else: + chain.info["reward"] = None + + await self.release_resources(chain_id) + + async def release_resources(self, id: str) -> None: + for tool in self.tools: + if isinstance(tool, Tool): + await tool.release(id=id) + if self._reward_fn is not None: + await self._reward_fn.release(id=id) + + async def set_tools(self, id: str, env_args: Dict[str, Any]) -> None: + for tool in self.tools: + if isinstance(tool, Tool): + await tool.set_env(id, env_args) + + def monitor_step(self) -> None: + messages = self.get_messages() + avg_turns = 0 + avg_tool_calls = 0 + avg_response_length = 0 + tool_calls_by_name = defaultdict(int) + + for message in messages: + for msg in message['messages']: + if msg['role'] == 'assistant': + avg_turns += 1 + if msg['role'] == 'tool': + avg_tool_calls += 1 + tool_call_name = msg['tool_name'] + tool_calls_by_name[tool_call_name] += 1 + + avg_turns /= len(messages) + avg_tool_calls /= len(messages) + + ent = MetricEvent( + kind="scalar", + name=f"Agent/rollout/step", + value=self.global_step, + x=self.global_step, + x_name="Agent/rollout/step" + ) + emit(ent) + + evt = MetricEvent( + kind="scalar", + name=f"Agent/rollout/avg_turns", + value=avg_turns, + x=self.global_step, + x_name="Agent/rollout/step" + ) + emit(evt) + + evt = MetricEvent( + kind="scalar", + name=f"Agent/rollout/avg_tool_calls", + value=avg_tool_calls, + x=self.global_step, + x_name="Agent/rollout/step" + ) + emit(evt) + + + for tool_name, tool_call_count in tool_calls_by_name.items(): + evt = MetricEvent( + kind="scalar", + name=f"Agent/rollout/tool_calls/{tool_name}", + value=tool_call_count, + x=self.global_step, + x_name="Agent/rollout/step" + ) + emit(evt) + + evt = MetricEvent( + kind="scalar", + name=f"Agent/rollout/step", + value=self.global_step, + x=self.global_step, + x_name="Agent/rollout/step" + ) + emit(evt) + + sample_message_json = json.dumps(serialize_for_json(messages[0]), indent=2) + evt = MetricEvent( + kind="text", + name="Agent/rollout/sample_message", + value=sample_message_json, + x=self.global_step, + x_name="Agent/rollout/step" + ) + emit(evt) + + for k, v in self.monitor_info.items(): + if k != "Agent/chains": # We don't log number of chains + evt = MetricEvent( + kind="list", + name=k, + value=v, + x=self.monitor_info['Agent/chains'], + ) + emit(evt) + + + def monitor_chain(self, trajectory) -> None: + self.monitor_info['Agent/chains'].append(self.finished_chains_count) + for tool in self.tools: + if tool.is_stateful and tool.pool_size > 0: + self.monitor_info[f"Agent/Tool/{tool.name}/used_env_size"].append(tool.used_env_size) + + evt = MetricEvent( + kind="text", + name="Agent/rollout/trajectory", + value=json.dumps(serialize_for_json(trajectory), indent=2), + x=self.global_step, + x_name="Agent/rollout/step" + ) + emit(evt) diff --git a/agentfly/agents/llm_backends/llm_backends.py b/agentfly/agents/llm_backends/llm_backends.py index 7ab9a4b..df62335 100644 --- a/agentfly/agents/llm_backends/llm_backends.py +++ b/agentfly/agents/llm_backends/llm_backends.py @@ -21,8 +21,7 @@ import logging import PIL - -LOGGER = logging.getLogger(__name__) +logger = logging.getLogger(__name__) try: from verl.protocol import DataProto @@ -340,13 +339,13 @@ async def generate_async(self, messages_list: str, **kwargs) -> str: inputs = self._process_inputs(prompts, vision_inputs) if n > 1: inputs = [_input for _input in inputs for _ in range(n)] - LOGGER.debug(f"[AsyncVLLMBackend] inputs: {inputs}") + logger.debug(f"[AsyncVLLMBackend] inputs: {inputs}") tasks = [self._generate_single(_input, sampling_params) for _input in inputs] outputs = await asyncio.gather(*tasks) # Flatten the outputs outputs = [output for output_list in outputs for output in output_list] response_texts = [output.text for output in outputs] - LOGGER.debug(f"[AsyncVLLMBackend] response_texts: {response_texts}") + logger.debug(f"[AsyncVLLMBackend] response_texts: {response_texts}") return response_texts @@ -513,7 +512,7 @@ def __init__( # --------------------------------------------------------------------- # # Low‑level single request (runs in threadpool so it doesn't block loop) # --------------------------------------------------------------------- # - @retry(stop=stop_after_attempt(1), wait=wait_exponential(multiplier=1, min=4, max=15)) + @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=15)) def _blocking_call(self, messages: List[List[Dict]], **kwargs) -> str: if "num_return_sequences" in kwargs: n = kwargs.pop("num_return_sequences") @@ -525,7 +524,6 @@ def _blocking_call(self, messages: List[List[Dict]], **kwargs) -> str: else: tool_choice = "none" - print(f"[ClientBackend] messages: {messages}") resp = self.client.chat.completions.create( model=self.model_name, messages=messages, @@ -570,7 +568,7 @@ def _convert_to_openai_chat_without_tool_call_processing(self, messages: list) - return messages # Public API ‑‑ sync or async depending on caller's context - def async_generate( + def generate( self, messages: List[List[Dict]] | List[Dict], **kwargs, @@ -588,10 +586,12 @@ def async_generate( messages_list = [messages] # single else: messages_list = messages # batch - print(f"[ClientBackend] messages_list: {messages_list}") + logger.debug(f"[ClientBackend] messages_list: {messages_list}") messages_list = [self._convert_to_openai_chat_without_tool_call_processing(messages) for messages in messages_list] async def _runner(): + # Ensure refiller is running in this event loop + self._ensure_refiller_running() tasks = [asyncio.create_task(self._call(_input, **kwargs)) for _input in messages_list] # Flatten the response list response_texts_list_or_dict = await asyncio.gather(*tasks) @@ -621,7 +621,7 @@ async def _runner(): async def generate_async(self, messages: List[List[Dict]] | List[Dict], **kwargs) -> List[str]: - return await self.async_generate(messages, **kwargs) + return await self.generate(messages, **kwargs) # Background token‑bucket refill (one token each 60/max_rpm seconds) async def _refill_tokens(self): @@ -633,11 +633,11 @@ async def _refill_tokens(self): def _ensure_refiller_running(self): if self._refill_task is None or self._refill_task.done(): - loop = asyncio.get_event_loop() - self._refill_task = loop.create_task(self._refill_tokens()) - - # Automatically start the refiller at first public use - def __getattribute__(self, name): - if name == "generate": - self._ensure_refiller_running() - return super().__getattribute__(name) + try: + # Try to get running loop first + loop = asyncio.get_running_loop() + self._refill_task = loop.create_task(self._refill_tokens()) + except RuntimeError: + # No event loop running, this will be handled by the caller + # The refiller will be started when we're in an event loop + pass diff --git a/agentfly/agents/specialized/hf_agent.py b/agentfly/agents/specialized/hf_agent.py index df75ba5..eee913f 100644 --- a/agentfly/agents/specialized/hf_agent.py +++ b/agentfly/agents/specialized/hf_agent.py @@ -1,6 +1,7 @@ from ast import Dict import json +import os from typing import List from ..agent_base import BaseAgent from ..parsers import extract_tool_calls diff --git a/agentfly/rewards/__init__.py b/agentfly/rewards/__init__.py index 09ba592..ee2e107 100644 --- a/agentfly/rewards/__init__.py +++ b/agentfly/rewards/__init__.py @@ -17,4 +17,6 @@ from .alfworld_reward import alfworld_episode_reward from .scienceworld_reward import scienceworld_reward from .gui_reward import gui_reward +from .vlm_as_judge.vlm_as_judge_reward import vlm_as_judge_reward +from .vlm_as_judge.vlm_as_judge_reward import vlm_as_judge_pass_reward diff --git a/agentfly/rewards/llm_as_judge/llm_as_judge_client.py b/agentfly/rewards/llm_as_judge/llm_as_judge_client.py index a33078e..895cf62 100644 --- a/agentfly/rewards/llm_as_judge/llm_as_judge_client.py +++ b/agentfly/rewards/llm_as_judge/llm_as_judge_client.py @@ -415,4 +415,3 @@ async def llm_as_judge_client_math_reward(prediction: str, answer: str) -> float "reward": 0.0, } - diff --git a/agentfly/tests/unit/agents/templates/__init__.py b/agentfly/rewards/vlm_as_judge/__init__.py similarity index 100% rename from agentfly/tests/unit/agents/templates/__init__.py rename to agentfly/rewards/vlm_as_judge/__init__.py diff --git a/agentfly/rewards/vlm_as_judge/vlm_as_judge_client.py b/agentfly/rewards/vlm_as_judge/vlm_as_judge_client.py new file mode 100644 index 0000000..79fd483 --- /dev/null +++ b/agentfly/rewards/vlm_as_judge/vlm_as_judge_client.py @@ -0,0 +1,478 @@ +import os +import re +import json +import glob +import time +import asyncio +import logging +from typing import Any, Dict, List, Optional + +from pathlib import Path +from openai import AsyncOpenAI +from tqdm.asyncio import tqdm_asyncio + +from ... import AGENT_HOME + +logger = logging.getLogger(__name__) + + +def _resolve_server_status_dir() -> str: + """Resolve the directory that contains vLLM server status JSON files. + + Priority: + 1) env `VLLM_SERVER_STATUS_DIR` + 2) env `DATA_PROCESS_HOME` + /vllm_server/server_status + 3) AGENT_HOME + /data-process/vllm_server/server_status + 4) Explicit fallback to /mnt/weka/... path as provided by user + """ + # 1) Explicit override + override = os.getenv("VLLM_SERVER_STATUS_DIR") + if override and os.path.isdir(override): + return override + + # 2) data-process home + dp_home = os.getenv("DATA_PROCESS_HOME") + if dp_home: + candidate = os.path.join(dp_home, "vllm_server", "server_status") + if os.path.isdir(candidate): + return candidate + + # 3) default relative to AGENT_HOME + candidate = os.path.join(AGENT_HOME, "data-process", "vllm_server", "server_status") + if os.path.isdir(candidate): + return candidate + + # Last resort: return the AGENT_HOME-based path even if missing (caller errors out) + return candidate + + +def get_server_ips(model: str) -> List[str]: + """Get list of server IPs from the most recent complete server instances file for a specific model.""" + server_status_dir = _resolve_server_status_dir() + + # Clean model name for filename matching (replace / and - with _) + model_clean = model.replace('/', '_').replace('-', '_') + + # Try multiple patterns to match the server files + patterns = [ + f"server_instances_complete_{model_clean}_*.json", + f"server_instances_complete_vllm_{model_clean}_*.json", + f"server_instances_complete_*{model_clean}*.json" + ] + + json_files = [] + for pattern in patterns: + search_pattern = os.path.join(server_status_dir, pattern) + found_files = glob.glob(search_pattern) + if found_files: + json_files = found_files + break + + if not json_files: + # Fallback: try to find any server instances file and filter by model in the JSON content + fallback_pattern = os.path.join(server_status_dir, "server_instances_complete_*.json") + all_files = glob.glob(fallback_pattern) + + for file in all_files: + try: + with open(file, 'r') as f: + server_info = json.load(f) + + # Check if any server in this file matches our model + matching_servers = [info for info in server_info if info.get('model') == model] + if matching_servers: + json_files = [file] + logger.info(f"Found servers for model '{model}' in fallback file: {file}") + break + except Exception as e: + logger.warning(f"Error reading file {file}: {e}") + continue + + if not json_files: + raise RuntimeError( + f"No server instances file found for model '{model}' in {search_pattern} or any fallback file under {server_status_dir}" + ) + + # Get the most recent file + latest_file = max(json_files, key=os.path.getctime) + + with open(latest_file, 'r') as f: + server_info = json.load(f) + + # Filter servers by model and extract IPs + ips = [] + for info in server_info: + if info.get('model') == model and 'ip' in info: + ips.append(info['ip']) + + if not ips: + raise RuntimeError(f"No IPs found for model '{model}' in server instances file {latest_file}") + + logger.info(f"Found {len(ips)} server instances for model '{model}': {ips}") + return ips + + +class RateLimiter: + def __init__(self, max_window_size: int): + self.max_window_size = max_window_size + self.semaphore = asyncio.Semaphore(max_window_size) + + async def acquire(self): + await self.semaphore.acquire() + + async def release(self): + self.semaphore.release() + + +class RoundRobinClient: + def __init__(self, ips: List[str], port: int, api_key: str, timeout: int, rate_limiters: List[RateLimiter]): + self.ips = ips + self.current_index = 0 + self.port = port + self.api_key = api_key + self.clients = [ + AsyncOpenAI( + base_url=f"http://{ip}:{port}/v1", + api_key=api_key, + timeout=timeout, + ) for ip in ips + ] + self.rate_limiters = rate_limiters + + async def get_next_available_client(self) -> tuple[AsyncOpenAI, RateLimiter]: + # Find the instance with the most available slots + max_available = -1 + best_client = None + best_limiter = None + best_index = -1 + + for i in range(len(self.clients)): + available = self.rate_limiters[i].semaphore._value + if available > max_available and available > 0: + max_available = available + best_client = self.clients[i] + best_limiter = self.rate_limiters[i] + best_index = i + + if best_client is not None: + await best_limiter.acquire() + return best_client, best_limiter + + # If no instance has available slots, wait on all and race + wait_tasks = [(i, asyncio.create_task(limiter.semaphore.acquire())) + for i, limiter in enumerate(self.rate_limiters)] + done, pending = await asyncio.wait( + [task for _, task in wait_tasks], + return_when=asyncio.FIRST_COMPLETED + ) + + for _, task in wait_tasks: + if task not in done: + task.cancel() + + for i, task in wait_tasks: + if task in done: + return self.clients[i], self.rate_limiters[i] + + raise RuntimeError("No instance became available despite wait completion") + + +class VLMClient: + def __init__(self, + model: str, + timeout_seconds: int = 60, + max_window_size_per_instance: int = 10, + port: int = 8000, + api_key: str = "token-abc123"): + self.timeout_seconds = timeout_seconds + server_ips = get_server_ips(model) + # server_ips = ["10.24.3.24"] + rate_limiters = [RateLimiter(max_window_size_per_instance) for _ in server_ips] + self.client_manager = RoundRobinClient(server_ips, port, api_key, timeout_seconds, rate_limiters) + + async def single_call(self, inputs, model, **kwargs): + try: + client, rate_limiter = await self.client_manager.get_next_available_client() + try: + response = await client.chat.completions.create( + model=model, + messages=inputs, + timeout=self.timeout_seconds, + **kwargs + ) + return response.choices[0].message.content + finally: + await rate_limiter.release() + except asyncio.TimeoutError: + logger.error(f"Request timed out after {self.timeout_seconds}s") + return None + except Exception as e: + logger.error(f"Error processing request: {e}") + return None + + async def process_all_inputs(self, inputs_list, num_generations=1, model=None, **kwargs): + all_tasks = [] + for inputs in inputs_list: + for _ in range(num_generations): + all_tasks.append(self.single_call(inputs, model=model, **kwargs)) + + responses = await tqdm_asyncio.gather(*all_tasks, desc="Processing VLM requests") + + grouped_responses = [] + for i in range(0, len(responses), num_generations): + grouped_responses.append(responses[i:i + num_generations]) + + return grouped_responses + + def check_availability(self) -> Dict[str, Any]: + # Mirror llm client basic stats + total_capacity = 0 + total_available = 0 + total_used = 0 + instance_details = [] + + for i in range(len(self.client_manager.clients)): + available = self.client_manager.rate_limiters[i].semaphore._value + capacity = self.client_manager.rate_limiters[i].max_window_size + used = capacity - available + + total_capacity += capacity + total_available += available + total_used += used + + instance_details.append({ + 'instance_id': i, + 'ip': self.client_manager.ips[i], + 'port': self.client_manager.port, + 'available_slots': available, + 'used_slots': used, + 'total_slots': capacity, + }) + + return { + 'total_instances': len(self.client_manager.clients), + 'total_capacity': total_capacity, + 'total_available': total_available, + 'total_used': total_used, + 'has_available_slots': total_available > 0, + 'instances': instance_details + } + + def is_available(self) -> bool: + availability = self.check_availability() + return availability['has_available_slots'] + + +DEFAULT_VLM_PROMPT_TEMPLATE = """You are given a set of visual verification questions and a description of objects and motion observed in an input medium (e.g., image or video). + +Your task is to **evaluate each question** based on whether it is **correctly reflected in the visual content**, considering visual cues, shape changes from viewpoint, and possible symbolic representations. + + +--- + + **Visual Reasoning Guidelines**: + +1. **Perspective Awareness**: + Objects may appear different based on viewpoint. For example: + - A **cylinder** may look like a **circle (top view)** or a **rectangle/square (side view)**. + - A **circular path** may appear as a **wave-like curve or straight line** in 2D projection. + +2. **Symbolic Representations**: + Common simplifications may be used. You should **reasonably infer** their meaning: + - A series of **dots or circles** may represent **foam markers** or control points. + - A **rectangle** may represent a **container** (e.g., cylindrical viewed from the side). + - A **line** may represent a **rubber mat** or constraint boundary. + - The object and track specifics might do not match directly, if the motion can be interpreted correctly, it is still true. + - It might use color to represent different objects, such as a green line to represent the flat surface is covered with a felt-like material. + - The rotation of the object might cannot be judged from the video, but the motion can be interpreted correctly, it is still true. + +3. **Container Boundaries**: + - If **no container is drawn**, you may assume the **video frame itself is the container boundary**. + - If a **container is visible**, treat it as **transparent** if inner content is visible. + - If the object is not visible, you should not assume it is in the container. + +4. **Focus on Shape & Position**, **not material**: + - Ignore assumptions about object **material**, **color**, or **texture**. + - Base your decisions entirely on **observable geometry** (e.g., shape, layout, structure) and **motion** (e.g., direction, trajectory). + - Use visible movement and positioning to judge truthfulness — even if the object type is unknown. + - If the described motion is **sliding down a slope**, but the video shows an **upward movement**, the result should be `"False"` — regardless of material or appearance. + - Make geometric and motion-based reasoning the core of your judgment, even when objects are **partially occluded**. + +5. **Occlusion Handling**: + - If an object is **partially blocked**, assess based on surrounding evidence whether its state or motion can still be inferred. + +6. **Avoid excessive uncertainty**: + - If there is enough visual context and logical structure, make a **confident judgment**. + - Use "Not sure" only when the evidence is **truly insufficient or ambiguous**. + +--- + + **Input**: +- Questions: {all_questions} +- Object and motion description: {summarize} + +--- + + **For each question**, return: +- `"index"`: the question index +- `"question"`: the full question text +- `"analysis"`: your reasoning process and visual inference +- `"result"`: one of `"True"`, `"False"`, or `"Not sure"` +- `"confidence_score"`: an integer from 1 (very uncertain) to 5 (very certain) + +--- + +**Output Format**: +Return a JSON list like this: +[ + {{ + "index": "1", + "question": "The ball rolls along the circular path.", + "analysis": "The object follows a closed curve consistent with a circular path from the top view.", + "result": "True", + "confidence_score": "5" + }}, + ... +] +""" + + +def _format_keywords(text: str, + keywords: Optional[List[str]] = None, + style: str = "bold", + case_sensitive: bool = False) -> str: + """Format given keywords in text. + + Args: + text: The input text to process. + keywords: List of keywords to highlight/format. + style: One of "bold" (default), "bracket", or "caps". + case_sensitive: Whether to match case-sensitively. + + Returns: + Formatted text with keywords highlighted. + """ + if not keywords: + return text + + # Deduplicate and sort by length (longest first) to avoid partial overlapping matches + unique_keywords = [k for k in sorted(set(keywords), key=lambda s: len(s or ""), reverse=True) if k] + if not unique_keywords: + return text + + flags = 0 if case_sensitive else re.IGNORECASE + pattern = re.compile("|".join(re.escape(k) for k in unique_keywords), flags) + + def repl(match: re.Match) -> str: + start, end = match.start(), match.end() + # Avoid double-formatting if already bolded ("**word**") + if style == "bold": + prev2 = text[max(0, start - 2):start] + next2 = text[end:end + 2] + if prev2 == "**" and next2 == "**": + return match.group(0) + return f"**{match.group(0)}**" + elif style == "bracket": + return f"[{match.group(0)}]" + elif style == "caps": + return match.group(0).upper() + else: + return f"**{match.group(0)}**" + + return pattern.sub(repl, text) + + +class _SafeDict(dict): + def __missing__(self, key): + return "{" + key + "}" + + +def create_vlm_prompt_from_template( + prompt_template: str, + variables: Optional[Dict[str, Any]] = None, + keywords: Optional[List[str]] = None, + style: str = "bold", + case_sensitive: bool = False, +) -> str: + """Create a VLM prompt from a template with optional keyword formatting. + + Args: + prompt_template: Template string containing placeholders like {summarize}, {all_questions}. + variables: Mapping used to format the template. + keywords: Keywords to highlight after formatting. + style: Keyword highlight style: "bold" (default), "bracket", or "caps". + case_sensitive: Whether keyword matching is case-sensitive. + + Returns: + Final prompt string. + """ + text = prompt_template + if variables: + try: + text = prompt_template.format_map(_SafeDict(variables)) + except Exception: + # Fall back to the original template if formatting fails + text = prompt_template + return _format_keywords(text, keywords=keywords, style=style, case_sensitive=case_sensitive) + + +def create_vlm_prompt_custom( + prompt: str, + keywords: Optional[List[str]] = None, + style: str = "bold", + case_sensitive: bool = False, +) -> str: + """Create a VLM prompt from a raw prompt string and optional keywords. + + This function does not inject any default template; it only formats + the provided prompt and highlights keywords as requested. + + Args: + prompt: The prompt content to send to the VLM. + keywords: List of keywords to format/highlight. + style: Highlight style: "bold" (default), "bracket", or "caps". + case_sensitive: Whether keyword matching is case-sensitive. + + Returns: + Final prompt string. + """ + return _format_keywords(prompt, keywords=keywords, style=style, case_sensitive=case_sensitive) + + +def create_vlm_prompt(summarize: str, all_questions: str) -> str: + """Create the default VLM prompt (backward-compatible). + + Existing call sites expect (summarize, all_questions). This delegates to + a template-based builder to enable future customization. + """ + return create_vlm_prompt_from_template( + DEFAULT_VLM_PROMPT_TEMPLATE, + variables={"summarize": summarize, "all_questions": all_questions}, + ) + + +def _extract_json_list(output_str: str) -> List[Dict[str, Any]]: + """Extract the VLM JSON list from the model output. + + Tries strict JSON first; falls back to extracting the substring between + the first '[' and the last ']' and parsing that. + """ + try: + parsed = json.loads(output_str) + if isinstance(parsed, list): + return parsed + except Exception: + pass + + # Fallback: try to extract JSON list portion + start = output_str.find('[') + end = output_str.rfind(']') + if start != -1 and end != -1 and end > start: + try: + parsed = json.loads(output_str[start:end+1]) + if isinstance(parsed, list): + return parsed + except Exception: + pass + raise ValueError("Failed to parse VLM JSON list from output") + diff --git a/agentfly/rewards/vlm_as_judge/vlm_as_judge_reward.py b/agentfly/rewards/vlm_as_judge/vlm_as_judge_reward.py new file mode 100644 index 0000000..572fba7 --- /dev/null +++ b/agentfly/rewards/vlm_as_judge/vlm_as_judge_reward.py @@ -0,0 +1,633 @@ +"""VLM as Judge Reward Function for AgentFly RL Training""" + +import os +import re +import json +import uuid +import tempfile +import subprocess +import asyncio +import logging +import concurrent.futures +from typing import Dict, Any, Optional, List, Tuple +from pathlib import Path + +# Support running both as a package module and as a standalone script +try: + from ..reward_base import reward + from .vlm_as_judge_client import VLMClient, create_vlm_prompt, _extract_json_list +except ImportError: # Running as a script without package context + import sys + # Add repo root to sys.path so absolute imports work when invoked directly + sys.path.append(str(Path(__file__).resolve().parents[2])) + from agentfly.rewards.reward_base import reward + from agentfly.rewards.vlm_as_judge.vlm_as_judge_client import ( + VLMClient, + create_vlm_prompt, + _extract_json_list, + create_vlm_prompt_from_template, + create_vlm_prompt_custom, + DEFAULT_VLM_PROMPT_TEMPLATE, + ) + +logger = logging.getLogger(__name__) + + +class VideoGenerator: + """Helper class to generate videos from code""" + + def __init__(self, output_dir: Optional[str] = None): + """Initialize video generator + + Args: + output_dir: Directory to save generated videos + """ + # Prefer a shared directory accessible by the VLM server if provided. + if output_dir is None: + output_dir = os.getenv("VLM_SHARED_VIDEO_DIR", "/tmp/vlm_videos") + self.output_dir = output_dir + os.makedirs(output_dir, exist_ok=True) + + def extract_code_from_response(self, response: str) -> Optional[str]: + """Extract Python code from model response + + Args: + response: Model response containing code + + Returns: + Extracted Python code or None + """ + # Remove tags if present + cleaned = re.sub(r'.*?', '', response, flags=re.DOTALL) + + # Extract code from ```python blocks + pattern = r"```python\n(.*?)```" + matches = re.findall(pattern, cleaned, re.DOTALL) + + if matches: + return matches[0] + return None + + def generate_video_from_code(self, code: str, output_path: str) -> bool: + """Execute Python code to generate video + + Args: + code: Python code to execute + output_path: Path to save the generated video + + Returns: + True if video generation successful, False otherwise + """ + try: + # Create a temporary Python file + with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: + # Modify code to use the specified output path + modified_code = code + + # Handle sys.argv usage for output filename + if 'sys.argv[1]' in code: + modified_code = code.replace('sys.argv[1]', f'"{output_path}"') + elif 'sys.argv' in code and 'len(sys.argv)' in code: + # Add sys.argv mock at the beginning + modified_code = f"import sys\nsys.argv = ['script.py', '{output_path}']\n" + code + else: + # If no sys.argv usage, try to modify output filename assignments + # Look for common patterns like output_file = ... or out = cv2.VideoWriter(...) + if 'output_file' in code: + # Replace output_file assignment + modified_code = re.sub( + r'output_file\s*=\s*["\'].*?["\']', + f'output_file = "{output_path}"', + code + ) + elif 'VideoWriter(' in code: + # Try to replace the first string argument in VideoWriter + modified_code = re.sub( + r'VideoWriter\s*\(\s*["\'].*?["\']', + f'VideoWriter("{output_path}"', + code + ) + else: + # Last resort: append output path assignment + modified_code = f"output_file = '{output_path}'\n" + code + + f.write(modified_code) + temp_file = f.name + + # Execute the code + # Always pass the output path as an argument so scripts that expect + # sys.argv[1] or check len(sys.argv) continue without exiting. + result = subprocess.run( + ['python', temp_file, output_path], + capture_output=True, + text=True, + timeout=120, # Increased timeout for video generation + cwd=self.output_dir # Run in output directory + ) + + # Clean up temp file + os.unlink(temp_file) + + # Check if video was created and is not empty + if os.path.exists(output_path) and os.path.getsize(output_path) > 1000: # At least 1KB + logger.info(f"Successfully generated video: {output_path} ({os.path.getsize(output_path)} bytes)") + return True + else: + logger.error(f"Video generation failed or file too small. stderr: {result.stderr}") + return False + + except subprocess.TimeoutExpired: + logger.error("Video generation timed out") + if 'temp_file' in locals(): + try: + os.unlink(temp_file) + except: + pass + return False + except Exception as e: + logger.error(f"Error generating video: {e}") + if 'temp_file' in locals(): + try: + os.unlink(temp_file) + except: + pass + return False + + +def extract_vlm_questions_from_data(data: Dict[str, Any]) -> Tuple[str, str, List[Dict]]: + """Extract VLM questions and summary from data + + Args: + data: Dictionary containing vlm_questions data + + Returns: + Tuple of (all_questions_str, summarize, questions_list) + """ + all_questions = "" + summarize = "" + questions_list = [] + + if "vlm_questions" in data: + vlm_data = data["vlm_questions"] + if isinstance(vlm_data, dict): + # Get summary + summarize = vlm_data.get("summarize", "") + + # Extract questions from nested vlm_questions field + if "vlm_questions" in vlm_data: + questions_list = vlm_data["vlm_questions"] + if isinstance(questions_list, list): + for q in questions_list: + if isinstance(q, dict): + idx = q.get("index", "") + question = q.get("question", "") + all_questions += f"{idx}. {question}\n" + else: + logger.warning(f"vlm_questions inner field is not a list: {type(questions_list)}") + else: + logger.warning(f"vlm_questions is not a dict: {type(vlm_data)}") + else: + logger.warning(f"No vlm_questions field in data. Available fields: {list(data.keys())}") + + all_questions = all_questions.strip() + + if not summarize: + summarize = "Evaluate the visual content based on the questions provided." + + logger.info(f"Extracted {len(questions_list)} questions from VLM data") + + return all_questions, summarize, questions_list + + +def calculate_weighted_reward(results: List[Dict], questions_list: List[Dict]) -> float: + """Calculate weighted reward based on VLM results and question weights + + Args: + results: List of VLM evaluation results + questions_list: Original questions with weights + + Returns: + Weighted reward score between 0.0 and 1.0 + """ + if not results or not questions_list: + return 0.0 + + # Create weight mapping + weight_map = {} + for q in questions_list: + idx = str(q.get("index", "")) + weight = float(q.get("weight", 1.0)) + weight_map[idx] = weight + + scores = [] + weights = [] + + for result in results: + idx = str(result.get("index", "")) + result_value = result.get("result", "Not sure") + confidence = int(result.get("confidence_score", "1")) + + # Get weight for this question + weight = weight_map.get(idx, 1.0) + + # Calculate score based on result + if result_value == "True": + score = 1.0 + elif result_value == "False": + score = 0.0 + else: # "Not sure" + if confidence >= 4: + score = 0.0 # High confidence "Not sure" -> False + else: + score = 1.0 # Low confidence "Not sure" -> True + + scores.append(score) + weights.append(weight) + + # Calculate weighted average + # if weights: + # weighted_sum = sum(s * w for s, w in zip(scores, weights)) + # total_weight = sum(weights) + # reward = weighted_sum / total_weight if total_weight > 0 else 0.0 + # else: + reward = sum(scores) / len(scores) if scores else 0.0 + + return reward + +def pass_fail_reward(results: List[Dict], questions_list: List[Dict]) -> float: + """Calculate a binary pass/fail score from VLM results. + + Returns 1.0 only when every question is judged as satisfied (or low-confidence + "Not sure"), otherwise returns 0.0. + """ + if not results or not questions_list: + return 0.0 + + result_map = { + str(r.get("index", "")).strip(): r + for r in results + if str(r.get("index", "")).strip() + } + + for question in questions_list: + idx = str(question.get("index", "")).strip() + if not idx: + logger.warning("Question without index encountered in pass/fail reward") + return 0.0 + + result = result_map.get(idx) + if result is None: + logger.warning("Missing VLM result for question index %s", idx) + return 0.0 + + result_value = str(result.get("result", "Not sure")).strip().lower() + confidence_raw = result.get("confidence_score", "1") + try: + confidence = int(confidence_raw) + except (TypeError, ValueError): + confidence = 1 + + if result_value == "true": + continue + if result_value == "not sure" and confidence < 4: + continue + + # Any explicit false or high-confidence uncertainty causes failure + return 0.0 + + return 1.0 + +@reward(name="vlm_as_judge_pass_reward") +async def vlm_as_judge_pass_reward( + prediction: str, + trajectory: Dict[str, Any] = None, + vlm_questions: Dict[str, Any] = None, + **data_fields +) -> Dict[str, float]: + """VLM as Judge reward function for evaluating agent trajectories + + This reward function: + 1. Extracts Python code from the prediction + 2. Generates a video using the code + 3. Uses VLM server to evaluate the video against provided questions + 4. Returns a binary pass/fail score based on VLM judgments + + Args: + prediction: Agent's generated response (should contain Python code) + trajectory: Agent trajectory information + **data_fields: Additional data fields from the RL data, including vlm_questions + + Returns: + pass/fail reward score between 0.0 and 1.0 + """ + try: + # Log incoming data for debugging + logger.info(f"=" * 60) + logger.info(f"vlm_as_judge_reward called") + logger.info(f"Prediction length: {len(prediction) if prediction else 0}") + + # Print the actual prediction content + logger.info(f"Prediction content (first 500 chars):") + logger.info(f"{prediction[:500] if prediction else 'No prediction'}") + if prediction and len(prediction) > 500: + logger.info(f"... (truncated, total length: {len(prediction)} chars)") + + logger.info(f"vlm_questions parameter: {vlm_questions is not None}") + logger.info(f"Additional data_fields keys: {list(data_fields.keys())}") + + # Initialize video generator + video_gen = VideoGenerator() + + # Combine vlm_questions with data_fields for extraction + all_data = dict(data_fields) + if vlm_questions is not None: + all_data['vlm_questions'] = vlm_questions + logger.info(f"vlm_questions type: {type(vlm_questions)}") + if isinstance(vlm_questions, dict): + logger.info(f"vlm_questions keys: {vlm_questions.keys()}") + if 'vlm_questions' in vlm_questions: + inner_vlm = vlm_questions['vlm_questions'] + logger.info(f"Inner vlm_questions type: {type(inner_vlm)}") + if isinstance(inner_vlm, list): + logger.info(f"Number of questions in inner list: {len(inner_vlm)}") + + # Extract VLM questions from data + all_questions, summarize, questions_list = extract_vlm_questions_from_data(all_data) + + if not questions_list: + logger.warning(f"No VLM questions found in data. Available fields: {list(all_data.keys())}") + return {"reward": 0.0} + + # Extract code from prediction + code = video_gen.extract_code_from_response(prediction) + if not code: + logger.warning("No Python code found in prediction") + logger.warning(f"Prediction was: {prediction[:1000] if prediction else 'None'}") + return {"reward": 0.0} + + logger.info(f"Extracted Python code ({len(code)} chars)") + logger.info(f"Code preview (first 300 chars):") + logger.info(f"{code[:300]}...") + if len(code) > 300: + logger.info(f"... (truncated, total length: {len(code)} chars)") + + # Generate unique video filename + video_filename = f"video_{uuid.uuid4().hex}.mp4" + video_path = os.path.join(video_gen.output_dir, video_filename) + + # Generate video from code + success = video_gen.generate_video_from_code(code, video_path) + if not success: + logger.error("Failed to generate video from code") + return {"reward": 0.0} + + # Run VLM evaluation directly since we're already async + client = VLMClient( + model="Qwen/Qwen2.5-VL-72B-Instruct", + timeout_seconds=120 + ) + + # Wait for client availability + for _ in range(10): + if client.is_available(): + break + await asyncio.sleep(1) + else: + logger.error("VLM client not available") + return {"reward": 0.0} + + # Create VLM prompt + prompt_text = create_vlm_prompt(summarize, all_questions) + + # Build message using ' in content %}{% set content = content.split('')[-1] %}{% endif %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|>\\n'}}{% endif %}" ) ) +register_template( + Template( + name="llemma", + system_template="{system_message}", + user_template="Input:{content}\n\n", + assistant_template="Response:{content}", + stop_words=[""] + ) +) + + if __name__ == "__main__": pass \ No newline at end of file diff --git a/agentfly/templates/utils.py b/agentfly/templates/utils.py index 2761822..75906ac 100644 --- a/agentfly/templates/utils.py +++ b/agentfly/templates/utils.py @@ -270,7 +270,7 @@ def tokenize_conversations( concatenated_mm_inputs = {} if concatenate_mm_inputs: for key in batch_mm_inputs[0].keys(): - if mm_inputs[key]: + if isinstance(mm_inputs[key], torch.Tensor): concatenated_mm_inputs[key] = torch.cat([mm_inputs[key] for mm_inputs in batch_mm_inputs if mm_inputs[key] is not None], dim=0) inputs = dict( diff --git a/agentfly/tests/scripts/test_cpu_runs.sh b/agentfly/tests/scripts/test_cpu_runs.sh new file mode 100644 index 0000000..c3f0fc9 --- /dev/null +++ b/agentfly/tests/scripts/test_cpu_runs.sh @@ -0,0 +1,10 @@ +#! /bin/bash + +# Test CPU runs + + +pytest -x agentfly/tests/unit/tools/ +pytest -x agentfly/tests/unit/envs/ +pytest -x agentfly/tests/unit/rewards/ + +pytest -x agentfly/tests/unit/templates/ \ No newline at end of file diff --git a/agentfly/tests/scripts/test_gpu_runs.sh b/agentfly/tests/scripts/test_gpu_runs.sh index e033a51..4f10934 100644 --- a/agentfly/tests/scripts/test_gpu_runs.sh +++ b/agentfly/tests/scripts/test_gpu_runs.sh @@ -2,9 +2,9 @@ # Test GPU runs -python -m pytest -x tests/unit/agents/test_initialization.py || exit 1 -python -m pytest -x tests/unit/agents/test_auto_agent.py || exit 1 -python -m pytest -x tests/unit/agents/test_code_agent.py || exit 1 -python -m pytest -x tests/unit/agents/test_react_agent.py || exit 1 -python -m pytest -x tests/unit/agents/test_webshop_agent.py || exit 1 -python -m pytest -x tests/unit/agents/test_vision_agent.py || exit 1 \ No newline at end of file +pytest -x agentfly/tests/unit/agents/test_initialization.py || exit 1 +pytest -x agentfly/tests/unit/agents/test_auto_agent.py || exit 1 +pytest -x agentfly/tests/unit/agents/test_code_agent.py || exit 1 +pytest -x agentfly/tests/unit/agents/test_react_agent.py || exit 1 +pytest -x agentfly/tests/unit/agents/test_webshop_agent.py || exit 1 +pytest -x agentfly/tests/unit/agents/test_vision_agent.py || exit 1 \ No newline at end of file diff --git a/agentfly/tests/unit/agents/backends/__init__.py b/agentfly/tests/unit/agents/backends/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/agentfly/tests/unit/agents/backends/test_client_backend.py b/agentfly/tests/unit/agents/backends/test_client_backend.py new file mode 100644 index 0000000..228decb --- /dev/null +++ b/agentfly/tests/unit/agents/backends/test_client_backend.py @@ -0,0 +1,318 @@ +#!/usr/bin/env python3 +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import time +import pytest +import threading +from unittest.mock import Mock, patch, AsyncMock, MagicMock +from typing import List, Dict, Any +import statistics +import sys +import os + +# Add the parent directory to the Python path +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))) + +from agentfly.agents.llm_backends.llm_backends import ClientBackend + + +class TestClientBackendWorkload: + """Test suite for ClientBackend workload and rate limiting functionality.""" + + @pytest.fixture + def mock_openai_client(self): + """Mock OpenAI client for testing.""" + mock_client = Mock() + mock_response = Mock() + mock_response.dict.return_value = { + "choices": [ + { + "message": { + "content": "Test response", + "tool_calls": None + } + } + ] + } + mock_client.chat.completions.create.return_value = mock_response + return mock_client + + @pytest.fixture + def client_backend(self, mock_openai_client): + """Create a ClientBackend instance for testing.""" + with patch('openai.OpenAI', return_value=mock_openai_client): + backend = ClientBackend( + model_name_or_path="test-model", + template="test-template", + base_url="http://localhost:8000/v1", + max_requests_per_minute=10, # Low limit for testing + timeout=30, + api_key="test-key" + ) + return backend + + def test_basic_functionality(self, client_backend, mock_openai_client): + """Test basic ClientBackend functionality.""" + messages = [{"role": "user", "content": "Hello"}] + + # Test sync generation + response = client_backend.generate(messages) + + assert isinstance(response, list) + assert len(response) == 1 + assert response[0] == "Test response" + mock_openai_client.chat.completions.create.assert_called_once() + + @pytest.mark.asyncio + async def test_async_generation(self, client_backend, mock_openai_client): + """Test async generation functionality.""" + messages = [{"role": "user", "content": "Hello"}] + + response = await client_backend.generate_async(messages) + + assert isinstance(response, list) + assert len(response) == 1 + assert response[0] == "Test response" + + def test_rate_limiting_basic(self, client_backend): + """Test basic rate limiting functionality.""" + # Verify semaphore is initialized correctly + assert client_backend._tokens._value == 10 # max_requests_per_minute + assert client_backend._max_tokens == 10 + + @pytest.mark.asyncio + async def test_rate_limiting_under_load(self, client_backend, mock_openai_client): + """Test rate limiting under high load with 100 concurrent requests.""" + # Configure mock to simulate API delay + def mock_api_call(*args, **kwargs): + time.sleep(1) # 1s delay per request + return Mock(dict=lambda: { + "choices": [{"message": {"content": "Test response", "tool_calls": None}}] + }) + + mock_openai_client.chat.completions.create = mock_api_call + + # Create 100 concurrent requests (10x the rate limit) + num_requests = 100 + messages = [{"role": "user", "content": f"Request {i}"} for i in range(num_requests)] + + start_time = time.time() + + # Send all requests concurrently + tasks = [client_backend.generate_async([msg]) for msg in messages] + responses = await asyncio.gather(*tasks) + + total_time = time.time() - start_time + + # Verify all requests completed + assert len(responses) == num_requests + assert all(isinstance(r, list) and len(r) == 1 for r in responses) + + # Verify rate limiting worked (should take longer than if unlimited) + # With 10 RPM limit, 100 requests should take at least 6 minutes in theory + # But with our 1s delay, it should take at least 10 seconds + assert total_time >= 10 # At least 10s due to our mock delay + + print(f"Rate limiting test: {num_requests} requests completed in {total_time:.2f}s") + print(f"Effective rate: {num_requests/total_time:.2f} requests/second") + + + @pytest.mark.asyncio + async def test_concurrent_burst_requests(self, client_backend, mock_openai_client): + """Test handling of burst requests that exceed the rate limit.""" + # Configure mock with realistic delay + async def mock_api_call(*args, **kwargs): + await asyncio.sleep(0.05) # 50ms delay per request + return Mock(dict=lambda: { + "choices": [{"message": {"content": "Burst response", "tool_calls": None}}] + }) + + mock_openai_client.chat.completions.create = mock_api_call + + # Send burst of 20 requests (2x the rate limit) + burst_size = 20 + messages = [{"role": "user", "content": f"Burst {i}"} for i in range(burst_size)] + + start_time = time.time() + + # Send all requests at once + tasks = [client_backend.generate_async([msg]) for msg in messages] + responses = await asyncio.gather(*tasks) + + total_time = time.time() - start_time + + # Verify all requests completed + assert len(responses) == burst_size + assert all("Burst response" in r[0] for r in responses) + + # Verify rate limiting was applied + # Should take longer than 20 * 0.05 = 1 second due to rate limiting + assert total_time >= 0.5 # At least 500ms due to rate limiting + + print(f"Burst test: {burst_size} requests completed in {total_time:.2f}s") + + @pytest.mark.asyncio + async def test_sustained_load_performance(self, client_backend, mock_openai_client): + """Test performance under sustained load over time.""" + # Configure mock with realistic delay + async def mock_api_call(*args, **kwargs): + await asyncio.sleep(0.02) # 20ms delay per request + return Mock(dict=lambda: { + "choices": [{"message": {"content": "Sustained response", "tool_calls": None}}] + }) + + mock_openai_client.chat.completions.create = mock_api_call + + # Send requests in batches over time + batch_size = 5 + num_batches = 4 + total_requests = batch_size * num_batches + + all_responses = [] + batch_times = [] + + for batch in range(num_batches): + batch_start = time.time() + + messages = [{"role": "user", "content": f"Sustained batch {batch} req {i}"} + for i in range(batch_size)] + + tasks = [client_backend.generate_async([msg]) for msg in messages] + batch_responses = await asyncio.gather(*tasks) + + batch_time = time.time() - batch_start + batch_times.append(batch_time) + all_responses.extend(batch_responses) + + # Small delay between batches + await asyncio.sleep(0.1) + + # Verify all requests completed + assert len(all_responses) == total_requests + assert all("Sustained response" in r[0] for r in all_responses) + + # Analyze performance + avg_batch_time = statistics.mean(batch_times) + total_time = sum(batch_times) + + print(f"Sustained load test: {total_requests} requests in {num_batches} batches") + print(f"Average batch time: {avg_batch_time:.3f}s") + print(f"Total time: {total_time:.3f}s") + print(f"Effective rate: {total_requests/total_time:.2f} requests/second") + + + @pytest.mark.asyncio + async def test_mixed_sync_async_workload(self, client_backend, mock_openai_client): + """Test mixed sync and async calls under load.""" + # Configure mock + async def mock_api_call(*args, **kwargs): + await asyncio.sleep(0.01) + return Mock(dict=lambda: { + "choices": [{"message": {"content": "Mixed response", "tool_calls": None}}] + }) + + mock_openai_client.chat.completions.create = mock_api_call + + # Mix of sync and async calls + sync_messages = [{"role": "user", "content": f"Sync {i}"} for i in range(5)] + async_messages = [{"role": "user", "content": f"Async {i}"} for i in range(5)] + + # Run sync calls in a separate thread to avoid blocking + def run_sync_calls(): + return [client_backend.generate([msg]) for msg in sync_messages] + + # Execute sync calls in thread pool + loop = asyncio.get_running_loop() + sync_task = loop.run_in_executor(None, run_sync_calls) + + # Execute async calls + async_tasks = [client_backend.generate_async([msg]) for msg in async_messages] + async_results = await asyncio.gather(*async_tasks) + + # Wait for sync calls to complete + sync_results = await sync_task + + # Verify all calls completed + assert len(sync_results) == 5 + assert len(async_results) == 5 + + # All should be successful + all_results = sync_results + async_results + assert all("Mixed response" in r[0] for r in all_results) + + print(f"Mixed workload test: {len(all_results)} total requests completed") + + def test_refiller_startup_edge_cases(self, client_backend): + """Test refiller startup in different contexts.""" + # Test startup when no event loop exists + with patch('asyncio.get_event_loop', side_effect=RuntimeError("No event loop")): + # Should not crash + client_backend._ensure_refiller_running() + assert client_backend._refill_task is None + + # Test startup when event loop exists but is not running + mock_loop = Mock() + mock_loop.is_running.return_value = False + mock_loop.create_task.return_value = Mock() + + with patch('asyncio.get_event_loop', return_value=mock_loop): + client_backend._ensure_refiller_running() + mock_loop.create_task.assert_called_once() + + @pytest.mark.asyncio + async def test_large_scale_workload(self, client_backend, mock_openai_client): + """Test with a large number of requests (1000) to verify system stability.""" + # Configure mock with minimal delay + async def mock_api_call(*args, **kwargs): + await asyncio.sleep(0.001) # 1ms delay per request + return Mock(dict=lambda: { + "choices": [{"message": {"content": "Large scale response", "tool_calls": None}}] + }) + + mock_openai_client.chat.completions.create = mock_api_call + + # Create 1000 requests + num_requests = 1000 + messages = [{"role": "user", "content": f"Large scale {i}"} for i in range(num_requests)] + + start_time = time.time() + + # Send all requests concurrently + tasks = [client_backend.generate_async([msg]) for msg in messages] + responses = await asyncio.gather(*tasks) + + total_time = time.time() - start_time + + # Verify all requests completed + assert len(responses) == num_requests + assert all("Large scale response" in r[0] for r in responses) + + print(f"Large scale test: {num_requests} requests completed in {total_time:.2f}s") + print(f"Effective rate: {num_requests/total_time:.2f} requests/second") + + # Verify rate limiting was applied (should be limited to ~10 RPM = 0.167 RPS) + # With 1000 requests at 0.167 RPS, should take at least 6000 seconds + # But our test is more lenient due to the mock delay + assert total_time >= 1.0 # At least 1 second due to rate limiting + + def test_workload_with_real_api(self): + client_backend = ClientBackend( + model_name_or_path="deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", + template="deepseek", + base_url="http://localhost:8000/v1", + max_requests_per_minute=60, + timeout=300, + api_key="EMPTY" + ) + messages = [[{"role": "user", "content": "Let $x,y$ and $z$ be positive real numbers that satisfy the following system of equations:\n\[\log_2\left({x \over yz}\right) = {1 \over 2}\]\n\[\log_2\left({y \over xz}\right) = {1 \over 3}\]\n\[\log_2\left({z \over xy}\right) = {1 \over 4}\]\nThen the value of $\left|\log_2(x^4y^3z^2)\right|$ is $\tfrac{m}{n}$ where $m$ and $n$ are relatively prime positive integers. Find $m+n$."}]] * 100 + time_start = time.time() + response = client_backend.generate(messages) + time_end = time.time() + + assert len(response) == 100 + print(f"Time taken: {time_end - time_start} seconds") + print(f"Effective rate: {100/(time_end - time_start)} requests/second") + + \ No newline at end of file diff --git a/agentfly/tests/unit/envs/test_webshop_text_env.py b/agentfly/tests/unit/envs/test_webshop_text_env.py index 8b00a5b..bf3aa1a 100644 --- a/agentfly/tests/unit/envs/test_webshop_text_env.py +++ b/agentfly/tests/unit/envs/test_webshop_text_env.py @@ -86,7 +86,7 @@ async def test_env_full_shopping_flow(): assert 'observation' in observation assert 'reward' in observation - await env.close() + await env.aclose() # @pytest.mark.asyncio # async def test_pagination_navigation(): diff --git a/agentfly/tests/unit/rewards/test_vlm_as_judge_reward.py b/agentfly/tests/unit/rewards/test_vlm_as_judge_reward.py new file mode 100644 index 0000000..0cf2dd0 --- /dev/null +++ b/agentfly/tests/unit/rewards/test_vlm_as_judge_reward.py @@ -0,0 +1,282 @@ +import sys +import os +from ....rewards.vlm_as_judge.vlm_as_judge_client import VLMClient, create_vlm_prompt, _extract_json_list +from ....rewards.vlm_as_judge.vlm_as_judge_reward import VideoGenerator, extract_vlm_questions_from_data, calculate_weighted_reward, pass_fail_reward +from ....rewards.vlm_as_judge.vlm_as_judge_reward import vlm_as_judge_pass_reward +from pathlib import Path +import asyncio +import sys +from pathlib import Path +import traceback +import time +import json +import os +import re +import subprocess +import tempfile +import uuid +import warnings +from typing import Dict, List, Optional, Tuple, Any + + +if __name__ == "__main__": + """Test VLM client functionality""" + import asyncio + + async def test_client(): + print("="*70) + print("Testing VLM Client") + print("="*70) + + # Test data + test_questions = { + "vlm_questions": { + "summarize": "A ball rolls down a ramp", + "vlm_questions": [ + {"index": "1", "question": "A ball is visible", "weight": 1.0}, + {"index": "2", "question": "The ball moves downward", "weight": 1.0} + ] + } + } + + try: + # Test client initialization + client = VLMClient( + model="Qwen/Qwen2.5-VL-72B-Instruct", + timeout_seconds=60 + ) + print(f"✓ Client initialized") + + # Check availability + is_available = client.is_available() + print(f"✓ Client available: {is_available}") + + # Test prompt creation + all_q = "1. A ball is visible\n2. The ball moves downward" + prompt = create_vlm_prompt("A ball rolls down a ramp", all_q) + print(f"✓ Prompt created ({len(prompt)} chars)") + + # Test JSON extraction + test_response = '''[{"index": "1", "result": "True", "confidence_score": "5"}]''' + results = _extract_json_list(test_response) + print(f"✓ JSON extraction works: {len(results)} results") + + print("\nAll tests passed!") + + except Exception as e: + print(f"✗ Test failed: {e}") + import traceback + traceback.print_exc() + + asyncio.run(test_client()) + + + """Test VLM-as-judge reward function""" + import sys + + # Test data - real physics example with charged sphere + test_data = { + "question": "A charged 0.6 kg aluminum sphere is placed at the center of a 1.5-meter by 1-meter by 1-meter glass tank filled with air at 25°C and 1 atm. The tank is horizontally divided into two equal compartments by a non-conductive partition. When a 450 N/C vertical electric field is applied, the sphere rises, clears the partition, and settles in the upper compartment over 13 seconds, as the field balances the gravitational force.", + "Level": 3, + "vlm_questions": { + "enableAnnotator": "Yes", + "summarize": "A charged 0.6 kg aluminum sphere is placed at the center of a 1.5m x 1m x 1m glass tank filled with air at 25°C and 1 atm. The tank is divided by a non-conductive partition. A 450 N/C vertical electric field causes the sphere to rise, clear the partition, and settle in the upper compartment, balancing gravitational force over 13 seconds.", + "vlm_questions": [ + { + "index": "1", + "question": "A non-conductive partition divides the tank horizontally into two equal compartments.", + "weight": 1.0 + }, + { + "index": "2", + "question": "The sphere is initially placed at the center of the tank.", + "weight": 1.0 + }, + { + "index": "3", + "question": "The sphere rises vertically when the electric field is applied.", + "weight": 1.0 + }, + { + "index": "4", + "question": "The sphere clears the partition and enters the upper compartment.", + "weight": 1.0 + }, + { + "index": "5", + "question": "The sphere settles in the upper compartment after moving.", + "weight": 1.0 + } + ] + } + } + + # Sample physics simulation code + sample_code = ''' +import sys +import subprocess +import importlib + +required_libraries = ['cv2', 'numpy'] +for lib in required_libraries: + try: + importlib.import_module(lib) + except ImportError: + subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'opencv-python', 'numpy']) + break + +import cv2 +import numpy as np + +if len(sys.argv) < 2: + print("Usage: python script.py output_filename.mp4") + sys.exit(1) + +output_file = sys.argv[1] + +# Physical parameters +tank_length = 1.5 +tank_width = 1.0 +tank_height = 1.0 +partition_height = 0.5 +initial_z = 0.25 +final_z = 0.75 +total_time = 13.0 +mass = 0.6 +gravity = 9.8 +E_field = 450.0 +charge = (mass * gravity) / E_field + +# Video parameters +fps = 30 +width, height = 1280, 720 +num_frames = int(total_time * fps) +fourcc = cv2.VideoWriter_fourcc(*'mp4v') +out = cv2.VideoWriter(output_file, fourcc, fps, (width, height)) + +# Scaling factors for visualization +margin = 50 +x_scale = (width - 2 * margin) / tank_length +z_scale = (height - 2 * margin) / tank_height +sphere_radius_px = 15 +force_scale = 30 + +def world_to_pixel(x, z): + px = int(margin + x * x_scale) + pz = int(height - margin - z * z_scale) + return px, pz + +for frame_idx in range(num_frames): + t = frame_idx / fps + progress = min(1.0, t / total_time) + current_z = initial_z + (final_z - initial_z) * progress + current_pos = [tank_length/2, tank_width/2, current_z] + velocity = (final_z - initial_z) / total_time + + # Create white background + img = np.ones((height, width, 3), dtype=np.uint8) * 255 + + # Draw tank + tank_tl = world_to_pixel(0, tank_height) + tank_br = world_to_pixel(tank_length, 0) + cv2.rectangle(img, tank_tl, tank_br, (200, 200, 255), 2) + + # Draw partition + part_start = world_to_pixel(0, partition_height) + part_end = world_to_pixel(tank_length, partition_height) + cv2.line(img, part_start, part_end, (100, 100, 100), 2) + + # Draw sphere + sphere_pos = world_to_pixel(tank_length/2, current_z) + cv2.circle(img, sphere_pos, sphere_radius_px, (0, 0, 255), -1) + + # Draw force vectors + g_vector_end = (sphere_pos[0], sphere_pos[1] + force_scale) + cv2.arrowedLine(img, sphere_pos, g_vector_end, (0, 150, 0), 2, tipLength=0.3) + + e_vector_end = (sphere_pos[0], sphere_pos[1] - force_scale) + cv2.arrowedLine(img, sphere_pos, e_vector_end, (255, 0, 0), 2, tipLength=0.3) + + # Draw text overlays + cv2.putText(img, f"Time: {t:.2f}s / {total_time}s", (20, 40), + cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 0), 2) + cv2.putText(img, f"Mass: {mass} kg", (width-300, 40), + cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2) + cv2.putText(img, f"Gravity: {gravity} m/s^2", (width-300, 80), + cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2) + cv2.putText(img, f"E-Field: {E_field} N/C", (width-300, 120), + cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2) + cv2.putText(img, f"Charge: {charge:.5f} C", (width-300, 160), + cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2) + cv2.putText(img, f"Velocity: {velocity:.4f} m/s", (width-300, 200), + cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2) + cv2.putText(img, f"Position: (0.75, 0.50, {current_z:.2f}) m", (width-300, 240), + cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2) + + # Write frame + out.write(img) + +out.release() +''' + + print("="*70) + print("Testing VLM-as-Judge Reward") + print("="*70) + + # Test video generator + print("\n1. Testing VideoGenerator...") + gen = VideoGenerator() + + # Test code extraction + test_response = f"```python\n{sample_code}\n```" + code = gen.extract_code_from_response(test_response) + print(f" ✓ Extracted {len(code) if code else 0} chars of code") + + # Test question extraction + print("\n2. Testing question extraction...") + all_q, summary, q_list = extract_vlm_questions_from_data(test_data) + print(f" ✓ Extracted {len(q_list)} questions") + print(f" ✓ Summary: {summary[:50]}...") + + # Test reward calculation + print("\n3. Testing reward calculation...") + test_results = [ + {"index": "1", "result": "True", "confidence_score": "5"}, + {"index": "2", "result": "True", "confidence_score": "4"}, + {"index": "3", "result": "False", "confidence_score": "3"} + ] + # reward = calculate_weighted_reward(test_results, q_list) + reward = pass_fail_reward(test_results, q_list) + print(f" ✓ Calculated reward: {reward:.3f}") + + print("\n4. Testing full reward function...") + print(" Note: This requires VLM server to be running") + + async def test_reward(): + """Async wrapper for testing the reward function""" + try: + # Test with physics simulation prediction including think tags + prediction_with_think = f"\n{test_data.get('think', 'Analyzing the physics problem...')}\n\n```python\n{sample_code}\n```" + + # Alternative: Test with just code + prediction = f"```python\n{sample_code}\n```" + + reward_value = await vlm_as_judge_pass_reward( + prediction=prediction, + trajectory={}, + **test_data + ) + print(f" ✓ Reward function returned: {reward_value}") + return reward_value + except Exception as e: + print(f" ⚠ Reward function test failed (expected if VLM server not running)") + print(f" Error: {e}") + return None + + # Run the async test + import asyncio + result = asyncio.run(test_reward()) + + print("\nTest complete!") + if result: + print(f"Final reward score: {result.get('reward', 0.0):.3f}") \ No newline at end of file diff --git a/agentfly/tests/unit/templates/__init__.py b/agentfly/tests/unit/templates/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/agentfly/tests/unit/agents/templates/test_qwen3_prompt.py b/agentfly/tests/unit/templates/test_qwen3_prompt.py similarity index 98% rename from agentfly/tests/unit/agents/templates/test_qwen3_prompt.py rename to agentfly/tests/unit/templates/test_qwen3_prompt.py index 3c3faae..7eb6b3e 100644 --- a/agentfly/tests/unit/agents/templates/test_qwen3_prompt.py +++ b/agentfly/tests/unit/templates/test_qwen3_prompt.py @@ -1,4 +1,4 @@ -from .....templates.utils import compare_hf_template +from ....templates.utils import compare_hf_template import pytest from transformers import AutoTokenizer diff --git a/agentfly/tests/unit/agents/templates/test_qwen3_tokenize.py b/agentfly/tests/unit/templates/test_qwen3_tokenize.py similarity index 97% rename from agentfly/tests/unit/agents/templates/test_qwen3_tokenize.py rename to agentfly/tests/unit/templates/test_qwen3_tokenize.py index fb47144..e006a05 100644 --- a/agentfly/tests/unit/agents/templates/test_qwen3_tokenize.py +++ b/agentfly/tests/unit/templates/test_qwen3_tokenize.py @@ -7,11 +7,11 @@ Since the align for textual prompt is already tested in other files, we only need to test the tokenization of the templates. """ -from .....templates.utils import tokenize_conversation +from ....templates.utils import tokenize_conversation import pytest from transformers import AutoTokenizer import torch -from .....templates.templates import Chat +from ....templates.templates import Chat @pytest.mark.parametrize("template", ["qwen3"]) @pytest.mark.parametrize("messages", [ diff --git a/agentfly/tests/unit/templates/test_single_turn_template_tokenize.py b/agentfly/tests/unit/templates/test_single_turn_template_tokenize.py new file mode 100644 index 0000000..9bf1c63 --- /dev/null +++ b/agentfly/tests/unit/templates/test_single_turn_template_tokenize.py @@ -0,0 +1,38 @@ +from ....templates.utils import tokenize_conversation +import pytest +from transformers import AutoTokenizer +import torch +from ....templates.templates import Chat + +@pytest.mark.parametrize("template", ["deepseek-r1-distill-qwen"]) +@pytest.mark.parametrize("messages", [ + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I am fine, thank you."}, + ], + [ + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": " This is test thinking content. I am fine, thank you."}, + ], +]) +@pytest.mark.parametrize("tools", [ + None, +]) +@pytest.mark.parametrize("add_generation_prompt", [False, True]) +def test_template_tokenize(template, messages, tools, add_generation_prompt): + template_tokenizer_mapping = { + "qwen2.5": "Qwen/Qwen2.5-3B-Instruct", + "llama-3.2": "meta-llama/Llama-3.2-3B-Instruct", + "deepseek-r1-distill-qwen": "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", + } + tokenizer = AutoTokenizer.from_pretrained(template_tokenizer_mapping[template], trust_remote_code=True) + + chat = Chat(template, messages, tools=tools) + prompt = chat.prompt(add_generation_prompt=add_generation_prompt, tools=tools) + + hf_inputs = tokenizer(prompt, return_tensors="pt") + + implemented_inputs = tokenize_conversation(messages, tokenizer, template, max_length=2048, tools=tools, add_generation_prompt=add_generation_prompt, return_tensors="pt") + + assert torch.equal(hf_inputs["input_ids"], implemented_inputs["input_ids"]), f"template: {template}\n\nmessages: {messages}\n\ntools: {tools}\n\nadd_generation_prompt: {add_generation_prompt}\n\nprompt: {prompt}\n\nimplemented_prompt: {tokenizer.decode(implemented_inputs['input_ids'][0])}\n\nhf_inputs: {hf_inputs}\n\nimplemented_inputs: {implemented_inputs}" \ No newline at end of file diff --git a/agentfly/tests/unit/templates/test_single_turn_templates.py b/agentfly/tests/unit/templates/test_single_turn_templates.py new file mode 100644 index 0000000..e968ef6 --- /dev/null +++ b/agentfly/tests/unit/templates/test_single_turn_templates.py @@ -0,0 +1,54 @@ +from ....templates import compare_hf_template +from transformers import AutoTokenizer +import pytest + + +# "qwen2.5-think", "qwen2.5", "qwen2.5-no-tool", +# "llama-3.2", "mistral", "glm-4", "internlm2.5", "phi-3.5", "phi-4" +@pytest.mark.parametrize("template", ["deepseek-r1-distill-qwen"]) +@pytest.mark.parametrize("messages", [ + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I am fine, thank you."}, + ], + [ + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": " This is test thinking content. I am fine, thank you."}, + ], + [ + {"role": "user", "content": "Hello, how are you?"}, + ], +]) +@pytest.mark.parametrize("tools", [ + None, +]) +@pytest.mark.parametrize("add_generation_prompt", [True, False]) +def test_chat_template_equal(template, messages, tools, add_generation_prompt): + # Filter invalid combinations + if add_generation_prompt and messages[-1]['role'] == 'assistant': + return + + + template_tokenizer_mapping = { + "qwen2.5": "Qwen/Qwen2.5-3B-Instruct", + "qwen2.5-think": "Qwen/Qwen2.5-3B-Instruct", + "qwen2.5-no-system-tool": "Qwen/Qwen2.5-3B-Instruct", + "deepseek-prover-v2": "deepseek-ai/DeepSeek-Prover-V2-7B", + "llama-3.2": "meta-llama/Llama-3.2-3B-Instruct", + "mistral": "mistralai/Mistral-7B-Instruct-v0.3", + "glm-4": "THUDM/glm-4-9b-chat", + "internlm2.5": "internlm/internlm2_5-7b-chat", + "phi-3.5": "microsoft/Phi-3.5-mini-instruct", + "phi-4": "microsoft/Phi-4", + "nemotron": "nvidia/Llama-3.1-Nemotron-Nano-8B-v1", + "deepseek-r1-distill-qwen": "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", + } + tokenizer = AutoTokenizer.from_pretrained(template_tokenizer_mapping[template], trust_remote_code=True) + + is_equal, is_equal_between_implemented_prompts, is_equal_between_jinja_prompts, official_prompt, implemented_prompt, implemented_jinja_prompt, highlighted_prompt = compare_hf_template(tokenizer, template, messages=messages, tools=tools,add_generation_prompt=add_generation_prompt) + + print(f"Official prompt:\n\n{official_prompt}") + print(f"Highlighted prompt:\n\n{highlighted_prompt}") + assert is_equal, f"Template: {template}\n\nMessages: {messages}\n\ntools: {tools}\n\nadd_generation_prompt: {add_generation_prompt}\n\nOfficial prompt:\n\n{official_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}" + assert is_equal_between_jinja_prompts, f"Template: {template}\n\nMessages: {messages}\n\ntools: {tools}\n\nadd_generation_prompt: {add_generation_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}\n\nJinja prompt:\n\n{implemented_jinja_prompt}" diff --git a/agentfly/tests/unit/agents/templates/test_template_utilities.py b/agentfly/tests/unit/templates/test_template_utilities.py similarity index 83% rename from agentfly/tests/unit/agents/templates/test_template_utilities.py rename to agentfly/tests/unit/templates/test_template_utilities.py index 2cdbf4b..246f1cc 100644 --- a/agentfly/tests/unit/agents/templates/test_template_utilities.py +++ b/agentfly/tests/unit/templates/test_template_utilities.py @@ -1,5 +1,5 @@ -from .....agents.templates.templates import get_template, register_template, Template -from .....agents.templates.vision_processor import get_processor +from ....templates.templates import get_template, register_template, Template +from ....templates.vision_processor import get_processor def test_template_registration(): register_template( diff --git a/agentfly/tests/unit/agents/templates/test_text_templates_full_align.py b/agentfly/tests/unit/templates/test_text_templates_full_align.py similarity index 99% rename from agentfly/tests/unit/agents/templates/test_text_templates_full_align.py rename to agentfly/tests/unit/templates/test_text_templates_full_align.py index 652ffa4..0e4e57d 100644 --- a/agentfly/tests/unit/agents/templates/test_text_templates_full_align.py +++ b/agentfly/tests/unit/templates/test_text_templates_full_align.py @@ -8,7 +8,7 @@ """ -from .....templates import compare_hf_template +from ....templates import compare_hf_template from transformers import AutoTokenizer import pytest diff --git a/agentfly/tests/unit/agents/templates/test_text_templates_partial_align.py b/agentfly/tests/unit/templates/test_text_templates_partial_align.py similarity index 96% rename from agentfly/tests/unit/agents/templates/test_text_templates_partial_align.py rename to agentfly/tests/unit/templates/test_text_templates_partial_align.py index ea26532..86f322d 100644 --- a/agentfly/tests/unit/agents/templates/test_text_templates_partial_align.py +++ b/agentfly/tests/unit/templates/test_text_templates_partial_align.py @@ -1,7 +1,7 @@ import pytest from transformers import AutoTokenizer -from agentfly.agents.templates.templates import get_template -from agentfly.agents.templates.utils import compare_hf_template +from ....templates.templates import get_template +from ....templates.utils import compare_hf_template # nemotron, phi-4, glm-4 @pytest.mark.parametrize("template_name", ["qwen2.5-think", "qwen2.5-no-system-tool",]) diff --git a/agentfly/tests/unit/agents/templates/test_text_templates_tokenize.py b/agentfly/tests/unit/templates/test_text_templates_tokenize.py similarity index 96% rename from agentfly/tests/unit/agents/templates/test_text_templates_tokenize.py rename to agentfly/tests/unit/templates/test_text_templates_tokenize.py index 079537b..f92fdc2 100644 --- a/agentfly/tests/unit/agents/templates/test_text_templates_tokenize.py +++ b/agentfly/tests/unit/templates/test_text_templates_tokenize.py @@ -7,11 +7,11 @@ Since the align for textual prompt is already tested in other files, we only need to test the tokenization of the templates. """ -from .....agents.templates.utils import tokenize_conversation +from ....templates.utils import tokenize_conversation import pytest from transformers import AutoTokenizer import torch -from .....agents.templates.templates import Chat +from ....templates.templates import Chat @pytest.mark.parametrize("template", ["llama-3.2", "qwen2.5"]) @pytest.mark.parametrize("messages", [ diff --git a/agentfly/tests/unit/agents/templates/test_vision_templates_full_align.py b/agentfly/tests/unit/templates/test_vision_templates_full_align.py similarity index 98% rename from agentfly/tests/unit/agents/templates/test_vision_templates_full_align.py rename to agentfly/tests/unit/templates/test_vision_templates_full_align.py index 0d98620..9514529 100644 --- a/agentfly/tests/unit/agents/templates/test_vision_templates_full_align.py +++ b/agentfly/tests/unit/templates/test_vision_templates_full_align.py @@ -9,7 +9,7 @@ """ -from .....agents.templates.utils import compare_hf_template +from ....templates.utils import compare_hf_template from transformers import AutoTokenizer import pytest # "qwen2.5-think", "qwen2.5", "qwen2.5-no-tool", diff --git a/agentfly/tests/unit/agents/templates/test_vision_templates_tokenize.py b/agentfly/tests/unit/templates/test_vision_templates_tokenize.py similarity index 97% rename from agentfly/tests/unit/agents/templates/test_vision_templates_tokenize.py rename to agentfly/tests/unit/templates/test_vision_templates_tokenize.py index 4117a5a..bfc380a 100644 --- a/agentfly/tests/unit/agents/templates/test_vision_templates_tokenize.py +++ b/agentfly/tests/unit/templates/test_vision_templates_tokenize.py @@ -8,8 +8,8 @@ """ -from .....agents.templates.templates import Chat -from .....agents.templates.utils import compare_hf_template, tokenize_conversation +from ....templates.templates import Chat +from ....templates.utils import compare_hf_template, tokenize_conversation from transformers import AutoTokenizer import pytest import torch diff --git a/agentfly/utils/deploy.py b/agentfly/utils/deploy.py new file mode 100644 index 0000000..bf2101c --- /dev/null +++ b/agentfly/utils/deploy.py @@ -0,0 +1,51 @@ + + +import os +from ..templates import get_template +from .. import AGENT_DATA_DIR +import click + + +def vllm_serve(model_name_or_path, template, tp, pp, dp, gpu_memory_utilization): + port = 8000 + + + if template is None: + template_option = "" + else: + jinja_template = get_template(template).jinja_template() + if not os.path.exists(f"{AGENT_DATA_DIR}/cache"): + os.makedirs(f"{AGENT_DATA_DIR}/cache") + with open(f"{AGENT_DATA_DIR}/cache/jinja_template.jinja", "w") as f: + f.write(jinja_template) + template_option = f"--chat-template {AGENT_DATA_DIR}/cache/jinja_template.jinja" + # command = f"vllm serve {model_name_or_path} --chat-template {AGENT_DATA_DIR}/cache/jinja_template.jinja --tensor-parallel-size {tp} --pipeline-parallel-size {pp} --data-parallel-size {dp} --port {port} --enable-auto-tool-choice --tool-call-parser hermes --expand-tools-even-if-tool-choice-none" + command = f"""vllm serve {model_name_or_path} \ +{template_option} \ +--tensor-parallel-size {tp} \ +--pipeline-parallel-size {pp} \ +--data-parallel-size {dp} --port {port} \ +--gpu-memory-utilization {gpu_memory_utilization} \ +--enable-auto-tool-choice --tool-call-parser hermes""" + + print(command) + os.system(command) + + + +@click.command() +@click.option("--model_name_or_path") +@click.option("--template", default=None) +@click.option("--tp", type=int, default=1) +@click.option("--pp", type=int, default=1) +@click.option("--dp", type=int, default=1) +@click.option("--gpu_memory_utilization", type=float, default=0.8) +def main(model_name_or_path, template, tp, pp, dp, gpu_memory_utilization): + vllm_serve(model_name_or_path, template, tp, pp, dp, gpu_memory_utilization) + + +if __name__=="__main__": + "python -m agentfly.utils.deploy --model_name_or_path Qwen/Qwen2.5-3B-Instruct --template qwen2.5 --tp 2 --dp 2" + "python -m agentfly.utils.deploy --model_name_or_path openai/gpt-oss-20b --tp 1 --dp 1" + "python -m agentfly.utils.deploy --model_name_or_path deepseek-ai/DeepSeek-R1-Distill-Qwen-7B --tp 1 --dp 1" + main() \ No newline at end of file diff --git a/agentfly/utils/monitor.py b/agentfly/utils/monitor.py index 504f1f5..2dd66f1 100644 --- a/agentfly/utils/monitor.py +++ b/agentfly/utils/monitor.py @@ -38,11 +38,13 @@ class MetricEvent: step Integer training step or episode counter. timestamp Unix seconds (auto‑filled if omitted). tags Arbitrary key/value pairs for filtering (e.g. run_id, module). + sinks List of sink names to send this event to. If None, sends to all sinks. """ kind: Literal["scalar", "hist", "text", "resource", "list"] name: str value: Any + sinks: Optional[List[str]] = None step: Optional[int] = None x: Optional[int] = None x_name: Optional[str] = "x_axis" @@ -55,7 +57,6 @@ def __post_init__(self) -> None: self.timestamp = time.time() - class BaseSink(abc.ABC): """Abstract writer backend.""" @@ -100,21 +101,40 @@ def serialize_for_json(obj): class JsonlSink(BaseSink): """Append events as JSON-Lines - human & machine friendly.""" - def __init__(self, path: str) -> None: - os.makedirs(os.path.dirname(path) or ".", exist_ok=True) - self._f = open(path, "a", buffering=1, encoding="utf-8") + def __init__(self, directory: str) -> None: + os.makedirs(os.path.dirname(directory) or ".", exist_ok=True) + self.directory = directory + if os.path.isdir(directory): + default_file = os.path.join(directory, "default.jsonl") + with open(default_file, 'w') as f: + pass + + self.log_files = {"default": open(default_file, "a", buffering=1, encoding="utf-8")} + else: + self.log_files = {} + self._lock = asyncio.Lock() async def log(self, evt: MetricEvent) -> None: + evt_name = evt.name.replace("/", "-") + if evt_name not in self.log_files: + file_name = os.path.join(self.directory, f"{evt_name}.jsonl") + with open(file_name, 'w') as f: + pass + self.log_files[evt_name] = open(file_name, "a", buffering=1, encoding="utf-8") + file_obj = self.log_files[evt_name] + async with self._lock: - self._f.write(json.dumps(serialize_for_json(asdict(evt)), ensure_ascii=False) + "\n") + file_obj.write(json.dumps(serialize_for_json(asdict(evt)), ensure_ascii=False) + "\n") async def flush(self) -> None: - self._f.flush() + for file_obj in self.log_files.values(): + file_obj.flush() async def close(self) -> None: await super().close() - self._f.close() + for file_obj in self.log_files.values(): + file_obj.close() class WandbSink(BaseSink): @@ -130,6 +150,9 @@ def __init__(self, project: str, **wandb_init_kwargs: Any) -> None: # noqa: D40 self.tables: Dict[str, wandb.Table] = {} async def log(self, evt: MetricEvent) -> None: # pragma: no cover + """ + Log the event to wandb. + """ if wandb.run is not None: payload = {evt.name: evt.value, **evt.tags} if evt.x is not None: @@ -249,7 +272,10 @@ async def _consumer_loop(cls) -> None: evt = await cls._queue.get() if evt is None: # sentinel break - for sink in list(cls._sinks.values()): + for sink_name, sink in list(cls._sinks.items()): + # Check if this sink should receive this event + if evt.sinks is not None and sink_name not in evt.sinks: + continue try: await sink.log(evt) except Exception as exc: diff --git a/agentfly/utils/trajectories.py b/agentfly/utils/trajectories.py new file mode 100644 index 0000000..71332b2 --- /dev/null +++ b/agentfly/utils/trajectories.py @@ -0,0 +1,69 @@ +""" +This module is used to collect trajectories from the agent. +""" +from typing import Dict, List +import asyncio +import json +import os +from ..agents.llm_backends.backend_configs import ClientConfig +from ..agents import HFAgent +import click + +def gather_responses(trajectories: List[Dict]): + responses = [] + for trajectory in trajectories: + responses.append({ + "id": trajectory["id"], + "response": trajectory["messages"][-1]["content"][0]["text"], + }) + + return responses + +@click.command() +@click.option("--model_name_or_path", type=str, required=True) +@click.option("--api_key", type=str, default="EMPTY") +@click.option("--max_turns", type=int, required=True) +@click.option("--num_chains", type=int, default=1, required=True) +@click.option("--data_file", type=str, required=True) +@click.option("--output_dir", type=str, required=True) +def main( + model_name_or_path: str, + api_key: str, + max_turns: int, + num_chains: int, + data_file: str, + output_dir: str, +): + async def run_agent(): + with open(data_file, 'r') as f: + messages = json.load(f) + + agent = HFAgent( + model_name_or_path=model_name_or_path, + tools=[], + backend="client", + backend_config=ClientConfig( + base_url="http://0.0.0.0:8000/v1", + api_key=api_key, + max_new_tokens=30720, + timeout=2400, + ), + local_cache_dir="test_cache", + ) + await agent.run( + messages=messages, + num_chains=num_chains, + max_turns=max_turns, + ) + responses = gather_responses(agent.trajectories) + + os.makedirs(output_dir, exist_ok=True) + with open(os.path.join(output_dir, os.path.basename(model_name_or_path) + '_responses.json'), 'w') as f: + json.dump(responses, f, indent=2) + + asyncio.run(run_agent()) + +if __name__ == "__main__": + # """python -m agentfly.utils.trajectories --model_name_or_path Qwen/Qwen2.5-3B-Instruct --api_key EMPTY --max_turns 1 --data_file ../../datasets/viphy_test.json --output_dir data/trajectories/""" + """python -m agentfly.utils.trajectories --model_name_or_path openai/gpt-oss-20b --api_key EMPTY --max_turns 1 --data_file ../../datasets/viphy_test.json --output_dir data/trajectories/""" + main() diff --git a/test_qwen3_template.py b/test_qwen3_template.py deleted file mode 100644 index 172228c..0000000 --- a/test_qwen3_template.py +++ /dev/null @@ -1,83 +0,0 @@ -#!/usr/bin/env python3 -""" -Test script for Qwen3Template implementation -""" - -import sys -import os -sys.path.append(os.path.dirname(os.path.abspath(__file__))) - -from agentfly.templates.templates import Qwen3Template - -def test_qwen3_template(): - """Test the Qwen3Template with various scenarios""" - - # Create a Qwen3Template instance - template = Qwen3Template( - name="qwen3-test", - system_template="<|im_start|>system\n{system_message}<|im_end|>\n", - user_template="<|im_start|>user\n{content}<|im_end|>\n", - assistant_template="<|im_start|>assistant\n{content}<|im_end|>\n", - stop_words=["<|im_end|>"], - generation_prompt="<|im_start|>assistant\n", - ) - - # Test case 1: Basic conversation without thinking - print("=== Test Case 1: Basic conversation without thinking ===") - messages1 = [ - {"role": "user", "content": "Hello, how are you?"}, - {"role": "assistant", "content": "I'm doing well, thank you!"} - ] - - prompt1, elements1, roles1 = template.render(messages1, add_generation_prompt=False, enable_thinking=False) - print("Prompt:") - print(prompt1) - print() - - # Test case 2: Conversation with thinking content that should be cleaned - print("=== Test Case 2: Conversation with thinking content (should be cleaned) ===") - messages2 = [ - {"role": "user", "content": "What is 2+2?"}, - {"role": "assistant", "content": "I need to add 2 and 2 together. This is basic arithmetic.The answer is 4."} - ] - - prompt2, elements2, roles2 = template.render(messages2, add_generation_prompt=False, enable_thinking=False) - print("Prompt (thinking content should be removed):") - print(prompt2) - print() - - # Test case 3: With add_generation_prompt=True and enable_thinking=False - print("=== Test Case 3: With generation prompt and enable_thinking=False ===") - messages3 = [ - {"role": "user", "content": "Tell me a joke"} - ] - - prompt3, elements3, roles3 = template.render(messages3, add_generation_prompt=True, enable_thinking=False) - print("Prompt (should include empty think tokens):") - print(prompt3) - print() - - # Test case 4: With add_generation_prompt=True and enable_thinking=True - print("=== Test Case 4: With generation prompt and enable_thinking=True ===") - prompt4, elements4, roles4 = template.render(messages3, add_generation_prompt=True, enable_thinking=True) - print("Prompt (should NOT include empty think tokens):") - print(prompt4) - print() - - # Test case 5: Last message is assistant with enable_thinking=False - print("=== Test Case 5: Last message is assistant with enable_thinking=False ===") - messages5 = [ - {"role": "user", "content": "What's the weather like?"}, - {"role": "assistant", "content": "I don't have access to current weather data."} - ] - - prompt5, elements5, roles5 = template.render(messages5, add_generation_prompt=False, enable_thinking=False) - print("Prompt (last assistant message should have empty think tokens):") - print(prompt5) - print() - - print("All tests completed!") - -if __name__ == "__main__": - test_qwen3_template() -