Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
9cc99be
Added NanoChat evaluation modules under benchmarks/language_models.
Jasmine-Yuting-Zhang Oct 27, 2025
06e666a
Added evaluation script for running NanoChat benchmark on HuggingFace…
Jasmine-Yuting-Zhang Oct 27, 2025
473d350
Added missing explanation of --max_per_task default value.
Jasmine-Yuting-Zhang Oct 27, 2025
f703348
Moved NanoChat benchmark from benchmarks/language_models to plato/ben…
Jasmine-Yuting-Zhang Oct 28, 2025
08dc54e
Cleaned up unused code from nanochat.
Jasmine-Yuting-Zhang Oct 28, 2025
a978beb
Added abstract eval_model() to TestingStrategy.
Jasmine-Yuting-Zhang Oct 28, 2025
c9db3e0
Added eval_model() to DefaultTestingStrategy.
Jasmine-Yuting-Zhang Oct 28, 2025
34ec582
Added benchmark result save/load utilities.
Jasmine-Yuting-Zhang Oct 28, 2025
956922f
Implemented benchmark evaluation pipeline with multiprocessing.
Jasmine-Yuting-Zhang Oct 28, 2025
103c665
Added registry for benchmark.
Jasmine-Yuting-Zhang Oct 28, 2025
c3c5020
Added base class for evaluating trained models.
Jasmine-Yuting-Zhang Oct 28, 2025
ed4b025
Added CORE benchmark implementation for language models.
Jasmine-Yuting-Zhang Oct 28, 2025
d8cf214
Added helper functions for CORE benchmark implementation.
Jasmine-Yuting-Zhang Oct 28, 2025
aaac39b
Added benchmark evaluation support in fedavg.py.
Jasmine-Yuting-Zhang Oct 28, 2025
0a36f74
Added benchmark configuration support in config.py.
Jasmine-Yuting-Zhang Oct 28, 2025
e60d99b
Added support for split learning benchmark evaluation.
Jasmine-Yuting-Zhang Oct 28, 2025
ec1c1ba
Reformatted code using Ruff.
Jasmine-Yuting-Zhang Oct 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,49 @@ def test_model(self, model, config, testset, sampler, context):
tester.log_metrics("eval", metrics)
return metrics["eval_accuracy"]

def eval_model(self, model, config, benchmark, sampler, context):
"""
Evaluate the model using the benchmark specified in the configuration.

This is a specialized implementation for HuggingFace-based models.

Arguments:
model: The model to evaluate
config: Testing configuration dictionary
benchmark: Benchmark instance (e.g., from plato.benchmarks.registry.get())
sampler: Optional data sampler (not used for CORE benchmark)
context: Training context

Returns:
Benchmark results dictionary containing:
- 'results': per-task accuracies (for CORE)
- 'centered_results': normalized scores (for CORE)
- 'core_metric': overall benchmark score (for CORE)
"""

if hasattr(model, "copy_weight"):
model.copy_weight()

# Get base model if available
base_model = model.base_model if hasattr(model, "base_model") else model

# Set model to eval mode and move to device
base_model.to(context.device)
base_model.eval()

if hasattr(benchmark, "model"):
benchmark.model = base_model
if hasattr(benchmark, "device"):
benchmark.device = context.device
if hasattr(benchmark, "tokenizer") and self.tokenizer is not None:
benchmark.tokenizer = self.tokenizer

# Use benchmark's evaluate method to get results
# benchmark.evaluate() returns dict with metrics
results = benchmark.evaluate()

return results


# ============================================================================
# Custom Callbacks for LLM Split Learning
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,18 @@ random_seed = 1
# IID, biased, or sharded?
sampler = "iid"

[benchmark]
type = "core" # Benchmark type (from registry)
max_per_task = 16 # Limit samples per task for faster evaluation
random_seed = 1

[trainer]

# The type of the trainer
type = "split_learning"

# The maximum number of training rounds
rounds = 100000
rounds = 10

# The machine learning model
model_type = "huggingface"
Expand Down
4 changes: 1 addition & 3 deletions examples/unlearning/fedunlearning/fedunlearning_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,7 @@ async def aggregate_deltas(self, updates, deltas_received, context):

if not filtered_pairs:
if self._fallback_to_original:
return await super().aggregate_deltas(
updates, deltas_received, context
)
return await super().aggregate_deltas(updates, deltas_received, context)

zero_delta = self._zero_delta(
context, deltas_received[0] if deltas_received else None
Expand Down
Empty file added plato/benchmarks/__init__.py
Empty file.
136 changes: 136 additions & 0 deletions plato/benchmarks/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
"""
Base class for benchmarks evaluating trained models.
"""

from typing import Any
from abc import ABC, abstractmethod
import gzip
import logging
import os
import sys
import tarfile
import zipfile
from pathlib import Path
from urllib.parse import urlparse
import requests
import contextlib, time


class Benchmark(ABC):
"""Base class for model benchmarks."""

def __init__(self):
"""
Initialize the benchmark.
"""
super().__init__()

@abstractmethod
def evaluate(self) -> dict[str, Any]:
"""
Evaluate the model on benchmark tasks.

evaluate() returns evaluation results.

Returns:
Dictionary of evaluation metrics

Example:
>>> results = benchmark.evaluate()
>>> print(results)
{'task1_accuracy': 0.85, 'overall': 0.875}
"""
pass

@abstractmethod
def get_formatted_result(self) -> str:
pass

# Borrowed from plato/datasources/base.py
@staticmethod
@contextlib.contextmanager
def _download_guard(data_path: str):
"""Serialise dataset downloads to avoid concurrent corruption."""
os.makedirs(data_path, exist_ok=True)
lock_file = os.path.join(data_path, ".download.lock")
lock_fd = None
waited = False

try:
while True:
try:
lock_fd = os.open(lock_file, os.O_CREAT | os.O_EXCL | os.O_RDWR)
break
except FileExistsError:
if not waited:
logging.info(
"Another process is preparing the dataset at %s. Waiting.",
data_path,
)
waited = True
time.sleep(1)
yield
finally:
if lock_fd is not None:
os.close(lock_fd)
try:
os.remove(lock_file)
except FileNotFoundError:
pass

@staticmethod
def download(url, data_path):
"""Download a dataset from a URL if it is not already available."""
url_parse = urlparse(url)
file_name = os.path.join(data_path, url_parse.path.split("/")[-1])
os.makedirs(data_path, exist_ok=True)
sentinel = Path(f"{file_name}.complete")

if sentinel.exists():
return

with Benchmark._download_guard(data_path):
if sentinel.exists():
return

logging.info("Downloading %s.", url)

res = requests.get(url, stream=True, timeout=60)
total_size = int(res.headers.get("Content-Length", 0))
downloaded_size = 0

with open(file_name, "wb+") as file:
for chunk in res.iter_content(chunk_size=1024):
if not chunk:
continue
downloaded_size += len(chunk)
file.write(chunk)
file.flush()
if total_size:
sys.stdout.write(f"\r{100 * downloaded_size / total_size:.1f}%")
sys.stdout.flush()
if total_size:
sys.stdout.write("\n")

# Unzip the compressed file just downloaded
logging.info("Decompressing the dataset downloaded.")
name, suffix = os.path.splitext(file_name)

if file_name.endswith("tar.gz"):
with tarfile.open(file_name, "r:gz") as tar:
tar.extractall(data_path)
os.remove(file_name)
elif suffix == ".zip":
logging.info("Extracting %s to %s.", file_name, data_path)
with zipfile.ZipFile(file_name, "r") as zip_ref:
zip_ref.extractall(data_path)
elif suffix == ".gz":
with gzip.open(file_name, "rb") as zipped_file:
with open(name, "wb") as unzipped_file:
unzipped_file.write(zipped_file.read())
os.remove(file_name)
else:
logging.info("Unknown compressed file type for %s.", file_name)
sys.exit()

sentinel.touch()
Loading