diff --git a/examples/metrics/grafana.json b/examples/metrics/grafana.json index 72d175971..6cef8182b 100644 --- a/examples/metrics/grafana.json +++ b/examples/metrics/grafana.json @@ -121,7 +121,7 @@ "uid": "${DS_PROMETHEUS}" }, "editorMode": "code", - "expr": "rate(ucm:interval_lookup_hit_rates_sum{model_name=\"$model_name\"}[$__rate_interval])\n/\nrate(ucm:interval_lookup_hit_rates_count{model_name=\"$model_name\"}[$__rate_interval])", + "expr": "rate(ucm:interval_lookup_hit_rates_sum{model_name=\"$model_name\",worker_id=\"0\"}[$__rate_interval])\n/\nrate(ucm:interval_lookup_hit_rates_count{model_name=\"$model_name\",worker_id=\"0\"}[$__rate_interval])", "hide": false, "instant": false, "legendFormat": "Average", diff --git a/ucm/integration/vllm/ucm_connector.py b/ucm/integration/vllm/ucm_connector.py index 66216a255..ed4dfc5a2 100644 --- a/ucm/integration/vllm/ucm_connector.py +++ b/ucm/integration/vllm/ucm_connector.py @@ -1,18 +1,26 @@ +import copy import hashlib import itertools import os import pickle import time from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Callable, List, Optional +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, List, Optional import torch +import yaml from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole, ) +from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( + KVConnectorPromMetrics, + KVConnectorStats, + PromMetric, + PromMetricT, +) from vllm.distributed.parallel_state import get_tp_group, get_world_group from vllm.platforms import current_platform from vllm.v1.core.sched.output import SchedulerOutput @@ -647,6 +655,50 @@ def get_block_ids_with_load_errors(self) -> set[int]: self._invalid_block_ids = set() return res + def get_kv_connector_stats(self) -> Optional[KVConnectorStats]: + if not hasattr(self, "monitor") or self.monitor is None: + return None + + # Get stats from monitor using get_stats_and_clear + try: + stats_dict = self.monitor.get_stats_and_clear("ConnStats") + if not stats_dict: + return None + + # Convert monitor stats to UCMKVConnectorStats format + ucm_stats = UCMKVConnectorStats() + # Record stats for this worker + ucm_stats.record_stats(stats_dict, worker_rank=self.global_rank) + + if not ucm_stats.is_empty(): + return ucm_stats + return None + except Exception as e: + logger.warning(f"Failed to get stats from monitor: {e}") + return None + + @classmethod + def build_kv_connector_stats( + cls, data: dict[str, Any] | None = None + ) -> Optional[KVConnectorStats]: + return ( + UCMKVConnectorStats(data=data) + if data is not None + else UCMKVConnectorStats() + ) + + @classmethod + def build_prom_metrics( + cls, + vllm_config: "VllmConfig", + metric_types: dict[type[PromMetric], type[PromMetricT]], + labelnames: list[str], + per_engine_labelvalues: dict[int, list[object]], + ) -> Optional[KVConnectorPromMetrics]: + return UCMPromMetrics( + vllm_config, metric_types, labelnames, per_engine_labelvalues + ) + class UCMLayerWiseConnector(UCMDirectConnector): """ @@ -897,3 +949,286 @@ def get_block_ids_with_load_errors(self) -> set[int]: Empty set if no load errors occurred. """ return self.connector.get_block_ids_with_load_errors() + + def get_kv_connector_stats(self) -> Optional[KVConnectorStats]: + return self.connector.get_kv_connector_stats() + + @classmethod + def build_kv_connector_stats( + cls, data: dict[str, Any] | None = None + ) -> Optional[KVConnectorStats]: + return UCMDirectConnector.build_kv_connector_stats(data) + + @classmethod + def build_prom_metrics( + cls, + vllm_config: VllmConfig, + metric_types: dict[type[PromMetric], type[PromMetricT]], + labelnames: list[str], + per_engine_labelvalues: dict[int, list[object]], + ) -> Optional[KVConnectorPromMetrics]: + return UCMDirectConnector.build_prom_metrics( + vllm_config, metric_types, labelnames, per_engine_labelvalues + ) + + +@dataclass +class UCMKVConnectorStats(KVConnectorStats): + """Data structure: {worker_rank: {metric_name: [values]}}""" + + def reset(self): + """Reset stats for all workers""" + for worker_data in self.data.values(): + worker_data.clear() + self.data.clear() + + def record_stats(self, stats_dict: dict[str, Any], worker_rank: int | None = None): + """Record stats from monitor data for a specific worker""" + worker_key = str(worker_rank) if worker_rank is not None else "0" + + if worker_key not in self.data: + self.data[worker_key] = {} + + worker_data = self.data[worker_key] + + for key, value in stats_dict.items(): + if key not in worker_data: + worker_data[key] = [] + + if isinstance(value, list): + worker_data[key].extend([float(v) for v in value]) + else: + worker_data[key].append(float(value)) + + def is_empty(self) -> bool: + return all( + all(len(v) == 0 for v in worker_data.values()) + for worker_data in self.data.values() + ) + + def aggregate(self, other: KVConnectorStats) -> KVConnectorStats: + """Aggregate stats from another worker, preserving per-worker separation""" + if other.is_empty(): + return self + + assert isinstance(other, UCMKVConnectorStats), "Expected UCMKVConnectorStats" + + for worker_rank, worker_data in other.data.items(): + if worker_rank not in self.data: + self.data[worker_rank] = copy.deepcopy(worker_data) + continue + + for metric_name, values in worker_data.items(): + self.data[worker_rank].setdefault(metric_name, []).extend(values) + + return self + + def reduce(self) -> dict[str, int | float]: + """Reduce the observations to representative values for CLI logging.""" + result = {} + is_count_metric = lambda k: any(x in k for x in ["num", "requests", "blocks"]) + is_hit_rate_metric = lambda k: "hit_rate" in k or "hit_rates" in k + + for worker_rank, worker_data in sorted(self.data.items()): + for metric_name, values in worker_data.items(): + if is_hit_rate_metric(metric_name) and worker_rank != "0": + continue + + suffix = "(total)" if is_count_metric(metric_name) else "(avg)" + if is_hit_rate_metric(metric_name) and worker_rank == "0": + worker_key = f"{metric_name} {suffix}" + else: + worker_key = f"worker_{worker_rank}_{metric_name} {suffix}" + + if not values: + result[worker_key] = 0 if is_count_metric(metric_name) else 0.0 + elif is_count_metric(metric_name): + result[worker_key] = int(sum(values)) + else: + result[worker_key] = round(sum(values) / len(values), 3) + + return result + + +class UCMPromMetrics(KVConnectorPromMetrics): + """ + Prometheus metrics for UCM connector. + Records metrics from self.monitor data based on metrics_configs.yaml configuration. + """ + + _config_cache: Dict[str, Dict[str, Any]] = {} + + def __init__( + self, + vllm_config: VllmConfig, + metric_types: dict[type[PromMetric], type[PromMetricT]], + labelnames: list[str], + per_engine_labelvalues: dict[int, list[object]], + ): + super().__init__(vllm_config, metric_types, labelnames, per_engine_labelvalues) + + config = self._load_config(vllm_config) + prometheus_config = config.get("prometheus", {}) + + self.metric_mappings: Dict[str, Dict[str, Any]] = {} + self._init_metrics_from_config( + labelnames, prometheus_config, per_engine_labelvalues + ) + + @staticmethod + def _get_metrics_config_path(vllm_config: VllmConfig) -> str: + """Get metrics_config_path from vllm_config""" + if vllm_config.kv_transfer_config is None: + return "" + try: + ucm_config = Config(vllm_config.kv_transfer_config) + return ucm_config.get_config().get("metrics_config_path", "") + except Exception: + return "" + + def _load_config(self, vllm_config: VllmConfig) -> Dict[str, Any]: + """Load configuration from YAML file with caching""" + metrics_config_path = self._get_metrics_config_path(vllm_config) + if not metrics_config_path: + return {} + + if metrics_config_path in self._config_cache: + return self._config_cache[metrics_config_path] + + try: + with open(metrics_config_path, "r") as f: + config = yaml.safe_load(f) + if config: + self._config_cache[metrics_config_path] = config + return config or {} + except FileNotFoundError: + return {} + except yaml.YAMLError as e: + logger.error(f"Error parsing YAML config file {metrics_config_path}: {e}") + return {} + + def _init_metrics_from_config( + self, + labelnames: list[str], + prometheus_config: Dict[str, Any], + per_engine_labelvalues: dict[int, list[object]], + ): + enabled = prometheus_config.get("enabled_metrics", {}) + metric_prefix = prometheus_config.get("metric_prefix", "ucm:") + + extended_labelnames = ( + labelnames + ["worker_id"] if "worker_id" not in labelnames else labelnames + ) + + metric_types_config = { + "counter": ( + self._counter_cls, + enabled.get("counters", True), + prometheus_config.get("counters", []), + ), + "gauge": ( + self._gauge_cls, + enabled.get("gauges", True), + prometheus_config.get("gauges", []), + ), + "histogram": ( + self._histogram_cls, + enabled.get("histograms", True), + prometheus_config.get("histograms", []), + ), + } + + for metric_type, ( + metric_cls, + is_enabled, + metrics_list, + ) in metric_types_config.items(): + if not is_enabled: + continue + + for metric_cfg in metrics_list: + if not (name := metric_cfg.get("name")): + continue + + doc = metric_cfg.get("documentation", "") + prometheus_name = f"{metric_prefix}{name}" + attr_name = f"{metric_type}_{name}" + + metric_kwargs = { + "name": prometheus_name, + "documentation": doc, + "labelnames": extended_labelnames, + } + if metric_type == "gauge": + metric_kwargs["multiprocess_mode"] = metric_cfg.get( + "multiprocess_mode", "live" + ) + elif metric_type == "histogram": + metric_kwargs["buckets"] = metric_cfg.get("buckets", []) + + metric = metric_cls(**metric_kwargs) + setattr(self, attr_name, metric) + self.metric_mappings[name] = { + "type": metric_type, + "attr": attr_name, + "extended_labelnames": extended_labelnames, + } + + def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0): + """Record transfer statistics to Prometheus metrics based on configuration.""" + if transfer_stats_data and isinstance(transfer_stats_data, dict): + first_key = next(iter(transfer_stats_data.keys()), None) + if first_key and isinstance(transfer_stats_data[first_key], dict): + for worker_rank, worker_stats in sorted(transfer_stats_data.items()): + self._observe_worker_stats(worker_stats, engine_idx, worker_rank) + return + + # Fallback: treat as single worker (backward compatibility) + self._observe_worker_stats(transfer_stats_data, engine_idx, "aggregated") + + def _observe_worker_stats( + self, worker_stats: dict[str, Any], engine_idx: int, worker_rank: str + ): + """Helper method to observe stats for a specific worker""" + if not worker_stats: + return + + base_labelvalues = self._per_engine_labelvalues.get(engine_idx, []) + extended_labelvalues = list(base_labelvalues) + [str(worker_rank)] + + for stat_name, value in worker_stats.items(): + try: + if stat_name not in self.metric_mappings: + continue + + metric_mapped = self.metric_mappings[stat_name] + base_metric = getattr(self, metric_mapped["attr"], None) + if base_metric is None: + logger.warning(f"Metric {stat_name} not initialized.") + continue + + metric_type = metric_mapped["type"] + extended_labelnames = metric_mapped.get("extended_labelnames", []) + + if "worker_id" in extended_labelnames: + per_engine_metric = base_metric.labels(*extended_labelvalues) + else: + per_engine_metric = base_metric.labels(*base_labelvalues) + + values = value if isinstance(value, list) else [value] + + if metric_type == "counter": + for val in values: + if val >= 0: + per_engine_metric.inc(int(val)) + elif metric_type == "gauge": + if values: + per_engine_metric.set(float(values[-1])) + elif metric_type == "histogram": + for val in values: + per_engine_metric.observe(float(val)) + + except Exception as e: + logger.warning( + f"Failed to log metric {stat_name} for worker {worker_rank}: {e}" + )