|
10 | 10 | Agent, |
11 | 11 | AuthenticateResponse, |
12 | 12 | Client, |
| 13 | + CreateTerminalResponse, |
13 | 14 | InitializeResponse, |
14 | 15 | LoadSessionResponse, |
15 | 16 | NewSessionResponse, |
|
24 | 25 | update_agent_message_text, |
25 | 26 | update_tool_call, |
26 | 27 | ) |
| 28 | +from acp.core import AgentSideConnection, ClientSideConnection |
27 | 29 | from acp.schema import ( |
28 | 30 | AgentMessageChunk, |
29 | 31 | AllowedOutcome, |
30 | 32 | AudioContentBlock, |
31 | 33 | ClientCapabilities, |
32 | 34 | DeniedOutcome, |
33 | 35 | EmbeddedResourceContentBlock, |
| 36 | + EnvVariable, |
34 | 37 | HttpMcpServer, |
35 | 38 | ImageContentBlock, |
36 | 39 | Implementation, |
@@ -130,6 +133,56 @@ async def test_session_notifications_flow(connect, client): |
130 | 133 | assert client.notifications[0].session_id == "sess" |
131 | 134 |
|
132 | 135 |
|
| 136 | +@pytest.mark.asyncio |
| 137 | +async def test_on_connect_create_terminal_handle(server): |
| 138 | + class _TerminalAgent(Agent): |
| 139 | + __test__ = False |
| 140 | + |
| 141 | + def __init__(self) -> None: |
| 142 | + self._conn: Client | None = None |
| 143 | + self.handle_id: str | None = None |
| 144 | + |
| 145 | + def on_connect(self, conn: Client) -> None: |
| 146 | + self._conn = conn |
| 147 | + |
| 148 | + async def prompt( |
| 149 | + self, |
| 150 | + prompt: list[TextContentBlock], |
| 151 | + session_id: str, |
| 152 | + **kwargs: Any, |
| 153 | + ) -> PromptResponse: |
| 154 | + assert self._conn is not None |
| 155 | + handle = await self._conn.create_terminal(command="echo", session_id=session_id) |
| 156 | + self.handle_id = handle.terminal_id |
| 157 | + return PromptResponse(stop_reason="end_turn") |
| 158 | + |
| 159 | + class _TerminalClient(TestClient): |
| 160 | + __test__ = False |
| 161 | + |
| 162 | + async def create_terminal( |
| 163 | + self, |
| 164 | + command: str, |
| 165 | + session_id: str, |
| 166 | + args: list[str] | None = None, |
| 167 | + cwd: str | None = None, |
| 168 | + env: list[EnvVariable] | None = None, |
| 169 | + output_byte_limit: int | None = None, |
| 170 | + **kwargs: Any, |
| 171 | + ) -> CreateTerminalResponse: |
| 172 | + return CreateTerminalResponse(terminal_id="term-123") |
| 173 | + |
| 174 | + agent = _TerminalAgent() |
| 175 | + client = _TerminalClient() |
| 176 | + agent_conn = AgentSideConnection(agent, server.server_writer, server.server_reader, listening=True) |
| 177 | + client_conn = ClientSideConnection(client, server.client_writer, server.client_reader) |
| 178 | + |
| 179 | + await client_conn.prompt(session_id="sess", prompt=[TextContentBlock(type="text", text="start")]) |
| 180 | + assert agent.handle_id == "term-123" |
| 181 | + |
| 182 | + await client_conn.close() |
| 183 | + await agent_conn.close() |
| 184 | + |
| 185 | + |
133 | 186 | @pytest.mark.asyncio |
134 | 187 | async def test_concurrent_reads(connect, client): |
135 | 188 | for i in range(5): |
|
0 commit comments