From ba77c62b3395a367ac76514d4fb40ae6cb4676ab Mon Sep 17 00:00:00 2001 From: Yuvraj Singh Date: Mon, 15 Dec 2025 01:58:45 +0530 Subject: [PATCH 1/2] feat: Add MCP integration --- README.md | 24 ++ pyproject.toml | 1 + src/dnet/api/http_api.py | 16 +- src/dnet/api/mcp_handler.py | 475 ++++++++++++++++++++++++++++++++++++ 4 files changed, 515 insertions(+), 1 deletion(-) create mode 100644 src/dnet/api/mcp_handler.py diff --git a/README.md b/README.md index 82be6d7f..8106fbe2 100644 --- a/README.md +++ b/README.md @@ -236,6 +236,30 @@ curl -X POST http://localhost:8080/v1/chat/completions \ }' ``` +#### MCP Integration + +dnet exposes an MCP server at `/mcp` for use with Claude Desktop, Cursor, and other MCP clients. + +Add this to your MCP config: + +```json +{ + "mcpServers": { + "dnet": { + "command": "npx", + "args": [ + "-y", + "mcp-remote@latest", + "http://localhost:8080/mcp", + "--allow-http" + ] + } + } +} +``` + +Available tools: `chat_completion`, `load_model`, `unload_model`, `list_models`, `get_status`, `get_cluster_details`. + #### Devices You can get the list of discoverable devices with: diff --git a/pyproject.toml b/pyproject.toml index 33ba5f4e..3b397bc8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ dependencies = [ "dnet-p2p @ file://${PROJECT_ROOT}/lib/dnet-p2p/bindings/py", "rich>=13.0.0", "psutil>=5.9.0", + "fastmcp", ] [project.optional-dependencies] diff --git a/src/dnet/api/http_api.py b/src/dnet/api/http_api.py index 1035d00f..3dd690d3 100644 --- a/src/dnet/api/http_api.py +++ b/src/dnet/api/http_api.py @@ -25,6 +25,7 @@ from .inference import InferenceManager from .model_manager import ModelManager from dnet_p2p import DnetDeviceProperties +from .mcp_handler import create_mcp_server class HTTPServer: @@ -41,9 +42,22 @@ def __init__( self.inference_manager = inference_manager self.model_manager = model_manager self.node_id = node_id - self.app = FastAPI() self.http_server: Optional[asyncio.Task] = None + # Create MCP server first to get lifespan + mcp = create_mcp_server( + inference_manager, model_manager, cluster_manager + ) + # Use path='/' since we're mounting at /mcp, so final path will be /mcp/ + mcp_app = mcp.http_app(path="/") + + # Create FastAPI app with MCP lifespan + self.app = FastAPI(lifespan=mcp_app.lifespan) + + # Mount MCP server as ASGI app + self.app.mount("/mcp", mcp_app) + + async def start(self, shutdown_trigger: Any = lambda: asyncio.Future()) -> None: await self._setup_routes() diff --git a/src/dnet/api/mcp_handler.py b/src/dnet/api/mcp_handler.py new file mode 100644 index 00000000..61208171 --- /dev/null +++ b/src/dnet/api/mcp_handler.py @@ -0,0 +1,475 @@ +from collections import defaultdict +from fastmcp import FastMCP, Context +from fastmcp.server.middleware.error_handling import ErrorHandlingMiddleware +from starlette.responses import JSONResponse +from pydantic import ValidationError +from .models import ( + ChatRequestModel, + ChatMessage, + APILoadModelRequest, +) +from .inference import InferenceManager +from .model_manager import ModelManager +from .cluster import ClusterManager +from dnet.utils.logger import logger +from dnet.utils.model import get_model_config_json +from distilp.profiler import profile_model + + +class McpError(Exception): + """Custom MCP error with JSON-RPC 2.0 error codes. + - -32700: Parse error + - -32600: Invalid request + - -32601: Method not found + - -32602: Invalid params + - -32603: Internal error + - -32000 to -32099: Server error (implementation-specific) + - -32000: Service unavailable (used when no model is loaded) + - -32001: Request Timeout + - -32002: Resource not found + - -32800: Request cancelled + - -32801: Content too large +""" + + def __init__(self, code: int, message: str, data: dict | None = None): + self.code = code + self.message = message + self.data = data or {} + super().__init__(self.message) + + +def create_mcp_server( + inference_manager: InferenceManager, + model_manager: ModelManager, + cluster_manager: ClusterManager, +) -> FastMCP: + """Create and configure the MCP server for dnet.""" + + mcp = FastMCP("dnet") + mcp.add_middleware(ErrorHandlingMiddleware()) + @mcp.custom_route("/mcp-health", methods=["GET"]) + async def mcp_health_check(request): + """Health check endpoint for MCP server.""" + return JSONResponse({ + "status": "healthy", + "service": "dnet-mcp", + "model_loaded": model_manager.current_model_id is not None, + "model": model_manager.current_model_id, + "topology_configured": cluster_manager.current_topology is not None, + "shards_discovered": len(cluster_manager.shards) if cluster_manager.shards else 0, + }) + + @mcp.tool() + async def chat_completion( + messages: list[dict[str, str]], + model: str | None = None, + temperature: float = 1.0, + max_tokens: int = 2000, + top_p: float = 1.0, + top_k: int = -1, + stop: str | list[str] | None = None, + repetition_penalty: float = 1.0, + ctx: Context | None = None, + ) -> str: + """Generate text using distributed LLM inference. + Args: + messages: Array of message objects with 'role' and 'content' fields. + Each message should be a dict like: {"role": "user", "content": "Hello"} + model: Model name (optional, uses currently loaded model if not specified) + temperature: Sampling temperature (0-2), default is 1.0 + max_tokens: Maximum tokens to generate, default is 2000 + top_p: Nucleus sampling parameter (0-1), default is 1.0 + top_k: Top-k sampling parameter (-1 for disabled), default is -1 + stop: Stop sequences (string or list), default is None + repetition_penalty: Repetition penalty (>=0), default is 1.0 + """ + + if ctx: + await ctx.info("Starting inference...") + + if not model_manager.current_model_id: + raise McpError( + -32000, + "No model loaded. Please load a model first using load_model tool.", + data={"action": "load_model"} + ) + + model_id = model or model_manager.current_model_id + stops = [stop] if isinstance(stop, str) else (stop or []) + + try: + msgs = [ + ChatMessage(**msg) if isinstance(msg, dict) else msg + for msg in messages + ] + req = ChatRequestModel( + messages=msgs, + model=model_id, + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p, + top_k=top_k, + stop=stops, + repetition_penalty=repetition_penalty, + stream=False, + ) + result = await inference_manager.chat_completions(req) + except ValidationError as e: + raise McpError( + -32602, + f"Invalid request parameters: {str(e)}", + data={"validation_errors": str(e)} + ) + except Exception as e: + logger.exception("Error in chat_completion: %s", e) + raise McpError( + -32603, + f"Inference failed: {str(e)}", + data={"model": model_id, "original_error": type(e).__name__} + ) + + if not result.choices or not result.choices[0].message: + raise McpError(-32603, "No content generated", data={"model": model_id}) + + text = result.choices[0].message.content or "" + if ctx: + await ctx.info("Inference completed successfully") + + return text + + @mcp.tool() + async def load_model( + model: str, + kv_bits: str = "8bit", + seq_len: int = 4096, + ctx: Context | None = None, + ) -> str: + """Load a model for distributed inference across the cluster. + + If a different model is already loaded, both models will stay in memory (old model + is not automatically unloaded). If the same model is already loaded, returns early. + Automatically prepares topology and discovers devices if needed. + + Args: + model: Model ID from catalog + kv_bits: KV cache quantization + seq_len: Sequence length + """ + try: + req = APILoadModelRequest( + model=model, + kv_bits=kv_bits, + seq_len=seq_len, + ) + if ctx: + await ctx.info(f"Starting to load model: {req.model}") + + if model_manager.current_model_id == req.model: + return f"Model '{req.model}' is already loaded." + + topology = cluster_manager.current_topology + if topology is None or topology.model != req.model: + if ctx: + await ctx.info("Preparing topology ...") + + await cluster_manager.scan_devices() + if not cluster_manager.shards: + raise McpError( + -32002, + "No shards discovered. Check shard connectivity.", + data={"action": "check_shard_connectivity"} + ) + + if ctx: + await ctx.info("Profiling cluster performance") + + model_config = get_model_config_json(req.model) + embedding_size = int(model_config["hidden_size"]) + num_layers = int(model_config["num_hidden_layers"]) + + batch_sizes = [1] + profiles = await cluster_manager.profile_cluster( + req.model, embedding_size, 2, batch_sizes + ) + if not profiles: + raise McpError( + -32603, + "Failed to collect device profiles. Check shard connectivity.", + data={ + "step": "profiling", + "shards_count": len(cluster_manager.shards) if cluster_manager.shards else 0 + } + ) + + if ctx: + await ctx.info("Computing optimal layer distribution") + + model_profile_split = profile_model( + repo_id=req.model, + batch_sizes=batch_sizes, + sequence_length=req.seq_len, + ) + model_profile = model_profile_split.to_model_profile() + + topology = await cluster_manager.solve_topology( + profiles, model_profile, req.model, num_layers, req.kv_bits + ) + cluster_manager.current_topology = topology + + if ctx: + await ctx.info("Topology prepared") + + if ctx: + await ctx.info("Loading model layers across shards...") + api_props = await cluster_manager.discovery.async_get_own_properties() + response = await model_manager.load_model( + topology, api_props, inference_manager.grpc_port + ) + + if not response.success: + error_msg = response.message or "Model loading failed" + shard_errors = [ + {"instance": s.instance, "message": s.message} + for s in response.shard_statuses + if not s.success + ] + raise McpError( + -32603, + f"Model loading failed: {error_msg}. " + f"{len(shard_errors)}/{len(response.shard_statuses)} shards failed.", + data={ + "model": req.model, + "shard_errors": shard_errors, + "failed_shards": len(shard_errors), + "total_shards": len(response.shard_statuses) + } + ) + + if topology.devices: + first_shard = topology.devices[0] + await inference_manager.connect_to_ring( + first_shard.local_ip, first_shard.shard_port, api_props.local_ip + ) + + if ctx: + await ctx.info(f"Model {req.model} loaded successfully across {len(response.shard_statuses)} shards") + + success_count = len([s for s in response.shard_statuses if s.success]) + return f"Model '{req.model}' loaded successfully. Loaded on {success_count}/{len(response.shard_statuses)} shards." + + except ValidationError as e: + raise McpError( + -32602, + f"Invalid load_model parameters: {str(e)}", + data={"validation_errors": str(e)} + ) + except McpError: + raise + except Exception as e: + logger.exception("Error in load_model: %s", e) + if ctx: + await ctx.error(f"Failed to load model: {str(e)}") + raise McpError( + -32603, + f"Failed to load model '{req.model}': {str(e)}", + data={"model": req.model, "original_error": type(e).__name__} + ) + + @mcp.tool() + async def unload_model(ctx: Context | None = None) -> str: + """Unload the currently loaded model to free memory. + Unloads the model from all shards and clears the topology. If no model is loaded, returns early. + """ + if not model_manager.current_model_id: + return "No model is currently loaded." + + model_name = model_manager.current_model_id + if ctx: + await ctx.info(f"Unloading model: {model_name}") + + await cluster_manager.scan_devices() + shards = cluster_manager.shards + response = await model_manager.unload_model(shards) + + if response.success: + cluster_manager.current_topology = None + if ctx: + await ctx.info("Model unloaded successfully") + return f"Model '{model_name}' unloaded successfully from all shards." + else: + shard_errors = [ + {"instance": s.instance, "message": s.message} + for s in response.shard_statuses + if not s.success + ] + raise McpError( + -32603, + "Model unloading failed", + data={ + "model": model_name, + "shard_errors": shard_errors, + "failed_shards": len(shard_errors), + "total_shards": len(response.shard_statuses) + } + ) + + # Resources (for MCP protocol compliance) + @mcp.resource("mcp://dnet/models") + async def get_available_models() -> str: + """List of models available in dnet catalog, organized by family and quantization.""" + return await _get_available_models_data() + + @mcp.resource("mcp://dnet/status") + async def get_model_status() -> str: + """Currently loaded model and cluster status information.""" + return await _get_model_status_data() + + @mcp.resource("mcp://dnet/cluster") + async def get_cluster_info() -> str: + """Detailed cluster information including devices and topology.""" + return await _get_cluster_info_data() + + # Tools that wrap resources (for Claude Desktop compatibility) + @mcp.tool() + async def list_models() -> str: + """List all available models in the dnet catalog. + + Returns a formatted list of models organized by family and quantization. + Use this to see what models you can load. + """ + return await _get_available_models_data() + + @mcp.tool() + async def get_status() -> str: + """Get the current status of dnet including loaded model, topology, and cluster information. + + Returns detailed status about: + - Currently loaded model (if any) + - Topology configuration + - Discovered shards in the cluster + """ + return await _get_model_status_data() + + @mcp.tool() + async def get_cluster_details() -> str: + """Get detailed cluster information including shard details and topology breakdown. + + Returns comprehensive information about: + - All discovered shards with their IPs and ports + - Current topology configuration + - Layer assignments across devices + """ + return await _get_cluster_info_data() + + + async def _get_available_models_data() -> str: + models_by_family = defaultdict(list) + for model in model_manager.available_models: + models_by_family[model.alias].append(model) + + output_lines = ["Available Models in dnet Catalog:\n"] + output_lines.append("=" * 60) + + for family_name in sorted(models_by_family.keys()): + models = sorted(models_by_family[family_name], key=lambda m: m.id) + output_lines.append(f"\n{family_name.upper()}") + output_lines.append("-" * 60) + + by_quant = defaultdict(list) + for model in models: + by_quant[model.quantization].append(model) + + for quant in ["bf16", "fp16", "8bit", "4bit"]: + if quant in by_quant: + quant_models = by_quant[quant] + quant_display = { + "bf16": "BF16 (Full precision)", + "fp16": "FP16 (Full precision)", + "8bit": "8-bit quantized", + "4bit": "4-bit quantized (smallest)", + }.get(quant, quant) + output_lines.append(f" {quant_display}:") + for model in quant_models: + output_lines.append(f" - {model.id}") + + output_lines.append("\n" + "=" * 60) + output_lines.append(f"\nTotal: {len(model_manager.available_models)} models") + output_lines.append("\nTo load a model, use the load_model tool with the full model ID.") + + return "\n".join(output_lines) + + async def _get_model_status_data() -> str: + status_lines = ["dnet Status"] + status_lines.append("=" * 60) + + if model_manager.current_model_id: + status_lines.append(f"\n Model Loaded: {model_manager.current_model_id}") + else: + status_lines.append("\n No Model Loaded") + + topology = cluster_manager.current_topology + if topology: + status_lines.append(f"\n Topology:\n Model: {topology.model}\n Devices: {len(topology.devices)}\n Layers: {topology.num_layers}\n KV Cache: {topology.kv_bits}") + + if topology.assignments: + status_lines.append(f"\n Layer Distribution:") + for assignment in topology.assignments: + layers_str = ", ".join( + f"{r[0]}-{r[-1]}" if len(r) > 1 else str(r[0]) + for r in assignment.layers + ) + status_lines.append( + f" {assignment.instance}: layers [{layers_str}]" + ) + else: + status_lines.append("\n Topology: Not configured") + + shards = cluster_manager.shards + if shards: + shard_names = ", ".join(sorted(shards.keys())) + status_lines.append(f"\n Cluster:\n Discovered Shards: {len(shards)}\n Shard Names: {shard_names}") + else: + status_lines.append("\n Cluster: No shards discovered") + + status_lines.append("\n" + "=" * 60) + + return "\n".join(status_lines) + + async def _get_cluster_info_data() -> str: + output_lines = ["dnet Cluster Information"] + output_lines.append("=" * 60) + + shards = cluster_manager.shards + if shards: + output_lines.append(f"\n Shards ({len(shards)}):") + for name, props in sorted(shards.items()): + output_lines.append(f"\n {name}:\n IP: {props.local_ip}\n HTTP Port: {props.server_port}\n gRPC Port: {props.shard_port}\n Manager: {'Yes' if props.is_manager else 'No'}\n Busy: {'Yes' if props.is_busy else 'No'}") + else: + output_lines.append("\n No shards discovered") + + topology = cluster_manager.current_topology + if topology: + output_lines.append(f"\n Topology:\n Model: {topology.model}\n Total Layers: {topology.num_layers}\n KV Cache Bits: {topology.kv_bits}\n Devices: {len(topology.devices)}") + + if topology.assignments: + output_lines.append(f"\n Layer Assignments:") + for assignment in topology.assignments: + layers_flat = [ + layer + for round_layers in assignment.layers + for layer in round_layers + ] + layers_str = ", ".join(map(str, sorted(layers_flat))) + output_lines.append( + f" {assignment.instance}: [{layers_str}] " + f"(window={assignment.window_size}, " + f"next={assignment.next_instance or 'N/A'})" + ) + else: + output_lines.append("\n No topology configured") + + output_lines.append("\n" + "=" * 60) + + return "\n".join(output_lines) + + return mcp From 93b3930b7d87c63881c9d6a217f0e59d2342cc7a Mon Sep 17 00:00:00 2001 From: Yuvraj Singh Date: Fri, 19 Dec 2025 21:36:42 +0530 Subject: [PATCH 2/2] Add MCP integration improvements and refactoring --- pyproject.toml | 5 +- src/dnet/api/http_api.py | 86 ++---- src/dnet/api/load_helpers.py | 98 ++++++ src/dnet/api/mcp_handler.py | 360 ++++++++-------------- tests/integration/test_mcp_integration.py | 289 +++++++++++++++++ tests/subsystems/test_api_http_server.py | 6 +- 6 files changed, 550 insertions(+), 294 deletions(-) create mode 100644 src/dnet/api/load_helpers.py create mode 100644 tests/integration/test_mcp_integration.py diff --git a/pyproject.toml b/pyproject.toml index 3b397bc8..1723ff32 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ dependencies = [ "dnet-p2p @ file://${PROJECT_ROOT}/lib/dnet-p2p/bindings/py", "rich>=13.0.0", "psutil>=5.9.0", - "fastmcp", + "fastmcp==2.13.0", ] [project.optional-dependencies] @@ -48,6 +48,7 @@ cuda = ["mlx[cuda]"] dev = [ "openai>=2.6.0", # for OpenAI compatibility tests "pytest>=8.4.2", + "pytest-asyncio>=0.24.0", "mypy>=1.3.0", # Type checking "ruff>=0.0.285", "types-psutil>=7.1.3", @@ -68,6 +69,7 @@ python_files = ["test_*.py", "*_test.py"] testpaths = ["tests"] python_functions = ["test_"] log_cli = true +asyncio_mode = "auto" markers = [ "api: tests for API node components (HTTP, gRPC, managers)", "shard: tests for Shard node components (HTTP, gRPC, runtime, policies, ring)", @@ -80,6 +82,7 @@ markers = [ "core: tests for core memory/cache/utils not tied to api/shard", "e2e: integration tests requiring live servers or multiple components", "integration: model catalog integration tests for CI (manual trigger)", + "mcp: tests for MCP handler tools and server integration", ] [tool.ruff] diff --git a/src/dnet/api/http_api.py b/src/dnet/api/http_api.py index 3dd690d3..34d8e7c1 100644 --- a/src/dnet/api/http_api.py +++ b/src/dnet/api/http_api.py @@ -1,6 +1,5 @@ from typing import Optional, Any, List import asyncio -import os from hypercorn import Config from hypercorn.utils import LifespanFailureError import hypercorn.asyncio as aio_hypercorn @@ -26,6 +25,11 @@ from .model_manager import ModelManager from dnet_p2p import DnetDeviceProperties from .mcp_handler import create_mcp_server +from .load_helpers import ( + _prepare_topology_core, + _load_model_core, + _unload_model_core, +) class HTTPServer: @@ -45,19 +49,16 @@ def __init__( self.http_server: Optional[asyncio.Task] = None # Create MCP server first to get lifespan - mcp = create_mcp_server( - inference_manager, model_manager, cluster_manager - ) + mcp = create_mcp_server(inference_manager, model_manager, cluster_manager) # Use path='/' since we're mounting at /mcp, so final path will be /mcp/ mcp_app = mcp.http_app(path="/") - + # Create FastAPI app with MCP lifespan self.app = FastAPI(lifespan=mcp_app.lifespan) - + # Mount MCP server as ASGI app self.app.mount("/mcp", mcp_app) - async def start(self, shutdown_trigger: Any = lambda: asyncio.Future()) -> None: await self._setup_routes() @@ -166,59 +167,27 @@ async def load_model(self, req: APILoadModelRequest) -> APILoadModelResponse: ), ) - model_config = get_model_config_json(req.model) - embedding_size = int(model_config["hidden_size"]) - num_layers = int(model_config["num_hidden_layers"]) - - await self.cluster_manager.scan_devices() - batch_sizes = [1] - profiles = await self.cluster_manager.profile_cluster( - req.model, embedding_size, 2, batch_sizes - ) - if not profiles: - return APILoadModelResponse( - model=req.model, - success=False, - shard_statuses=[], - message="No profiles collected", + try: + topology = await _prepare_topology_core( + self.cluster_manager, req.model, req.kv_bits, req.seq_len ) - - model_profile_split = profile_model( - repo_id=req.model, - batch_sizes=batch_sizes, - sequence_length=req.seq_len, - ) - model_profile = model_profile_split.to_model_profile() - topology = await self.cluster_manager.solve_topology( - profiles, model_profile, req.model, num_layers, req.kv_bits - ) + except RuntimeError as e: + if "No profiles collected" in str(e): + return APILoadModelResponse( + model=req.model, + success=False, + shard_statuses=[], + message="No profiles collected", + ) + raise self.cluster_manager.current_topology = topology - api_props = await self.cluster_manager.discovery.async_get_own_properties() - grpc_port = int(self.inference_manager.grpc_port) - - # Callback address shards should use for SendToken. - # In static discovery / cloud setups, discovery may report 127.0.0.1 which is not usable. - api_callback_addr = (os.getenv("DNET_API_CALLBACK_ADDR") or "").strip() - if not api_callback_addr: - api_callback_addr = f"{api_props.local_ip}:{grpc_port}" - if api_props.local_ip in ("127.0.0.1", "localhost"): - logger.warning( - "API callback address is loopback (%s). Remote shards will fail to SendToken. " - "Set DNET_API_CALLBACK_ADDR to a reachable host:port.", - api_callback_addr, - ) - response = await self.model_manager.load_model( + response = await _load_model_core( + self.cluster_manager, + self.model_manager, + self.inference_manager, topology, - api_props, - self.inference_manager.grpc_port, - api_callback_address=api_callback_addr, ) - if response.success: - first_shard = topology.devices[0] - await self.inference_manager.connect_to_ring( - first_shard.local_ip, first_shard.shard_port, api_callback_addr - ) return response except Exception as e: @@ -231,12 +200,7 @@ async def load_model(self, req: APILoadModelRequest) -> APILoadModelResponse: ) async def unload_model(self) -> UnloadModelResponse: - await self.cluster_manager.scan_devices() - shards = self.cluster_manager.shards - response = await self.model_manager.unload_model(shards) - if response.success: - self.cluster_manager.current_topology = None - return response + return await _unload_model_core(self.cluster_manager, self.model_manager) async def get_devices(self) -> JSONResponse: devices = await self.cluster_manager.discovery.async_get_properties() diff --git a/src/dnet/api/load_helpers.py b/src/dnet/api/load_helpers.py new file mode 100644 index 00000000..73cf666f --- /dev/null +++ b/src/dnet/api/load_helpers.py @@ -0,0 +1,98 @@ +import os +from dnet.utils.logger import logger +from dnet.utils.model import get_model_config_json +from distilp.profiler import profile_model +from dnet.core.types.topology import TopologyInfo +from .models import APILoadModelResponse, UnloadModelResponse + + +async def get_api_callback_address( + cluster_manager, + grpc_port: int | str, +) -> str: + api_props = await cluster_manager.discovery.async_get_own_properties() + grpc_port_int = int(grpc_port) + api_callback_addr = (os.getenv("DNET_API_CALLBACK_ADDR") or "").strip() + if not api_callback_addr: + api_callback_addr = f"{api_props.local_ip}:{grpc_port_int}" + if api_props.local_ip in ("127.0.0.1", "localhost"): + logger.warning( + "API callback address is loopback (%s). Remote shards will fail to SendToken. " + "Set DNET_API_CALLBACK_ADDR to a reachable host:port.", + api_callback_addr, + ) + return api_callback_addr + + +async def _prepare_topology_core( + cluster_manager, + model: str, + kv_bits: str, + seq_len: int, + progress_callback=None, +) -> TopologyInfo: + model_config = get_model_config_json(model) + embedding_size = int(model_config["hidden_size"]) + num_layers = int(model_config["num_hidden_layers"]) + + await cluster_manager.scan_devices() + if progress_callback: + await progress_callback("Profiling cluster performance") + batch_sizes = [1] + profiles = await cluster_manager.profile_cluster( + model, embedding_size, 2, batch_sizes + ) + if not profiles: + raise RuntimeError("No profiles collected") + + if progress_callback: + await progress_callback("Computing optimal layer distribution") + model_profile_split = profile_model( + repo_id=model, + batch_sizes=batch_sizes, + sequence_length=seq_len, + ) + model_profile = model_profile_split.to_model_profile() + + topology = await cluster_manager.solve_topology( + profiles, model_profile, model, num_layers, kv_bits + ) + return topology + + +async def _load_model_core( + cluster_manager, + model_manager, + inference_manager, + topology: TopologyInfo, +) -> APILoadModelResponse: + api_props = await cluster_manager.discovery.async_get_own_properties() + grpc_port = int(inference_manager.grpc_port) + + api_callback_addr = await get_api_callback_address( + cluster_manager, inference_manager.grpc_port + ) + response = await model_manager.load_model( + topology, + api_props, + grpc_port, + api_callback_address=api_callback_addr, + ) + if response.success: + first_shard = topology.devices[0] + await inference_manager.connect_to_ring( + first_shard.local_ip, first_shard.shard_port, api_callback_addr + ) + return response + + +async def _unload_model_core( + cluster_manager, + model_manager, +) -> UnloadModelResponse: + await cluster_manager.scan_devices() + shards = cluster_manager.shards + response = await model_manager.unload_model(shards) + if response.success: + cluster_manager.current_topology = None + return response diff --git a/src/dnet/api/mcp_handler.py b/src/dnet/api/mcp_handler.py index 61208171..a7a5b46f 100644 --- a/src/dnet/api/mcp_handler.py +++ b/src/dnet/api/mcp_handler.py @@ -1,4 +1,4 @@ -from collections import defaultdict +import json from fastmcp import FastMCP, Context from fastmcp.server.middleware.error_handling import ErrorHandlingMiddleware from starlette.responses import JSONResponse @@ -12,12 +12,15 @@ from .model_manager import ModelManager from .cluster import ClusterManager from dnet.utils.logger import logger -from dnet.utils.model import get_model_config_json -from distilp.profiler import profile_model +from .load_helpers import ( + _prepare_topology_core, + _load_model_core, + _unload_model_core, +) class McpError(Exception): - """Custom MCP error with JSON-RPC 2.0 error codes. + """Custom MCP error with JSON-RPC 2.0 error codes. - -32700: Parse error - -32600: Invalid request - -32601: Method not found @@ -29,8 +32,8 @@ class McpError(Exception): - -32002: Resource not found - -32800: Request cancelled - -32801: Content too large -""" - + """ + def __init__(self, code: int, message: str, data: dict | None = None): self.code = code self.message = message @@ -47,17 +50,22 @@ def create_mcp_server( mcp = FastMCP("dnet") mcp.add_middleware(ErrorHandlingMiddleware()) + @mcp.custom_route("/mcp-health", methods=["GET"]) async def mcp_health_check(request): """Health check endpoint for MCP server.""" - return JSONResponse({ - "status": "healthy", - "service": "dnet-mcp", - "model_loaded": model_manager.current_model_id is not None, - "model": model_manager.current_model_id, - "topology_configured": cluster_manager.current_topology is not None, - "shards_discovered": len(cluster_manager.shards) if cluster_manager.shards else 0, - }) + return JSONResponse( + { + "status": "healthy", + "service": "dnet-mcp", + "model_loaded": model_manager.current_model_id is not None, + "model": model_manager.current_model_id, + "topology_configured": cluster_manager.current_topology is not None, + "shards_discovered": len(cluster_manager.shards) + if cluster_manager.shards + else 0, + } + ) @mcp.tool() async def chat_completion( @@ -83,7 +91,7 @@ async def chat_completion( stop: Stop sequences (string or list), default is None repetition_penalty: Repetition penalty (>=0), default is 1.0 """ - + if ctx: await ctx.info("Starting inference...") @@ -91,7 +99,7 @@ async def chat_completion( raise McpError( -32000, "No model loaded. Please load a model first using load_model tool.", - data={"action": "load_model"} + data={"action": "load_model"}, ) model_id = model or model_manager.current_model_id @@ -99,8 +107,7 @@ async def chat_completion( try: msgs = [ - ChatMessage(**msg) if isinstance(msg, dict) else msg - for msg in messages + ChatMessage(**msg) if isinstance(msg, dict) else msg for msg in messages ] req = ChatRequestModel( messages=msgs, @@ -110,27 +117,27 @@ async def chat_completion( top_p=top_p, top_k=top_k, stop=stops, - repetition_penalty=repetition_penalty, + repetition_penalty=repetition_penalty, stream=False, - ) + ) result = await inference_manager.chat_completions(req) except ValidationError as e: raise McpError( -32602, f"Invalid request parameters: {str(e)}", - data={"validation_errors": str(e)} + data={"validation_errors": str(e)}, ) except Exception as e: logger.exception("Error in chat_completion: %s", e) raise McpError( -32603, f"Inference failed: {str(e)}", - data={"model": model_id, "original_error": type(e).__name__} + data={"model": model_id, "original_error": type(e).__name__}, ) if not result.choices or not result.choices[0].message: raise McpError(-32603, "No content generated", data={"model": model_id}) - + text = result.choices[0].message.content or "" if ctx: await ctx.info("Inference completed successfully") @@ -152,78 +159,54 @@ async def load_model( Args: model: Model ID from catalog - kv_bits: KV cache quantization - seq_len: Sequence length + kv_bits: KV cache quantization mode for the model's KV cache, the default is "8bit". + seq_len: Maximum sequence length (in tokens). defaults to 4096. """ try: req = APILoadModelRequest( model=model, - kv_bits=kv_bits, + kv_bits=kv_bits, seq_len=seq_len, ) if ctx: await ctx.info(f"Starting to load model: {req.model}") - if model_manager.current_model_id == req.model: - return f"Model '{req.model}' is already loaded." - topology = cluster_manager.current_topology - if topology is None or topology.model != req.model: + if topology is None: if ctx: await ctx.info("Preparing topology ...") - - await cluster_manager.scan_devices() - if not cluster_manager.shards: - raise McpError( - -32002, - "No shards discovered. Check shard connectivity.", - data={"action": "check_shard_connectivity"} + try: + topology = await _prepare_topology_core( + cluster_manager, + req.model, + req.kv_bits, + req.seq_len, + progress_callback=ctx.info if ctx else None, ) - - if ctx: - await ctx.info("Profiling cluster performance") - - model_config = get_model_config_json(req.model) - embedding_size = int(model_config["hidden_size"]) - num_layers = int(model_config["num_hidden_layers"]) - - batch_sizes = [1] - profiles = await cluster_manager.profile_cluster( - req.model, embedding_size, 2, batch_sizes - ) - if not profiles: - raise McpError( - -32603, - "Failed to collect device profiles. Check shard connectivity.", - data={ - "step": "profiling", - "shards_count": len(cluster_manager.shards) if cluster_manager.shards else 0 - } - ) - - if ctx: - await ctx.info("Computing optimal layer distribution") - - model_profile_split = profile_model( - repo_id=req.model, - batch_sizes=batch_sizes, - sequence_length=req.seq_len, - ) - model_profile = model_profile_split.to_model_profile() - - topology = await cluster_manager.solve_topology( - profiles, model_profile, req.model, num_layers, req.kv_bits - ) + except RuntimeError as e: + if "No profiles collected" in str(e): + raise McpError( + -32603, + "Failed to collect device profiles. Check shard connectivity.", + data={ + "step": "profiling", + "shards_count": len(cluster_manager.shards) + if cluster_manager.shards + else 0, + }, + ) + raise cluster_manager.current_topology = topology - if ctx: await ctx.info("Topology prepared") if ctx: await ctx.info("Loading model layers across shards...") - api_props = await cluster_manager.discovery.async_get_own_properties() - response = await model_manager.load_model( - topology, api_props, inference_manager.grpc_port + response = await _load_model_core( + cluster_manager, + model_manager, + inference_manager, + topology, ) if not response.success: @@ -235,24 +218,19 @@ async def load_model( ] raise McpError( -32603, - f"Model loading failed: {error_msg}. " - f"{len(shard_errors)}/{len(response.shard_statuses)} shards failed.", + f"Model loading failed: {error_msg}. {len(shard_errors)}/{len(response.shard_statuses)} shards failed.", data={ "model": req.model, "shard_errors": shard_errors, "failed_shards": len(shard_errors), - "total_shards": len(response.shard_statuses) - } - ) - - if topology.devices: - first_shard = topology.devices[0] - await inference_manager.connect_to_ring( - first_shard.local_ip, first_shard.shard_port, api_props.local_ip + "total_shards": len(response.shard_statuses), + }, ) if ctx: - await ctx.info(f"Model {req.model} loaded successfully across {len(response.shard_statuses)} shards") + await ctx.info( + f"Model {req.model} loaded successfully across {len(response.shard_statuses)} shards" + ) success_count = len([s for s in response.shard_statuses if s.success]) return f"Model '{req.model}' loaded successfully. Loaded on {success_count}/{len(response.shard_statuses)} shards." @@ -261,7 +239,7 @@ async def load_model( raise McpError( -32602, f"Invalid load_model parameters: {str(e)}", - data={"validation_errors": str(e)} + data={"validation_errors": str(e)}, ) except McpError: raise @@ -271,8 +249,8 @@ async def load_model( await ctx.error(f"Failed to load model: {str(e)}") raise McpError( -32603, - f"Failed to load model '{req.model}': {str(e)}", - data={"model": req.model, "original_error": type(e).__name__} + f"Failed to load model '{model}': {str(e)}", + data={"model": model, "original_error": type(e).__name__}, ) @mcp.tool() @@ -287,12 +265,9 @@ async def unload_model(ctx: Context | None = None) -> str: if ctx: await ctx.info(f"Unloading model: {model_name}") - await cluster_manager.scan_devices() - shards = cluster_manager.shards - response = await model_manager.unload_model(shards) + response = await _unload_model_core(cluster_manager, model_manager) if response.success: - cluster_manager.current_topology = None if ctx: await ctx.info("Model unloaded successfully") return f"Model '{model_name}' unloaded successfully from all shards." @@ -309,167 +284,94 @@ async def unload_model(ctx: Context | None = None) -> str: "model": model_name, "shard_errors": shard_errors, "failed_shards": len(shard_errors), - "total_shards": len(response.shard_statuses) - } + "total_shards": len(response.shard_statuses), + }, ) # Resources (for MCP protocol compliance) @mcp.resource("mcp://dnet/models") - async def get_available_models() -> str: - """List of models available in dnet catalog, organized by family and quantization.""" - return await _get_available_models_data() + def get_available_models() -> str: + """List of models available in dnet catalog as JSON.""" + return _get_available_models_data() @mcp.resource("mcp://dnet/status") - async def get_model_status() -> str: - """Currently loaded model and cluster status information.""" - return await _get_model_status_data() + def get_model_status() -> str: + """Currently loaded model and cluster status as JSON.""" + return _get_model_status_data() @mcp.resource("mcp://dnet/cluster") - async def get_cluster_info() -> str: - """Detailed cluster information including devices and topology.""" - return await _get_cluster_info_data() + def get_cluster_info() -> str: + """Cluster information including devices and topology as JSON.""" + return _get_cluster_info_data() # Tools that wrap resources (for Claude Desktop compatibility) @mcp.tool() - async def list_models() -> str: + def list_models() -> str: """List all available models in the dnet catalog. - Returns a formatted list of models organized by family and quantization. - Use this to see what models you can load. + Returns JSON with model IDs, aliases, and quantization info. """ - return await _get_available_models_data() + return _get_available_models_data() @mcp.tool() - async def get_status() -> str: - """Get the current status of dnet including loaded model, topology, and cluster information. + def get_status() -> str: + """Get the current status of dnet. - Returns detailed status about: - - Currently loaded model (if any) - - Topology configuration - - Discovered shards in the cluster + Returns JSON with loaded model, topology, and shard count. """ - return await _get_model_status_data() + return _get_model_status_data() @mcp.tool() - async def get_cluster_details() -> str: - """Get detailed cluster information including shard details and topology breakdown. + def get_cluster_details() -> str: + """Get detailed cluster information. - Returns comprehensive information about: - - All discovered shards with their IPs and ports - - Current topology configuration - - Layer assignments across devices + Returns JSON with all discovered shards and current topology. """ - return await _get_cluster_info_data() - - - async def _get_available_models_data() -> str: - models_by_family = defaultdict(list) - for model in model_manager.available_models: - models_by_family[model.alias].append(model) - - output_lines = ["Available Models in dnet Catalog:\n"] - output_lines.append("=" * 60) - - for family_name in sorted(models_by_family.keys()): - models = sorted(models_by_family[family_name], key=lambda m: m.id) - output_lines.append(f"\n{family_name.upper()}") - output_lines.append("-" * 60) - - by_quant = defaultdict(list) - for model in models: - by_quant[model.quantization].append(model) - - for quant in ["bf16", "fp16", "8bit", "4bit"]: - if quant in by_quant: - quant_models = by_quant[quant] - quant_display = { - "bf16": "BF16 (Full precision)", - "fp16": "FP16 (Full precision)", - "8bit": "8-bit quantized", - "4bit": "4-bit quantized (smallest)", - }.get(quant, quant) - output_lines.append(f" {quant_display}:") - for model in quant_models: - output_lines.append(f" - {model.id}") - - output_lines.append("\n" + "=" * 60) - output_lines.append(f"\nTotal: {len(model_manager.available_models)} models") - output_lines.append("\nTo load a model, use the load_model tool with the full model ID.") - - return "\n".join(output_lines) - - async def _get_model_status_data() -> str: - status_lines = ["dnet Status"] - status_lines.append("=" * 60) - - if model_manager.current_model_id: - status_lines.append(f"\n Model Loaded: {model_manager.current_model_id}") - else: - status_lines.append("\n No Model Loaded") + return _get_cluster_info_data() + + def _get_available_models_data() -> str: + """Return available models as JSON (same format as /v1/models endpoint).""" + return json.dumps( + { + "object": "list", + "data": [m.model_dump() for m in model_manager.available_models], + } + ) + def _get_model_status_data() -> str: + """Return current status as JSON.""" topology = cluster_manager.current_topology - if topology: - status_lines.append(f"\n Topology:\n Model: {topology.model}\n Devices: {len(topology.devices)}\n Layers: {topology.num_layers}\n KV Cache: {topology.kv_bits}") - - if topology.assignments: - status_lines.append(f"\n Layer Distribution:") - for assignment in topology.assignments: - layers_str = ", ".join( - f"{r[0]}-{r[-1]}" if len(r) > 1 else str(r[0]) - for r in assignment.layers - ) - status_lines.append( - f" {assignment.instance}: layers [{layers_str}]" - ) - else: - status_lines.append("\n Topology: Not configured") - - shards = cluster_manager.shards - if shards: - shard_names = ", ".join(sorted(shards.keys())) - status_lines.append(f"\n Cluster:\n Discovered Shards: {len(shards)}\n Shard Names: {shard_names}") - else: - status_lines.append("\n Cluster: No shards discovered") - - status_lines.append("\n" + "=" * 60) - - return "\n".join(status_lines) - - async def _get_cluster_info_data() -> str: - output_lines = ["dnet Cluster Information"] - output_lines.append("=" * 60) + return json.dumps( + { + "model_loaded": model_manager.current_model_id, + "topology": topology.model_dump() if topology else None, + "shards_discovered": len(cluster_manager.shards) + if cluster_manager.shards + else 0, + } + ) + def _get_cluster_info_data() -> str: + """Return cluster information as JSON (same format as /v1/devices endpoint).""" shards = cluster_manager.shards - if shards: - output_lines.append(f"\n Shards ({len(shards)}):") - for name, props in sorted(shards.items()): - output_lines.append(f"\n {name}:\n IP: {props.local_ip}\n HTTP Port: {props.server_port}\n gRPC Port: {props.shard_port}\n Manager: {'Yes' if props.is_manager else 'No'}\n Busy: {'Yes' if props.is_busy else 'No'}") - else: - output_lines.append("\n No shards discovered") - topology = cluster_manager.current_topology - if topology: - output_lines.append(f"\n Topology:\n Model: {topology.model}\n Total Layers: {topology.num_layers}\n KV Cache Bits: {topology.kv_bits}\n Devices: {len(topology.devices)}") - - if topology.assignments: - output_lines.append(f"\n Layer Assignments:") - for assignment in topology.assignments: - layers_flat = [ - layer - for round_layers in assignment.layers - for layer in round_layers - ] - layers_str = ", ".join(map(str, sorted(layers_flat))) - output_lines.append( - f" {assignment.instance}: [{layers_str}] " - f"(window={assignment.window_size}, " - f"next={assignment.next_instance or 'N/A'})" - ) - else: - output_lines.append("\n No topology configured") - - output_lines.append("\n" + "=" * 60) - - return "\n".join(output_lines) + return json.dumps( + { + "devices": { + name: { + "instance": props.instance, + "local_ip": props.local_ip, + "server_port": props.server_port, + "shard_port": props.shard_port, + "is_manager": props.is_manager, + "is_busy": props.is_busy, + } + for name, props in shards.items() + } + if shards + else {}, + "topology": topology.model_dump() if topology else None, + } + ) return mcp diff --git a/tests/integration/test_mcp_integration.py b/tests/integration/test_mcp_integration.py new file mode 100644 index 00000000..2abdf48d --- /dev/null +++ b/tests/integration/test_mcp_integration.py @@ -0,0 +1,289 @@ +"""Integration tests for MCP server. + +These tests validate that MCP tools work end-to-end through HTTP endpoint. + +Usage (with servers running): + uv run pytest tests/integration/test_mcp_integration.py -v -x + +Usage (standalone - starts servers automatically): + uv run pytest tests/integration/test_mcp_integration.py -v -x --start-servers + +Usage (in CI - expects servers started externally): + uv run pytest tests/integration/test_mcp_integration.py -m integration -v -x +""" + +import json +import logging +import os +import signal +import subprocess +import sys +import time +from typing import Any, Generator + +import pytest +import requests + +from dnet.api.catalog import get_ci_test_models + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +API_HTTP_PORT = 8080 +API_GRPC_PORT = 58080 +SHARD_HTTP_PORT = 8081 +SHARD_GRPC_PORT = 58081 +BASE_URL = f"http://localhost:{API_HTTP_PORT}" +MCP_URL = f"{BASE_URL}/mcp" + +# Timeouts +HEALTH_CHECK_TIMEOUT = 60 # seconds to wait for servers to start +MODEL_LOAD_TIMEOUT = 300 # seconds to wait for model loading +INFERENCE_TIMEOUT = 120 # seconds for inference + + +def wait_for_health(url: str, timeout: float = HEALTH_CHECK_TIMEOUT) -> bool: + deadline = time.time() + timeout + while time.time() < deadline: + try: + resp = requests.get(f"{url}/health", timeout=2) + if resp.status_code == 200: + return True + except requests.RequestException: + pass + time.sleep(0.5) + return False + + +@pytest.fixture(scope="module") +def servers(start_servers_flag) -> Generator[None, None, None]: + procs: list[subprocess.Popen] = [] + + if start_servers_flag: + shard_cmd = [ + sys.executable, + "-m", + "cli.shard", + "--http-port", + str(SHARD_HTTP_PORT), + "--grpc-port", + str(SHARD_GRPC_PORT), + ] + shard_proc = subprocess.Popen( + shard_cmd, + cwd=os.path.dirname(os.path.dirname(os.path.dirname(__file__))), + env={**os.environ, "PYTHONPATH": "src"}, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + procs.append(shard_proc) + + if not wait_for_health(f"http://localhost:{SHARD_HTTP_PORT}", timeout=30): + shard_proc.terminate() + try: + shard_proc.wait(timeout=5) + except subprocess.TimeoutExpired: + shard_proc.kill() + shard_proc.wait() + pytest.skip(f"Shard server not healthy at port {SHARD_HTTP_PORT}") + + api_cmd = [ + sys.executable, + "-m", + "cli.api", + "--http-port", + str(API_HTTP_PORT), + "--grpc-port", + str(API_GRPC_PORT), + ] + api_proc = subprocess.Popen( + api_cmd, + cwd=os.path.dirname(os.path.dirname(os.path.dirname(__file__))), + env={**os.environ, "PYTHONPATH": "src"}, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + procs.append(api_proc) + + if not wait_for_health(BASE_URL): + for p in procs: + p.terminate() + try: + p.wait(timeout=5) + except subprocess.TimeoutExpired: + p.kill() + p.wait() + pytest.skip(f"Server not healthy at {BASE_URL}/health") + + # When starting servers automatically, wait for P2P discovery to find shards + # This is needed because MCP's load_model will try to profile immediately + if start_servers_flag: + if not wait_for_shards_discovered(BASE_URL, timeout=30): + for p in procs: + p.terminate() + try: + p.wait(timeout=5) + except subprocess.TimeoutExpired: + p.kill() + p.wait() + pytest.skip("Shards not discovered within timeout") + + yield + + for p in procs: + p.send_signal(signal.SIGTERM) + try: + p.wait(timeout=10) + except subprocess.TimeoutExpired: + p.kill() + p.wait() + + +def mcp_call_tool( + tool_name: str, arguments: dict[str, Any], timeout: float | None = None +) -> Any: + """Call an MCP tool via HTTP transport. + + Args: + tool_name: Name of the MCP tool to call + arguments: Arguments to pass to the tool + timeout: Optional timeout in seconds. If None, no timeout is applied. + """ + try: + from fastmcp.client import Client + from fastmcp.client.transports import StreamableHttpTransport + except ImportError: + pytest.skip("fastmcp not available") + + async def _call(): + async with Client(transport=StreamableHttpTransport(MCP_URL)) as client: + return await client.call_tool(name=tool_name, arguments=arguments) + + import asyncio + + if timeout is not None: + return asyncio.run(asyncio.wait_for(_call(), timeout=timeout)) + else: + return asyncio.run(_call()) + + +def wait_for_shards_discovered(base_url: str, timeout: float = 30) -> bool: + """Wait for at least one shard to be discovered via P2P discovery.""" + deadline = time.time() + timeout + while time.time() < deadline: + try: + resp = requests.get(f"{base_url}/v1/devices", timeout=2) + if resp.status_code == 200: + data = resp.json() + devices = data.get("devices", {}) + # Check if we have any non-manager devices (shards) + shard_count = sum( + 1 + for props in devices.values() + if not props.get("is_manager", False) + ) + if shard_count > 0: + return True + except requests.RequestException: + pass + time.sleep(0.5) + return False + + +def prepare_and_load_model_mcp(model_id: str) -> None: + """Prepare topology and load model via MCP. + + MCP's load_model already handles topology preparation internally if needed. + """ + result = mcp_call_tool( + "load_model", {"model": model_id}, timeout=MODEL_LOAD_TIMEOUT + ) + assert result.data is not None + assert "loaded successfully" in result.data.lower() + + +def unload_model_mcp() -> None: + """Unload the current model via MCP. + + Logs a warning if unloading fails, as this is a best-effort cleanup. + """ + try: + mcp_call_tool("unload_model", {}) + except Exception as e: + logger.warning(f"Failed to unload model via MCP (best effort): {e}") + + +CI_TEST_MODELS = get_ci_test_models() + + +@pytest.mark.integration +def test_mcp_health_check(servers): + resp = requests.get(f"{MCP_URL}/mcp-health", timeout=HEALTH_CHECK_TIMEOUT) + resp.raise_for_status() + data = resp.json() + assert data["status"] == "healthy" + assert data["service"] == "dnet-mcp" + + +@pytest.mark.integration +def test_mcp_list_models(servers): + result = mcp_call_tool("list_models", {}) + assert result.data is not None + assert isinstance(result.data, str) + data = json.loads(result.data) + assert "object" in data + assert data["object"] == "list" + assert "data" in data + assert isinstance(data["data"], list) + assert len(data["data"]) > 0 + + +@pytest.mark.integration +def test_mcp_get_status_no_model(servers): + result = mcp_call_tool("get_status", {}) + assert result.data is not None + assert isinstance(result.data, str) + data = json.loads(result.data) + assert "model_loaded" in data + assert "shards_discovered" in data + + +@pytest.mark.integration +@pytest.mark.parametrize( + "model", + CI_TEST_MODELS[:1], + ids=[m["alias"] for m in CI_TEST_MODELS[:1]], +) +def test_mcp_load_and_chat(servers, model: dict[str, Any]): + model_id = model["id"] + try: + prepare_and_load_model_mcp(model_id) + + result = mcp_call_tool( + "chat_completion", + { + "messages": [ + {"role": "user", "content": "What is 2+2? Reply briefly."} + ], + "max_tokens": 50, + "temperature": 0.1, + }, + timeout=INFERENCE_TIMEOUT, + ) + assert result.data is not None + assert isinstance(result.data, str) + assert len(result.data) > 0 + assert result.data.strip() + + finally: + unload_model_mcp() + + +@pytest.mark.integration +def test_mcp_get_cluster_details(servers): + result = mcp_call_tool("get_cluster_details", {}) + assert result.data is not None + assert isinstance(result.data, str) + data = json.loads(result.data) + assert "devices" in data + assert "topology" in data diff --git a/tests/subsystems/test_api_http_server.py b/tests/subsystems/test_api_http_server.py index 15eaeb52..b9f9ef90 100644 --- a/tests/subsystems/test_api_http_server.py +++ b/tests/subsystems/test_api_http_server.py @@ -409,7 +409,7 @@ def test_load_model_bootstrap_profiles_empty_returns_failure(monkeypatch): srv = _create_server(cm, im, mm) monkeypatch.setattr( - "dnet.api.http_api.get_model_config_json", + "dnet.api.load_helpers.get_model_config_json", lambda m: {"hidden_size": 8, "num_hidden_layers": 4}, raising=True, ) @@ -437,7 +437,7 @@ def test_load_model_bootstrap_success_connects(monkeypatch): srv = _create_server(cm, im, mm) monkeypatch.setattr( - "dnet.api.http_api.get_model_config_json", + "dnet.api.load_helpers.get_model_config_json", lambda m: {"hidden_size": 16, "num_hidden_layers": 6}, raising=True, ) @@ -450,7 +450,7 @@ async def _prof(model_id, emb, maxb, batches): from tests.fakes import FakeModelProfile as _MP2 monkeypatch.setattr( - "dnet.api.http_api.profile_model", + "dnet.api.load_helpers.profile_model", lambda repo_id, batch_sizes, sequence_length: _MP2(), raising=True, )