Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/metrics/grafana.json
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@
"uid": "${DS_PROMETHEUS}"
},
"editorMode": "code",
"expr": "rate(ucm:interval_lookup_hit_rates_sum{model_name=\"$model_name\"}[$__rate_interval])\n/\nrate(ucm:interval_lookup_hit_rates_count{model_name=\"$model_name\"}[$__rate_interval])",
"expr": "rate(ucm:interval_lookup_hit_rates_sum{model_name=\"$model_name\",worker_id=\"0\"}[$__rate_interval])\n/\nrate(ucm:interval_lookup_hit_rates_count{model_name=\"$model_name\",worker_id=\"0\"}[$__rate_interval])",
"hide": false,
"instant": false,
"legendFormat": "Average",
Expand Down
337 changes: 336 additions & 1 deletion ucm/integration/vllm/ucm_connector.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,26 @@
import copy
import hashlib
import itertools
import os
import pickle
import time
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Callable, List, Optional
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, List, Optional

import torch
import yaml
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
KVConnectorMetadata,
KVConnectorRole,
)
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorPromMetrics,
KVConnectorStats,
PromMetric,
PromMetricT,
)
from vllm.distributed.parallel_state import get_tp_group, get_world_group
from vllm.platforms import current_platform
from vllm.v1.core.sched.output import SchedulerOutput
Expand Down Expand Up @@ -647,6 +655,50 @@ def get_block_ids_with_load_errors(self) -> set[int]:
self._invalid_block_ids = set()
return res

def get_kv_connector_stats(self) -> Optional[KVConnectorStats]:
if not hasattr(self, "monitor") or self.monitor is None:
return None

# Get stats from monitor using get_stats_and_clear
try:
stats_dict = self.monitor.get_stats_and_clear("ConnStats")
if not stats_dict:
return None

# Convert monitor stats to UCMKVConnectorStats format
ucm_stats = UCMKVConnectorStats()
# Record stats for this worker
ucm_stats.record_stats(stats_dict, worker_rank=self.global_rank)

if not ucm_stats.is_empty():
return ucm_stats
return None
except Exception as e:
logger.warning(f"Failed to get stats from monitor: {e}")
return None

@classmethod
def build_kv_connector_stats(
cls, data: dict[str, Any] | None = None
) -> Optional[KVConnectorStats]:
return (
UCMKVConnectorStats(data=data)
if data is not None
else UCMKVConnectorStats()
)

@classmethod
def build_prom_metrics(
cls,
vllm_config: "VllmConfig",
metric_types: dict[type[PromMetric], type[PromMetricT]],
labelnames: list[str],
per_engine_labelvalues: dict[int, list[object]],
) -> Optional[KVConnectorPromMetrics]:
return UCMPromMetrics(
vllm_config, metric_types, labelnames, per_engine_labelvalues
)


class UCMLayerWiseConnector(UCMDirectConnector):
"""
Expand Down Expand Up @@ -897,3 +949,286 @@ def get_block_ids_with_load_errors(self) -> set[int]:
Empty set if no load errors occurred.
"""
return self.connector.get_block_ids_with_load_errors()

def get_kv_connector_stats(self) -> Optional[KVConnectorStats]:
return self.connector.get_kv_connector_stats()

@classmethod
def build_kv_connector_stats(
cls, data: dict[str, Any] | None = None
) -> Optional[KVConnectorStats]:
return UCMDirectConnector.build_kv_connector_stats(data)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Must we use specific class here?


@classmethod
def build_prom_metrics(
cls,
vllm_config: VllmConfig,
metric_types: dict[type[PromMetric], type[PromMetricT]],
labelnames: list[str],
per_engine_labelvalues: dict[int, list[object]],
) -> Optional[KVConnectorPromMetrics]:
return UCMDirectConnector.build_prom_metrics(
vllm_config, metric_types, labelnames, per_engine_labelvalues
)


@dataclass
class UCMKVConnectorStats(KVConnectorStats):
"""Data structure: {worker_rank: {metric_name: [values]}}"""

def reset(self):
"""Reset stats for all workers"""
for worker_data in self.data.values():
worker_data.clear()
self.data.clear()

def record_stats(self, stats_dict: dict[str, Any], worker_rank: int | None = None):
"""Record stats from monitor data for a specific worker"""
worker_key = str(worker_rank) if worker_rank is not None else "0"

if worker_key not in self.data:
self.data[worker_key] = {}

worker_data = self.data[worker_key]

for key, value in stats_dict.items():
if key not in worker_data:
worker_data[key] = []

if isinstance(value, list):
worker_data[key].extend([float(v) for v in value])
else:
worker_data[key].append(float(value))

def is_empty(self) -> bool:
return all(
all(len(v) == 0 for v in worker_data.values())
for worker_data in self.data.values()
)

def aggregate(self, other: KVConnectorStats) -> KVConnectorStats:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The excessive nesting depth (five levels) can be simplified by extracting helper functions and streamlining conditional checks. Same as function reduce

"""Aggregate stats from another worker, preserving per-worker separation"""
if other.is_empty():
return self

assert isinstance(other, UCMKVConnectorStats), "Expected UCMKVConnectorStats"

for worker_rank, worker_data in other.data.items():
if worker_rank not in self.data:
self.data[worker_rank] = copy.deepcopy(worker_data)
continue

for metric_name, values in worker_data.items():
self.data[worker_rank].setdefault(metric_name, []).extend(values)

return self

def reduce(self) -> dict[str, int | float]:
"""Reduce the observations to representative values for CLI logging."""
result = {}
is_count_metric = lambda k: any(x in k for x in ["num", "requests", "blocks"])
is_hit_rate_metric = lambda k: "hit_rate" in k or "hit_rates" in k

for worker_rank, worker_data in sorted(self.data.items()):
for metric_name, values in worker_data.items():
if is_hit_rate_metric(metric_name) and worker_rank != "0":
continue

suffix = "(total)" if is_count_metric(metric_name) else "(avg)"
if is_hit_rate_metric(metric_name) and worker_rank == "0":
worker_key = f"{metric_name} {suffix}"
else:
worker_key = f"worker_{worker_rank}_{metric_name} {suffix}"

if not values:
result[worker_key] = 0 if is_count_metric(metric_name) else 0.0
elif is_count_metric(metric_name):
result[worker_key] = int(sum(values))
else:
result[worker_key] = round(sum(values) / len(values), 3)

return result


class UCMPromMetrics(KVConnectorPromMetrics):
"""
Prometheus metrics for UCM connector.
Records metrics from self.monitor data based on metrics_configs.yaml configuration.
"""

_config_cache: Dict[str, Dict[str, Any]] = {}

def __init__(
self,
vllm_config: VllmConfig,
metric_types: dict[type[PromMetric], type[PromMetricT]],
labelnames: list[str],
per_engine_labelvalues: dict[int, list[object]],
):
super().__init__(vllm_config, metric_types, labelnames, per_engine_labelvalues)

config = self._load_config(vllm_config)
prometheus_config = config.get("prometheus", {})

self.metric_mappings: Dict[str, Dict[str, Any]] = {}
self._init_metrics_from_config(
labelnames, prometheus_config, per_engine_labelvalues
)

@staticmethod
def _get_metrics_config_path(vllm_config: VllmConfig) -> str:
"""Get metrics_config_path from vllm_config"""
if vllm_config.kv_transfer_config is None:
return ""
try:
ucm_config = Config(vllm_config.kv_transfer_config)
return ucm_config.get_config().get("metrics_config_path", "")
except Exception:
return ""

def _load_config(self, vllm_config: VllmConfig) -> Dict[str, Any]:
"""Load configuration from YAML file with caching"""
metrics_config_path = self._get_metrics_config_path(vllm_config)
if not metrics_config_path:
return {}

if metrics_config_path in self._config_cache:
return self._config_cache[metrics_config_path]

try:
with open(metrics_config_path, "r") as f:
config = yaml.safe_load(f)
if config:
self._config_cache[metrics_config_path] = config
return config or {}
except FileNotFoundError:
return {}
except yaml.YAMLError as e:
logger.error(f"Error parsing YAML config file {metrics_config_path}: {e}")
return {}

def _init_metrics_from_config(
self,
labelnames: list[str],
prometheus_config: Dict[str, Any],
per_engine_labelvalues: dict[int, list[object]],
):
enabled = prometheus_config.get("enabled_metrics", {})
metric_prefix = prometheus_config.get("metric_prefix", "ucm:")

extended_labelnames = (
labelnames + ["worker_id"] if "worker_id" not in labelnames else labelnames
)

metric_types_config = {
"counter": (
self._counter_cls,
enabled.get("counters", True),
prometheus_config.get("counters", []),
),
"gauge": (
self._gauge_cls,
enabled.get("gauges", True),
prometheus_config.get("gauges", []),
),
"histogram": (
self._histogram_cls,
enabled.get("histograms", True),
prometheus_config.get("histograms", []),
),
}

for metric_type, (
metric_cls,
is_enabled,
metrics_list,
) in metric_types_config.items():
if not is_enabled:
continue

for metric_cfg in metrics_list:
if not (name := metric_cfg.get("name")):
continue

doc = metric_cfg.get("documentation", "")
prometheus_name = f"{metric_prefix}{name}"
attr_name = f"{metric_type}_{name}"

metric_kwargs = {
"name": prometheus_name,
"documentation": doc,
"labelnames": extended_labelnames,
}
if metric_type == "gauge":
metric_kwargs["multiprocess_mode"] = metric_cfg.get(
"multiprocess_mode", "live"
)
elif metric_type == "histogram":
metric_kwargs["buckets"] = metric_cfg.get("buckets", [])

metric = metric_cls(**metric_kwargs)
setattr(self, attr_name, metric)
self.metric_mappings[name] = {
"type": metric_type,
"attr": attr_name,
"extended_labelnames": extended_labelnames,
}

def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
"""Record transfer statistics to Prometheus metrics based on configuration."""
if transfer_stats_data and isinstance(transfer_stats_data, dict):
first_key = next(iter(transfer_stats_data.keys()), None)
if first_key and isinstance(transfer_stats_data[first_key], dict):
for worker_rank, worker_stats in sorted(transfer_stats_data.items()):
self._observe_worker_stats(worker_stats, engine_idx, worker_rank)
return

# Fallback: treat as single worker (backward compatibility)
self._observe_worker_stats(transfer_stats_data, engine_idx, "aggregated")

def _observe_worker_stats(
self, worker_stats: dict[str, Any], engine_idx: int, worker_rank: str
):
"""Helper method to observe stats for a specific worker"""
if not worker_stats:
return

base_labelvalues = self._per_engine_labelvalues.get(engine_idx, [])
extended_labelvalues = list(base_labelvalues) + [str(worker_rank)]

for stat_name, value in worker_stats.items():
try:
if stat_name not in self.metric_mappings:
continue

metric_mapped = self.metric_mappings[stat_name]
base_metric = getattr(self, metric_mapped["attr"], None)
if base_metric is None:
logger.warning(f"Metric {stat_name} not initialized.")
continue

metric_type = metric_mapped["type"]
extended_labelnames = metric_mapped.get("extended_labelnames", [])

if "worker_id" in extended_labelnames:
per_engine_metric = base_metric.labels(*extended_labelvalues)
else:
per_engine_metric = base_metric.labels(*base_labelvalues)

values = value if isinstance(value, list) else [value]

if metric_type == "counter":
for val in values:
if val >= 0:
per_engine_metric.inc(int(val))
elif metric_type == "gauge":
if values:
per_engine_metric.set(float(values[-1]))
elif metric_type == "histogram":
for val in values:
per_engine_metric.observe(float(val))

except Exception as e:
logger.warning(
f"Failed to log metric {stat_name} for worker {worker_rank}: {e}"
)