Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion scenarios/debate/Dockerfile.adk-debate-judge
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ RUN \
--mount=type=cache,target=/home/agentbeats/.cache/uv,uid=1000 \
uv sync --locked

COPY scenarios scenarios
COPY scenarios/debate scenarios/debate

ENTRYPOINT ["uv", "run", "scenarios/debate/adk_debate_judge.py"]
CMD ["--host", "0.0.0.0"]
Expand Down
2 changes: 1 addition & 1 deletion scenarios/debate/Dockerfile.debate-judge
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ RUN \
--mount=type=cache,target=/home/agentbeats/.cache/uv,uid=1000 \
uv sync --locked

COPY scenarios scenarios
COPY scenarios/debate scenarios/debate

ENTRYPOINT ["uv", "run", "scenarios/debate/debate_judge.py"]
CMD ["--host", "0.0.0.0"]
Expand Down
2 changes: 1 addition & 1 deletion scenarios/debate/Dockerfile.debater
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ RUN \
--mount=type=cache,target=/home/agentbeats/.cache/uv,uid=1000 \
uv sync --locked

COPY scenarios scenarios
COPY scenarios/debate scenarios/debate

ENTRYPOINT ["uv", "run", "scenarios/debate/debater.py"]
CMD ["--host", "0.0.0.0"]
Expand Down
2 changes: 1 addition & 1 deletion scenarios/tau2/Dockerfile.tau2-agent
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ RUN \
--mount=type=cache,target=/home/agentbeats/.cache/uv,uid=1000 \
uv sync --locked

COPY scenarios scenarios
COPY scenarios/tau2 scenarios/tau2

ENTRYPOINT ["uv", "run", "scenarios/tau2/tau2_agent.py"]
CMD ["--host", "0.0.0.0"]
Expand Down
30 changes: 30 additions & 0 deletions scenarios/tau2/Dockerfile.tau2-env
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# https://github.com/meta-pytorch/OpenEnv/blob/fb169f8c660df722f538160b3ce636de3312a756/src/envs/README.md

# Accept base image as build argument for CI/CD flexibility
ARG BASE_IMAGE=openenv-base:latest
FROM ${BASE_IMAGE}

# Install dependencies
COPY scenarios/tau2/requirements.txt /tmp/requirements.txt
RUN \
apt-get update && \
apt-get install -y git && \
pip install --no-cache-dir -r /tmp/requirements.txt && rm /tmp/requirements.txt

# Download tau2 data
RUN git clone --depth 1 --filter=blob:none --sparse https://github.com/sierra-research/tau2-bench.git /app/scenarios/tau2/tau2-bench && \
cd /app/scenarios/tau2/tau2-bench && \
git sparse-checkout set data

ENV TAU2_DATA_DIR=/app/scenarios/tau2/tau2-bench/data

# Copy environment code
COPY scenarios/tau2/ /app/scenarios/tau2/

# Health check
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
CMD curl -f http://localhost:8000/health || exit 1

# Run server
# https://github.com/meta-pytorch/OpenEnv/issues/244
CMD ["python", "-m", "uvicorn", "tau2_server:app", "--host", "0.0.0.0", "--port", "8000", "--app-dir", "scenarios/tau2"]
13 changes: 1 addition & 12 deletions scenarios/tau2/Dockerfile.tau2-evaluator
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,7 @@ RUN \
--mount=type=cache,target=/home/agentbeats/.cache/uv,uid=1000 \
uv pip install "tau2 @ git+https://github.com/sierra-research/tau2-bench.git"

# Download tau2 data
USER root
RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/*
USER agentbeats

RUN git clone --depth 1 --filter=blob:none --sparse https://github.com/sierra-research/tau2-bench.git /home/agentbeats/tau2-bench && \
cd /home/agentbeats/tau2-bench && \
git sparse-checkout set data

ENV TAU2_DATA_DIR=/home/agentbeats/tau2-bench/data

COPY scenarios scenarios
COPY scenarios/tau2 scenarios/tau2

ENTRYPOINT ["uv", "run", "scenarios/tau2/tau2_evaluator.py"]
CMD ["--host", "0.0.0.0"]
Expand Down
2 changes: 2 additions & 0 deletions scenarios/tau2/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
openenv-core>=0.1.1
tau2 @ git+https://github.com/sierra-research/tau2-bench.git
10 changes: 10 additions & 0 deletions scenarios/tau2/scenario.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,16 @@ role = "agent"
endpoint = "http://127.0.0.1:9019"
cmd = "python scenarios/tau2/tau2_agent.py --host 127.0.0.1 --port 9019"

[[environments]]
name = "tau2"
endpoint = "http://127.0.0.1:8000"
image = "tau2-env:latest"
publishes = ["127.0.0.1:8000:8000"]
[environments.env]
TAU2_DOMAIN = "airline"
TAU2_TASK_ID = "0"
TAU2_ENV_ARGS_JSON = "{}"

[config]
domain = "airline"
num_tasks = 3
25 changes: 25 additions & 0 deletions scenarios/tau2/tau2_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from typing import Any

from openenv_core.client_types import StepResult
from openenv_core.http_env_client import HTTPEnvClient

from tau2_models import Tau2Action, Tau2Observation, Tau2State


# https://github.com/meta-pytorch/OpenEnv/blob/fb169f8c660df722f538160b3ce636de3312a756/src/envs/README.md


class Tau2Env(HTTPEnvClient[Tau2Action, Tau2Observation]):
def _step_payload(self, action: Tau2Action) -> dict[str, Any]:
return {"action": action.action}

def _parse_result(self, payload: dict[str, Any]) -> StepResult[Tau2Observation]:
obs = Tau2Observation(**payload["observation"])
return StepResult(
observation=obs,
reward=payload.get("reward"),
done=payload.get("done", False),
)

def _parse_state(self, payload: dict[str, Any]) -> Tau2State:
return Tau2State(**payload)
55 changes: 55 additions & 0 deletions scenarios/tau2/tau2_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from dataclasses import dataclass, field
from typing import Any
import uuid

import gymnasium as gym

from openenv_core.env_server import Action, Environment, Observation, State
from tau2.gym import TAU_BENCH_ENV_ID, register_gym_agent

from tau2_models import Tau2Action, Tau2Observation, Tau2State


# https://github.com/sierra-research/tau2-bench/blob/main/src/tau2/gym/README.md
# https://github.com/meta-pytorch/OpenEnv/blob/fb169f8c660df722f538160b3ce636de3312a756/src/envs/README.md


register_gym_agent()


class Tau2Environment(Environment):
def __init__(
self,
domain: str,
task_id: str,
env_args: Any,
):
super().__init__()
self._state = Tau2State()
self._gym_env: gym.Env[str, str] = gym.make(
TAU_BENCH_ENV_ID,
domain=domain,
task_id=task_id,
**env_args,
)

def reset(self) -> Tau2Observation:
self._state = Tau2State(episode_id=str(uuid.uuid4()))
observation, info = self._gym_env.reset()
self._state.info = info
return Tau2Observation(observation=observation)

def step(self, action: Action) -> Tau2Observation:
assert isinstance(action, Tau2Action)
self._state.step_count += 1
observation, reward, terminated, truncated, info = self._gym_env.step(action.action)
self._state.info = info
return Tau2Observation(
observation=observation,
done=terminated or truncated,
reward=float(reward),
)

@property
def state(self) -> Tau2State:
return self._state
69 changes: 29 additions & 40 deletions scenarios/tau2/tau2_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import time
from typing import Any, Optional

import gymnasium as gym
import uvicorn
from dotenv import load_dotenv

Expand All @@ -40,17 +39,16 @@

from tau2.data_model.simulation import RewardInfo
from tau2.environment.tool import Tool
from tau2.gym import TAU_BENCH_ENV_ID, register_gym_agent
from tau2.run import get_tasks

from tau2_client import Tau2Env
from tau2_models import Tau2Action

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("tau2_evaluator")

RESPOND_ACTION_NAME = "respond"

# Register tau-bench gym environments
register_gym_agent()


def tools_to_str(tools: list[Tool]) -> str:
"""Convert tau-bench tools to JSON schema format."""
Expand Down Expand Up @@ -93,19 +91,19 @@ def validate_request(self, request: EvalRequest) -> tuple[bool, str]:
return False, f"Missing config keys: {missing_config_keys}"
return True, "ok"

async def run_eval(self, req: EvalRequest, updater: TaskUpdater) -> None:
logger.info(f"Starting tau2 evaluation: {req}")
async def run_eval(self, request: EvalRequest, updater: TaskUpdater) -> None:
logger.info(f"Starting tau2 evaluation: {request}")
start_time = time.time()

domain = req.config["domain"]
task_ids = req.config.get("task_ids", None)
num_tasks = req.config.get("num_tasks", None)
max_steps = req.config.get("max_steps", 200)
user_llm = req.config.get("user_llm", "openai/gpt-4o")
user_llm_args = req.config.get("user_llm_args", {"temperature": 0.0})
domain = request.config["domain"]
task_ids = request.config.get("task_ids", None)
num_tasks = request.config.get("num_tasks", None)
max_steps = request.config.get("max_steps", 200)
user_llm = request.config.get("user_llm", "openai/gpt-4o")
user_llm_args = request.config.get("user_llm_args", {"temperature": 0.0})

# Get the purple agent URL
agent_url = str(req.participants["agent"])
agent_url = str(request.participants["agent"])

# Get task IDs
resolved_task_ids = get_task_ids(domain, task_ids, num_tasks)
Expand Down Expand Up @@ -146,7 +144,7 @@ async def run_eval(self, req: EvalRequest, updater: TaskUpdater) -> None:
num_completed = len(metrics["tasks"])
pass_rate = (total_reward / num_completed * 100) if num_completed > 0 else 0

result_data = {
result_data: dict[str, Any] = {
"domain": domain,
"score": total_reward,
"max_score": num_completed,
Expand Down Expand Up @@ -188,35 +186,26 @@ async def _run_single_task(
task_id: str,
max_steps: int,
user_llm: str,
user_llm_args: dict,
user_llm_args: dict[Any, Any],
) -> float:
"""Run a single tau-bench task and return the reward."""

env = gym.make(
TAU_BENCH_ENV_ID,
domain=domain,
task_id=task_id,
max_steps=max_steps,
user_llm=user_llm,
user_llm_args=user_llm_args,
all_messages_as_observation=False,
)
env = Tau2Env("http://localhost:8000")

terminated = False
observation, info = env.reset()
observation_sr = env.reset()

# Build the initial task description for the purple agent
task_description = self._build_task_prompt(info, observation)
task_description = self._build_task_prompt(env.state.info, observation_sr.observation.observation)

# Start a new conversation with the purple agent
next_message = task_description
is_first_message = True

while not terminated:
while not observation_sr.done:
logger.debug(f"Sending to purple agent: {next_message[:200]}...")

# Send message to purple agent
response = await self._tool_provider.talk_to_agent(
response: str = await self._tool_provider.talk_to_agent(
message=next_message,
url=agent_url,
new_conversation=is_first_message,
Expand All @@ -227,28 +216,28 @@ async def _run_single_task(

# Parse the purple agent's action
try:
action = self._parse_agent_response(response)
action = Tau2Action(action=self._parse_agent_response(response))
except Exception as e:
logger.error(f"Failed to parse agent response: {e}")
# When parsing fails, respond with error as plain text (not a tool call)
action = "I encountered an error processing the request."
action = Tau2Action(action="I encountered an error processing the request.")

# Step the environment with either a JSON string (tool call) or plain text (user response)
observation, reward, terminated, truncated, info = env.step(action)
logger.debug(f"Environment step: reward={reward}, terminated={terminated}")
observation_sr = env.step(action)
logger.debug(f"Environment step: reward={observation_sr.reward}, done={observation_sr.done}")

if terminated:
if observation_sr.done:
break

next_message = observation
next_message = observation_sr.observation.observation

# Extract final reward
if info.get("reward_info"):
reward_info = RewardInfo.model_validate_json(info["reward_info"])
if env.state.info.get("reward_info"):
reward_info = RewardInfo.model_validate_json(env.state.info["reward_info"])
return reward_info.reward
return float(reward)
return 0. if observation_sr.reward is None else float(observation_sr.reward)

def _build_task_prompt(self, info: dict, observation: str) -> str:
def _build_task_prompt(self, info: dict[Any, Any], observation: str) -> str:
"""Build the initial task prompt for the purple agent."""
return f"""
{info["policy"]}
Expand Down
19 changes: 19 additions & 0 deletions scenarios/tau2/tau2_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from dataclasses import dataclass, field
from typing import Any

from openenv_core.env_server import Action, Observation, State


@dataclass
class Tau2Action(Action):
action: str


@dataclass
class Tau2Observation(Observation):
observation: str


@dataclass
class Tau2State(State):
info: dict[str, Any] = field(default_factory=dict[str, Any])
17 changes: 17 additions & 0 deletions scenarios/tau2/tau2_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import os
import json
from openenv_core.env_server import create_fastapi_app

from tau2_models import Tau2Action, Tau2Observation
from tau2_env import Tau2Environment


# https://github.com/meta-pytorch/OpenEnv/blob/fb169f8c660df722f538160b3ce636de3312a756/src/envs/README.md


env = Tau2Environment(
domain=os.environ.get("TAU2_DOMAIN", "airline"),
task_id=os.environ.get("TAU2_TASK_ID", "0"),
env_args=json.loads(os.environ.get("TAU2_ENV_ARGS_JSON", "{}")),
)
app = create_fastapi_app(env, Tau2Action, Tau2Observation)
Loading