diff --git a/examples/split_learning/llm_split_learning/split_learning_trainer.py b/examples/split_learning/llm_split_learning/split_learning_trainer.py index b921eba66..348d6a34b 100644 --- a/examples/split_learning/llm_split_learning/split_learning_trainer.py +++ b/examples/split_learning/llm_split_learning/split_learning_trainer.py @@ -162,6 +162,49 @@ def test_model(self, model, config, testset, sampler, context): tester.log_metrics("eval", metrics) return metrics["eval_accuracy"] + def eval_model(self, model, config, benchmark, sampler, context): + """ + Evaluate the model using the benchmark specified in the configuration. + + This is a specialized implementation for HuggingFace-based models. + + Arguments: + model: The model to evaluate + config: Testing configuration dictionary + benchmark: Benchmark instance (e.g., from plato.benchmarks.registry.get()) + sampler: Optional data sampler (not used for CORE benchmark) + context: Training context + + Returns: + Benchmark results dictionary containing: + - 'results': per-task accuracies (for CORE) + - 'centered_results': normalized scores (for CORE) + - 'core_metric': overall benchmark score (for CORE) + """ + + if hasattr(model, "copy_weight"): + model.copy_weight() + + # Get base model if available + base_model = model.base_model if hasattr(model, "base_model") else model + + # Set model to eval mode and move to device + base_model.to(context.device) + base_model.eval() + + if hasattr(benchmark, "model"): + benchmark.model = base_model + if hasattr(benchmark, "device"): + benchmark.device = context.device + if hasattr(benchmark, "tokenizer") and self.tokenizer is not None: + benchmark.tokenizer = self.tokenizer + + # Use benchmark's evaluate method to get results + # benchmark.evaluate() returns dict with metrics + results = benchmark.evaluate() + + return results + # ============================================================================ # Custom Callbacks for LLM Split Learning diff --git a/examples/split_learning/llm_split_learning/split_learning_wikitext2_gpt2.toml b/examples/split_learning/llm_split_learning/split_learning_wikitext2_gpt2.toml index cd068c09c..88c67e48b 100644 --- a/examples/split_learning/llm_split_learning/split_learning_wikitext2_gpt2.toml +++ b/examples/split_learning/llm_split_learning/split_learning_wikitext2_gpt2.toml @@ -40,13 +40,18 @@ random_seed = 1 # IID, biased, or sharded? sampler = "iid" +[benchmark] +type = "core" # Benchmark type (from registry) +max_per_task = 16 # Limit samples per task for faster evaluation +random_seed = 1 + [trainer] # The type of the trainer type = "split_learning" # The maximum number of training rounds -rounds = 100000 +rounds = 10 # The machine learning model model_type = "huggingface" diff --git a/examples/unlearning/fedunlearning/fedunlearning_server.py b/examples/unlearning/fedunlearning/fedunlearning_server.py index 8938b6549..6d6ade576 100644 --- a/examples/unlearning/fedunlearning/fedunlearning_server.py +++ b/examples/unlearning/fedunlearning/fedunlearning_server.py @@ -43,9 +43,7 @@ async def aggregate_deltas(self, updates, deltas_received, context): if not filtered_pairs: if self._fallback_to_original: - return await super().aggregate_deltas( - updates, deltas_received, context - ) + return await super().aggregate_deltas(updates, deltas_received, context) zero_delta = self._zero_delta( context, deltas_received[0] if deltas_received else None diff --git a/plato/benchmarks/__init__.py b/plato/benchmarks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/plato/benchmarks/base.py b/plato/benchmarks/base.py new file mode 100644 index 000000000..3978723ce --- /dev/null +++ b/plato/benchmarks/base.py @@ -0,0 +1,136 @@ +""" +Base class for benchmarks evaluating trained models. +""" + +from typing import Any +from abc import ABC, abstractmethod +import gzip +import logging +import os +import sys +import tarfile +import zipfile +from pathlib import Path +from urllib.parse import urlparse +import requests +import contextlib, time + + +class Benchmark(ABC): + """Base class for model benchmarks.""" + + def __init__(self): + """ + Initialize the benchmark. + """ + super().__init__() + + @abstractmethod + def evaluate(self) -> dict[str, Any]: + """ + Evaluate the model on benchmark tasks. + + evaluate() returns evaluation results. + + Returns: + Dictionary of evaluation metrics + + Example: + >>> results = benchmark.evaluate() + >>> print(results) + {'task1_accuracy': 0.85, 'overall': 0.875} + """ + pass + + @abstractmethod + def get_formatted_result(self) -> str: + pass + + # Borrowed from plato/datasources/base.py + @staticmethod + @contextlib.contextmanager + def _download_guard(data_path: str): + """Serialise dataset downloads to avoid concurrent corruption.""" + os.makedirs(data_path, exist_ok=True) + lock_file = os.path.join(data_path, ".download.lock") + lock_fd = None + waited = False + + try: + while True: + try: + lock_fd = os.open(lock_file, os.O_CREAT | os.O_EXCL | os.O_RDWR) + break + except FileExistsError: + if not waited: + logging.info( + "Another process is preparing the dataset at %s. Waiting.", + data_path, + ) + waited = True + time.sleep(1) + yield + finally: + if lock_fd is not None: + os.close(lock_fd) + try: + os.remove(lock_file) + except FileNotFoundError: + pass + + @staticmethod + def download(url, data_path): + """Download a dataset from a URL if it is not already available.""" + url_parse = urlparse(url) + file_name = os.path.join(data_path, url_parse.path.split("/")[-1]) + os.makedirs(data_path, exist_ok=True) + sentinel = Path(f"{file_name}.complete") + + if sentinel.exists(): + return + + with Benchmark._download_guard(data_path): + if sentinel.exists(): + return + + logging.info("Downloading %s.", url) + + res = requests.get(url, stream=True, timeout=60) + total_size = int(res.headers.get("Content-Length", 0)) + downloaded_size = 0 + + with open(file_name, "wb+") as file: + for chunk in res.iter_content(chunk_size=1024): + if not chunk: + continue + downloaded_size += len(chunk) + file.write(chunk) + file.flush() + if total_size: + sys.stdout.write(f"\r{100 * downloaded_size / total_size:.1f}%") + sys.stdout.flush() + if total_size: + sys.stdout.write("\n") + + # Unzip the compressed file just downloaded + logging.info("Decompressing the dataset downloaded.") + name, suffix = os.path.splitext(file_name) + + if file_name.endswith("tar.gz"): + with tarfile.open(file_name, "r:gz") as tar: + tar.extractall(data_path) + os.remove(file_name) + elif suffix == ".zip": + logging.info("Extracting %s to %s.", file_name, data_path) + with zipfile.ZipFile(file_name, "r") as zip_ref: + zip_ref.extractall(data_path) + elif suffix == ".gz": + with gzip.open(file_name, "rb") as zipped_file: + with open(name, "wb") as unzipped_file: + unzipped_file.write(zipped_file.read()) + os.remove(file_name) + else: + logging.info("Unknown compressed file type for %s.", file_name) + sys.exit() + + sentinel.touch() diff --git a/plato/benchmarks/core.py b/plato/benchmarks/core.py new file mode 100644 index 000000000..b2af7ef53 --- /dev/null +++ b/plato/benchmarks/core.py @@ -0,0 +1,187 @@ +""" +CORE benchmark implementation for evaluating language models. +Borrowed and adapted from: https://github.com/karpathy/nanochat +""" + +import json +import logging +import os +import random +import time +from typing import Any + +import pandas as pd +import torch +import yaml + +from plato.benchmarks import base +from plato.benchmarks.core_helpers import core +from plato.config import Config + + +class Benchmark(base.Benchmark): + """ + CORE benchmark - evaluates language models on the CORE suite. + """ + + def __init__(self): + """ + Initialize CORE benchmark -- load benchmark tasks and data. + """ + super().__init__() + + # These will be set externally before evaluate() is called + self.model = None + self.device = None + self.tokenizer = None + + # Get configuration specific to CORE benchmark + self.random_seed = getattr(Config().benchmark, "random_seed", 24) + self.max_per_task = getattr(Config().benchmark, "max_per_task", -1) + + # Load benchmark tasks and datasets + self._load_benchmark_data() + + def _load_benchmark_data(self): + """ + Load CORE benchmark tasks and evaluation data. + + Downloads the evaluation bundle if not already present, then loads + task configurations and data files. + """ + # Get base directory and ensure eval_bundle is downloaded + benchmark_base_dir = Config.params["benchmark_path"] + + # Download eval_bundle if not present + if not os.path.exists(benchmark_base_dir): + logging.info("CORE evaluation bundle not found. Downloading...") + eval_bundle_url = ( + "https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip" + ) + Benchmark.download(eval_bundle_url, benchmark_base_dir) + + # Load benchmark configuration + eval_bundle_dir = os.path.join(benchmark_base_dir, "eval_bundle") + config_path = os.path.join(eval_bundle_dir, "core.yaml") + self.eval_meta_data_path = os.path.join(eval_bundle_dir, "eval_meta_data.csv") + self.data_base_path = os.path.join(eval_bundle_dir, "eval_data") + + with open(config_path, "r") as f: + config = yaml.safe_load(f) + + self.tasks = config["icl_tasks"] + self.eval_metadata = pd.read_csv(self.eval_meta_data_path) + + def evaluate(self) -> dict[str, Any]: + """ + Evaluate the model on all CORE tasks. + + Returns: + Dictionary containing: + - 'results': per-task accuracies + - 'centered_results': normalized scores + - 'core_metric': overall CORE score + """ + + if self.model is None: + raise RuntimeError("Trainer has no model - cannot run benchmark") + + if self.tokenizer is None: + raise RuntimeError("Trainer has no tokenizer - cannot run benchmark") + + results = {} + centered_results = {} + + # Set model to eval mode + self.model.eval() + + with torch.no_grad(): + for task in self.tasks: + start_time = time.time() + label = task["label"] + + task_meta = { + "task_type": task["icl_task_type"], + "dataset_uri": task["dataset_uri"], + "num_fewshot": task["num_fewshot"][0], + "continuation_delimiter": task.get("continuation_delimiter", " "), + } + + logging.info( + "Evaluating task: %s (%d-shot, type: %s)", + label, + task_meta["num_fewshot"], + task_meta["task_type"], + ) + + # Load data for this task (matching evaluate_model.py pattern) + data_path = os.path.join(self.data_base_path, task_meta["dataset_uri"]) + with open(data_path, "r") as f: + data = [json.loads(line.strip()) for line in f] + + # Shuffle the data for reproducibility (matching evaluate_model.py) + shuffle_rng = random.Random(self.random_seed) + shuffle_rng.shuffle(data) + + # Crop data if max_per_task is specified + if self.max_per_task > 0: + data = data[: self.max_per_task] + + # Run evaluation using existing core_eval logic + accuracy = core.evaluate_task( + self.model, # Model in CUDA memory from trainer + self.tokenizer, # Tokenizer from trainer + data, + self.device, + task_meta, + ) + + results[label] = accuracy + + # Compute centered result (normalized by random baseline) + row = self.eval_metadata[self.eval_metadata["Eval Task"] == label] + random_baseline = row["Random baseline"].values[0] + centered = (accuracy - 0.01 * random_baseline) / ( + 1.0 - 0.01 * random_baseline + ) + centered_results[label] = centered + + elapsed = time.time() - start_time + logging.info( + "accuracy: %.4f | centered: %.4f | time: %.2fs", + accuracy, + centered, + elapsed, + ) + + # Compute overall CORE metric + core_metric = sum(centered_results.values()) / len(centered_results) + + return { + "results": results, + "centered_results": centered_results, + "core_metric": core_metric, + } + + def get_formatted_result(self, evaluation_result: dict[str, Any]) -> str: + """ + Format the evaluation results for display. + + Args: + evaluation_result: The dictionary returned by the evaluate() method. + Returns: + A formatted string summarizing the results. + """ + results = evaluation_result["results"] + centered_results = evaluation_result["centered_results"] + core_metric = evaluation_result["core_metric"] + + result_lines = [f"{'Task':<35}, {'Accuracy':<10}, {'Centered':<10}"] + for task, acc in results.items(): + centered = centered_results[task] + result_lines.append(f"{task:<35}, {acc:<10.6f}, {centered:<10.6f}") + result_lines.append( + f"{'Overall CORE Metric':<35}, {'':<10}, {core_metric:<10.6f}\n" + ) + + return "\n".join(result_lines) diff --git a/plato/benchmarks/core_helpers/core.py b/plato/benchmarks/core_helpers/core.py new file mode 100644 index 000000000..d767dd49d --- /dev/null +++ b/plato/benchmarks/core_helpers/core.py @@ -0,0 +1,281 @@ +""" +Borrowed and adapted from: https://github.com/karpathy/nanochat + +Functions for evaluating the CORE metric, as described in the DCLM paper. +https://arxiv.org/abs/2406.11794 + +TODOs: +- All tasks ~match except for squad. We get 31% reference is 37%. Figure out why. +""" + +import random + +from jinja2 import Template +import torch + +from plato.benchmarks.core_helpers.tokenizer import UniversalHuggingFaceTokenizer + + +def render_prompts_mc(item, continuation_delimiter, fewshot_examples=None): + """Render complete prompts for a multiple choice question""" + template_str = """ +{%- for example in fewshot_examples -%} +{{ example.query }}{{ continuation_delimiter }}{{ example.choices[example.gold] }} + +{% endfor -%} +{{ item.query }}{{ continuation_delimiter }}{{ choice }}""".strip() + template = Template(template_str) + fewshot_examples = fewshot_examples or [] + context = { + "fewshot_examples": fewshot_examples, + "continuation_delimiter": continuation_delimiter, + "item": item, + } + prompts = [template.render(choice=choice, **context) for choice in item["choices"]] + return prompts + + +def render_prompts_schema(item, continuation_delimiter, fewshot_examples=None): + """Render complete prompts for a schema question""" + template_str = """ +{%- for example in fewshot_examples -%} +{{ example.context_options[example.gold] }}{{ continuation_delimiter }}{{ example.continuation }} + +{% endfor -%} +{{ context }}{{ continuation_delimiter }}{{ item.continuation }}""".strip() + template = Template(template_str) + fewshot_examples = fewshot_examples or [] + context = { + "fewshot_examples": fewshot_examples, + "continuation_delimiter": continuation_delimiter, + "item": item, + } + prompts = [ + template.render(context=context_option, **context) + for context_option in item["context_options"] + ] + return prompts + + +def render_prompts_lm(item, continuation_delimiter, fewshot_examples=None): + """ + Render complete prompt for a language modeling task. + Notice that we manually trim the context in the template, + which in some datasets seems to have trailing whitespace (which we don't want). + """ + template_str = """ +{%- for example in fewshot_examples -%} +{{ example.context | trim }}{{ continuation_delimiter }}{{ example.continuation }} + +{% endfor -%} +{{ item.context | trim }}{{ continuation_delimiter }}{% if include_continuation %}{{ item.continuation }}{% endif %}""".strip() + template = Template(template_str) + fewshot_examples = fewshot_examples or [] + context = { + "fewshot_examples": fewshot_examples, + "continuation_delimiter": continuation_delimiter, + "item": item, + } + # Return two prompts: without and with the continuation + prompt_without = template.render(include_continuation=False, **context) + prompt_with = template.render(include_continuation=True, **context) + # Due to the way the data seems to be stored, I think I need to strip in the case of LM here. + # Otherwise we may get trailing whitespaces in prompt_without (which get absorbed into the next + # token in prompt_with), meaning we don't get a nice and clean prefix in the token space + # to detect the final continuation. Tokenizers... + prompt_without = prompt_without.strip() + return [prompt_without, prompt_with] + + +def find_common_length(token_sequences, direction="left"): + """ + Find the length of the common prefix or suffix across token sequences + - direction: 'left' for prefix, 'right' for suffix + """ + min_len = min(len(seq) for seq in token_sequences) + indices = {"left": range(min_len), "right": range(-1, -min_len - 1, -1)}[direction] + # Find the first position where the token sequences differ + for i, idx in enumerate(indices): + token = token_sequences[0][idx] + if not all(seq[idx] == token for seq in token_sequences): + return i + return min_len + + +def stack_sequences(tokens, pad_token_id): + """Stack up a list of token sequences, pad to longest on the right""" + bsz, seq_len = len(tokens), max(len(x) for x in tokens) + input_ids = torch.full((bsz, seq_len), pad_token_id, dtype=torch.long) + for i, x in enumerate(tokens): + input_ids[i, : len(x)] = torch.tensor(x, dtype=torch.long) + return input_ids + + +def batch_sequences_mc(tokenizer, prompts): + # In multiple choice, contexts are the same but the continuation is different (common prefix) + tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id()) + # figure out the start and end of each continuation + answer_start_idx = find_common_length(tokens, direction="left") + start_indices = [answer_start_idx] * len(prompts) + end_indices = [len(x) for x in tokens] + return tokens, start_indices, end_indices + + +def batch_sequences_schema(tokenizer, prompts): + # In schema tasks, contexts vary but continuation is the same (common suffix) + tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id()) + # figure out the start and end of each context + suffix_length = find_common_length(tokens, direction="right") + end_indices = [len(x) for x in tokens] + start_indices = [ei - suffix_length for ei in end_indices] + return tokens, start_indices, end_indices + + +def batch_sequences_lm(tokenizer, prompts): + # In LM tasks, we have two prompts: without and with continuation + tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id()) + tokens_without, tokens_with = tokens + start_idx, end_idx = len(tokens_without), len(tokens_with) + assert start_idx < end_idx, ( + "prompt without is supposed to be a prefix of prompt with" + ) + assert tokens_without == tokens_with[:start_idx], ( + "prompt without is supposed to be a prefix of prompt with" + ) + # we only need the with continuation prompt in the LM task, i.e. batch size of 1 + return [tokens_with], [start_idx], [end_idx] + + +@torch.no_grad() +def forward_model(model, input_ids): + """ + Take BxT tensor of token ids, return BxT tensor of losses and argmax predictions. + The last column of losses is set to nan because we don't have autoregressive targets there. + """ + batch_size, seq_len = input_ids.size() + outputs = model(input_ids) + + # Extract logits from model output (handles both raw tensors and HuggingFace output objects) + if hasattr(outputs, "logits"): + logits = outputs.logits + else: + logits = outputs + + # Roll the tensor to the left by one position to get the (autoregressive) target ids + target_ids = torch.roll(input_ids, shifts=-1, dims=1) + # Calculate cross entropy at all positions + losses = torch.nn.functional.cross_entropy( + logits.view(batch_size * seq_len, -1), + target_ids.view(batch_size * seq_len), + reduction="none", + ).view(batch_size, seq_len) + # Set the last column to be nan because there is no autoregressive loss there + losses[:, -1] = float("nan") + # Get the argmax predictions at each position + predictions = logits.argmax(dim=-1) + return losses, predictions + + +@torch.no_grad() +def evaluate_example(idx, model, tokenizer, data, device, task_meta): + """Evaluate a single example, return True if correct, False otherwise""" + item = data[idx] + task_type = task_meta["task_type"] + num_fewshot = task_meta["num_fewshot"] + continuation_delimiter = task_meta["continuation_delimiter"] + + # Sample few-shot examples (excluding current item) + fewshot_examples = [] + if num_fewshot > 0: + rng = random.Random(1234 + idx) + available_indices = [i for i in range(len(data)) if i != idx] + fewshot_indices = rng.sample(available_indices, num_fewshot) + fewshot_examples = [data[i] for i in fewshot_indices] + + # Render prompts and batch sequences based on task type + if task_type == "multiple_choice": + prompts = render_prompts_mc(item, continuation_delimiter, fewshot_examples) + tokens, start_idxs, end_idxs = batch_sequences_mc(tokenizer, prompts) + elif task_type == "schema": + prompts = render_prompts_schema(item, continuation_delimiter, fewshot_examples) + tokens, start_idxs, end_idxs = batch_sequences_schema(tokenizer, prompts) + elif task_type == "language_modeling": + prompts = render_prompts_lm(item, continuation_delimiter, fewshot_examples) + tokens, start_idxs, end_idxs = batch_sequences_lm(tokenizer, prompts) + else: + raise ValueError(f"Unsupported task type: {task_type}") + + # Some models can't forward sequences beyond a certain length (e.g. GPT-2) + # In these cases, we have to truncate sequences to max length and adjust the indices + max_tokens = None + if hasattr(model, "max_seq_len") and model.max_seq_len is not None: + max_tokens = model.max_seq_len + elif hasattr(model, "config"): + # For HuggingFace models, check common config attributes + if hasattr(model.config, "n_positions"): + max_tokens = model.config.n_positions + elif hasattr(model.config, "max_position_embeddings"): + max_tokens = model.config.max_position_embeddings + else: + max_tokens = 1024 # default to 1024 (GPT-2) if no info available + + if max_tokens is not None: + new_tokens, new_start_idxs, new_end_idxs = [], [], [] + for t, s, e in zip(tokens, start_idxs, end_idxs): + if len(t) > max_tokens: + num_to_crop = len(t) - max_tokens + new_tokens.append(t[-max_tokens:]) # take the last max_tokens tokens + new_start_idxs.append(s - num_to_crop) # shift the indices down + new_end_idxs.append(e - num_to_crop) + assert s - num_to_crop >= 0, "this should never happen right?" + assert e - num_to_crop >= 0, "this should never happen right?" + else: + new_tokens.append(t) # keep unchanged + new_start_idxs.append(s) + new_end_idxs.append(e) + tokens, start_idxs, end_idxs = new_tokens, new_start_idxs, new_end_idxs + + # Stack up all the sequences into a batch + pad_token_id = tokenizer.get_bos_token_id() # use BOS as pad token is ok + input_ids = stack_sequences(tokens, pad_token_id) + input_ids = input_ids.to(device) + + # Forward the model, get the autoregressive loss and argmax prediction at each token + losses, predictions = forward_model(model, input_ids) + + # See if the losses/predictions come out correctly + if task_type == "language_modeling": + # language modeling task is currently always batch size 1 + si = start_idxs[0] + ei = end_idxs[0] + # predictions[i] predict input_ids[i+1] autoregressively + predicted_tokens = predictions[0, si - 1 : ei - 1] + actual_tokens = input_ids[0, si:ei] + is_correct = torch.all(predicted_tokens == actual_tokens).item() + elif task_type in ["multiple_choice", "schema"]: + # For MC/schema: find the option with lowest average loss + mean_losses = [ + losses[i, si - 1 : ei - 1].mean().item() + for i, (si, ei) in enumerate(zip(start_idxs, end_idxs)) + ] + pred_idx = mean_losses.index(min(mean_losses)) + is_correct = pred_idx == item["gold"] + else: + raise ValueError(f"Unsupported task type: {task_type}") + + return is_correct + + +def evaluate_task(model, tokenizer, data, device, task_meta): + """ + This function is responsible for evaluating one task across many examples. + """ + # wrap tokenizer with Universal wrapper for compatibility + tokenizer = UniversalHuggingFaceTokenizer(tokenizer) + correct = torch.zeros(len(data), dtype=torch.float32, device=device) + for idx in range(len(data)): + is_correct = evaluate_example(idx, model, tokenizer, data, device, task_meta) + correct[idx] = float(is_correct) + # compute the mean + mean_correct = correct.mean().item() + return mean_correct diff --git a/plato/benchmarks/core_helpers/tokenizer.py b/plato/benchmarks/core_helpers/tokenizer.py new file mode 100644 index 000000000..05dc619ab --- /dev/null +++ b/plato/benchmarks/core_helpers/tokenizer.py @@ -0,0 +1,303 @@ +""" +BPE Tokenizer in the style of GPT-4. + +Two implementations are available: +1) HuggingFace Tokenizer that can do both training and inference but is really confusing +2) Universal Wrapper that can load any HuggingFace tokenizer (e.g., for GPT-2 which has slightly different tokenization rules than GPT-4) for inference only. +""" + +import os + +SPECIAL_TOKENS = [ + # every document begins with the Beginning of Sequence (BOS) token that delimits documents + "<|bos|>", + # tokens below are only used during finetuning to render Conversations into token ids + "<|user_start|>", # user messages + "<|user_end|>", + "<|assistant_start|>", # assistant messages + "<|assistant_end|>", + "<|python_start|>", # assistant invokes python REPL tool + "<|python_end|>", + "<|output_start|>", # python REPL outputs back to assistant + "<|output_end|>", +] + +# NOTE: this split pattern deviates from GPT-4 in that we use \p{N}{1,2} instead of \p{N}{1,3} +# I did this because I didn't want to "waste" too many tokens on numbers for smaller vocab sizes. +# I haven't validated that this is actually a good idea, TODO. +SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""" + +# ----------------------------------------------------------------------------- +# Generic GPT-4-style tokenizer based on HuggingFace Tokenizer +from tokenizers import Tokenizer as HFTokenizer +from tokenizers import pre_tokenizers, decoders, Regex +from tokenizers.models import BPE +from tokenizers.trainers import BpeTrainer + + +class HuggingFaceTokenizer: + """Light wrapper around HuggingFace Tokenizer for some utilities""" + + def __init__(self, tokenizer): + self.tokenizer = tokenizer + + @classmethod + def from_pretrained(cls, hf_path): + # init from a HuggingFace pretrained tokenizer (e.g. "gpt2") + tokenizer = HFTokenizer.from_pretrained(hf_path) + return cls(tokenizer) + + @classmethod + def from_directory(cls, tokenizer_dir): + # init from a local directory on disk (e.g. "out/tokenizer") + tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json") + tokenizer = HFTokenizer.from_file(tokenizer_path) + return cls(tokenizer) + + @classmethod + def train_from_iterator(cls, text_iterator, vocab_size): + # train from an iterator of text + # Configure the HuggingFace Tokenizer + tokenizer = HFTokenizer( + BPE( + byte_fallback=True, # needed! + unk_token=None, + fuse_unk=False, + ) + ) + # Normalizer: None + tokenizer.normalizer = None + # Pre-tokenizer: GPT-4 style + # the regex pattern used by GPT-4 to split text into groups before BPE + # NOTE: The pattern was changed from \p{N}{1,3} to \p{N}{1,2} because I suspect it is harmful to + # very small models and smaller vocab sizes, because it is a little bit wasteful in the token space. + # (but I haven't validated this! TODO) + gpt4_split_regex = Regex( + SPLIT_PATTERN + ) # huggingface demands that you wrap it in Regex!! + tokenizer.pre_tokenizer = pre_tokenizers.Sequence( + [ + pre_tokenizers.Split( + pattern=gpt4_split_regex, behavior="isolated", invert=False + ), + pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False), + ] + ) + # Decoder: ByteLevel (it pairs together with the ByteLevel pre-tokenizer) + tokenizer.decoder = decoders.ByteLevel() + # Post-processor: None + tokenizer.post_processor = None + # Trainer: BPE + trainer = BpeTrainer( + vocab_size=vocab_size, + show_progress=True, + min_frequency=0, # no minimum frequency + initial_alphabet=pre_tokenizers.ByteLevel.alphabet(), + special_tokens=SPECIAL_TOKENS, + ) + # Kick off the training + tokenizer.train_from_iterator(text_iterator, trainer) + return cls(tokenizer) + + def get_vocab_size(self): + return self.tokenizer.get_vocab_size() + + def get_special_tokens(self): + special_tokens_map = self.tokenizer.get_added_tokens_decoder() + special_tokens = [w.content for w in special_tokens_map.values()] + return special_tokens + + def id_to_token(self, id): + return self.tokenizer.id_to_token(id) + + def _encode_one(self, text, prepend=None, append=None): + # encode a single string + # prepend/append can be either a string of a special token or a token id directly. + assert isinstance(text, str) + ids = [] + if prepend is not None: + prepend_id = ( + prepend if isinstance(prepend, int) else self.encode_special(prepend) + ) + ids.append(prepend_id) + ids.extend(self.tokenizer.encode(text, add_special_tokens=False).ids) + if append is not None: + append_id = ( + append if isinstance(append, int) else self.encode_special(append) + ) + ids.append(append_id) + return ids + + def encode_special(self, text): + # encode a single special token via exact match + return self.tokenizer.token_to_id(text) + + def get_bos_token_id(self): + bos = self.encode_special("<|bos|>") + return bos + + def encode(self, text, *args, **kwargs): + if isinstance(text, str): + return self._encode_one(text, *args, **kwargs) + elif isinstance(text, list): + return [self._encode_one(t, *args, **kwargs) for t in text] + else: + raise ValueError(f"Invalid input type: {type(text)}") + + def __call__(self, *args, **kwargs): + return self.encode(*args, **kwargs) + + def decode(self, ids): + return self.tokenizer.decode(ids, skip_special_tokens=False) + + def save(self, tokenizer_dir): + # save the tokenizer to disk + os.makedirs(tokenizer_dir, exist_ok=True) + tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json") + self.tokenizer.save(tokenizer_path) + print(f"Saved tokenizer to {tokenizer_path}") + + +# ----------------------------------------------------------------------------- +# Universal Tokenizer Wrapper that works with any HuggingFace model +# For example, GPT2TokenizerFast doesn't have a get_bos_token_id() method, +# so we need this wrapper to provide a unified interface. + + +class UniversalHuggingFaceTokenizer: + """ + Universal wrapper that provides a consistent interface for any HuggingFace tokenizer. + + This wrapper automatically detects special tokens (BOS, PAD, EOS) and provides + utility methods that work across different tokenizer implementations. + """ + + def __init__(self, tokenizer): + """ + Initialize the wrapper with a HuggingFace tokenizer. + + Args: + tokenizer: A HuggingFace tokenizer instance (e.g., GPT2TokenizerFast) + """ + self.tokenizer = tokenizer + self._pad_token_id = None + self._bos_token_id = None + self._eos_token_id = None + self._detect_special_tokens() + + def _detect_special_tokens(self): + """ + Auto-detect special token IDs from the tokenizer. + + Detection strategy (in order of priority): + 1. Try direct attributes on the tokenizer (bos_token_id, pad_token_id, eos_token_id) + 2. For missing tokens, use EOS as BOS/PAD for models like GPT-2 + 3. Try token_to_id() method with common token names + 4. Final fallbacks: 0 for pad, pad for bos + """ + # Strategy 1: Direct attributes (works for most HuggingFace tokenizers) + if ( + hasattr(self.tokenizer, "bos_token_id") + and self.tokenizer.bos_token_id is not None + ): + self._bos_token_id = self.tokenizer.bos_token_id + + if ( + hasattr(self.tokenizer, "pad_token_id") + and self.tokenizer.pad_token_id is not None + ): + self._pad_token_id = self.tokenizer.pad_token_id + + if ( + hasattr(self.tokenizer, "eos_token_id") + and self.tokenizer.eos_token_id is not None + ): + self._eos_token_id = self.tokenizer.eos_token_id + # For GPT-2 and similar models, BOS is often the same as EOS + if self._bos_token_id is None: + self._bos_token_id = self._eos_token_id + # Use EOS as pad if no pad token exists + if self._pad_token_id is None: + self._pad_token_id = self._eos_token_id + + # Strategy 2: Try token_to_id method for tokenizers with nested structure + if hasattr(self.tokenizer, "tokenizer"): + tokenizer_obj = self.tokenizer.tokenizer + + if self._pad_token_id is None: + pad_candidates = ["", "[PAD]", "<|pad|>", "", "<|endoftext|>"] + self._pad_token_id = self._try_token_candidates( + tokenizer_obj, pad_candidates + ) + + if self._bos_token_id is None: + bos_candidates = ["", "[CLS]", "<|startoftext|>", "<|endoftext|>"] + self._bos_token_id = self._try_token_candidates( + tokenizer_obj, bos_candidates + ) + + # Strategy 3: Final fallbacks + if self._pad_token_id is None: + self._pad_token_id = 0 # Most models default to 0 + + if self._bos_token_id is None: + self._bos_token_id = self._pad_token_id + + def _try_token_candidates(self, tokenizer_obj, candidates): + """ + Try to find a token ID from a list of candidate token strings. + + Args: + tokenizer_obj: The tokenizer object with token_to_id method + candidates: List of token strings to try + + Returns: + Token ID if found, None otherwise + """ + if not hasattr(tokenizer_obj, "token_to_id"): + return None + + for candidate in candidates: + token_id = tokenizer_obj.token_to_id(candidate) + if token_id is not None: + return token_id + return None + + def get_bos_token_id(self): + """Get the beginning-of-sequence token ID.""" + return self._bos_token_id + + def get_pad_token_id(self): + """Get the padding token ID.""" + return self._pad_token_id + + def get_eos_token_id(self): + """Get the end-of-sequence token ID.""" + return self._eos_token_id + + def __call__(self, prompts, prepend=None): + """ + Tokenize prompts with optional prepended token. + + Args: + prompts: Single string or list of strings to tokenize + prepend: Optional token ID to prepend to each sequence + + Returns: + List of token IDs, or list of lists if multiple prompts + """ + if isinstance(prompts, str): + prompts = [prompts] + + result = [] + for prompt in prompts: + tokens = self.tokenizer.encode(prompt) + if prepend is not None: + tokens = [prepend] + tokens + result.append(tokens) + + return result[0] if len(result) == 1 else result + + def __getattr__(self, name): + """Delegate all other attributes to the wrapped tokenizer.""" + return getattr(self.tokenizer, name) diff --git a/plato/benchmarks/registry.py b/plato/benchmarks/registry.py new file mode 100644 index 000000000..1325e8bff --- /dev/null +++ b/plato/benchmarks/registry.py @@ -0,0 +1,31 @@ +""" +Registry for benchmarks. + +Enables runtime benchmark selection via configuration. +""" + +from plato.benchmarks import core +from plato.benchmarks.base import Benchmark as BenchmarkBase + +registered_benchmarks: dict[str, type[BenchmarkBase]] = { + "core": core.Benchmark, +} + + +def get(type: str) -> BenchmarkBase: + """Get an instance of the benchmark.""" + if type in registered_benchmarks: + benchmark_cls = registered_benchmarks[type] + registered_benchmark = benchmark_cls() + else: + available = list(registered_benchmarks.keys()) + raise ValueError( + f"No such benchmark: {type}. Available benchmarks: {available}" + ) + + return registered_benchmark + + +def list_benchmarks(): + """List all available benchmark types.""" + return list(registered_benchmarks.keys()) diff --git a/plato/config.py b/plato/config.py index ec7d840a0..9d0f68b15 100644 --- a/plato/config.py +++ b/plato/config.py @@ -153,6 +153,7 @@ class Config: clients: Any server: Any data: Any + benchmark: Any trainer: Any algorithm: Any results: Any @@ -342,6 +343,20 @@ def __new__(cls): Config.params["base_path"], "data" ) + # User specific benchmark + if hasattr(config, "benchmark"): + Config.benchmark = config.benchmark + + # Directory of benchmark dataset + if hasattr(Config().benchmark, "data_path"): + Config.params["benchmark_path"] = os.path.join( + Config.params["base_path"], Config().benchmark.data_path + ) + else: + Config.params["benchmark_path"] = os.path.join( + Config.params["base_path"], "benchmark" + ) + # Pretrained models if hasattr(Config().server, "model_path"): Config.params["model_path"] = os.path.join( @@ -402,6 +417,10 @@ def __new__(cls): if hasattr(config, "parameters"): Config.parameters = config.parameters + # Benchmark configuration (for model evaluation) + if hasattr(config, "benchmark"): + Config.benchmark = config.benchmark + return cls._instance def __getattr__(self, name: str) -> Any: diff --git a/plato/servers/fedavg.py b/plato/servers/fedavg.py index ee636cb49..30b262b5a 100644 --- a/plato/servers/fedavg.py +++ b/plato/servers/fedavg.py @@ -7,6 +7,7 @@ import os from plato.algorithms import registry as algorithms_registry +from plato.benchmarks import registry as benchmarks_registry from plato.config import Config from plato.datasources import registry as datasources_registry from plato.processors import registry as processor_registry @@ -50,6 +51,8 @@ def __init__( self.testset_sampler = None self.total_samples = 0 + self.benchmark = None + self.total_clients = Config().clients.total_clients self.clients_per_round = Config().clients.per_round @@ -252,6 +255,25 @@ async def _process_reports(self): trainer = self.require_trainer() self.accuracy = trainer.test(self.testset, self.testset_sampler) + # Evaluating the global model on the specified benchmark + if hasattr(Config().config, "benchmark") and hasattr( + Config().benchmark, "type" + ): + benchmark_type = Config().benchmark.type + if self.benchmark is None: + self.benchmark = benchmarks_registry.get(benchmark_type) + logging.info( + "[%s] Started model evaluation on benchmark %s.", self, benchmark_type + ) + trainer = self.require_trainer() + self.benchmark_result = trainer.eval(self.benchmark, self.testset_sampler) + logging.info( + "[%s] Model evaluation result on benchmark %s:\n%s.", + self, + benchmark_type, + self.benchmark.get_formatted_result(self.benchmark_result), + ) + if hasattr(Config().trainer, "target_perplexity"): logging.info( fonts.colourize( diff --git a/plato/trainers/base.py b/plato/trainers/base.py index 306ca0c4f..9b67a86b7 100644 --- a/plato/trainers/base.py +++ b/plato/trainers/base.py @@ -3,6 +3,7 @@ """ import os +import json from abc import ABC, abstractmethod from typing import Any, Optional @@ -77,6 +78,39 @@ def load_accuracy(filename=None): return accuracy + @staticmethod + def save_benchmark_result(benchmark_result, filename=None): + """Saving the benchmark result to a file.""" + model_path = Config().params["model_path"] + model_name = Config().trainer.model_name + + if not os.path.exists(model_path): + os.makedirs(model_path) + + if filename is not None: + benchmark_result_path = f"{model_path}/{filename}" + else: + benchmark_result_path = f"{model_path}/{model_name}.eval" + + with open(benchmark_result_path, "w", encoding="utf-8") as file: + json.dump(benchmark_result, file) + + @staticmethod + def load_benchmark_result(filename=None): + """Loading the benchmark result from a file.""" + model_path = Config().params["model_path"] + model_name = Config().trainer.model_name + + if filename is not None: + benchmark_result_path = f"{model_path}/{filename}" + else: + benchmark_result_path = f"{model_path}/{model_name}.eval" + + with open(benchmark_result_path, encoding="utf-8") as file: + benchmark_result = json.load(file) + + return benchmark_result + def pause_training(self): """Remove files of running trainers.""" if hasattr(Config().trainer, "max_concurrency"): diff --git a/plato/trainers/composable.py b/plato/trainers/composable.py index a1da4a5eb..164c111bb 100644 --- a/plato/trainers/composable.py +++ b/plato/trainers/composable.py @@ -793,6 +793,85 @@ def test_model(self, config, testset, sampler=None, **kwargs): return accuracy + def eval_process(self, config, benchmark, sampler=None, **kwargs): + """The evaluating loop, run in a separate process.""" + self.eval_model(config, benchmark, sampler, **kwargs) + + model_name = Config().trainer.model_name + filename = f"{model_name}_{self.client_id}_{config['run_id']}.eval" + self.save_benchmark_result(self.benchmark_result, filename) + + def eval(self, benchmark, sampler=None, **kwargs) -> dict[str, Any]: + """ + Evaluate the model using the provided benchmark. + + Args: + benchmark: benchmark instance (from benchmarks.registry.get()) + sampler: The sampler for the test dataset + **kwargs: Additional keyword arguments + + Returns: + Accuracy on benchmark + """ + config = Config().trainer._asdict() + config["run_id"] = Config().params["run_id"] + + if "max_concurrency" in config: + model = self._require_model() + model.cpu() + + if mp.get_start_method(allow_none=True) != "spawn": + mp.set_start_method("spawn", force=True) + + eval_proc = mp.Process( + target=self.eval_process, + args=(config, benchmark, sampler), + kwargs=kwargs, + ) + eval_proc.start() + eval_proc.join() + + model_name = Config().trainer.model_name + filename = f"{model_name}_{self.client_id}_{Config().params['run_id']}.eval" + + try: + benchmark_result = self.load_benchmark_result(filename) + except OSError as error: + raise ValueError( + f"Evaluating on client {self.client_id} failed." + ) from error + + self.pause_training() + return benchmark_result + else: + return self.eval_model(config, benchmark, sampler, **kwargs) + + def eval_model(self, config, benchmark, sampler=None, **kwargs): + """ + Evaluate the model using benchmark. + + Args: + config: Evaluation configuration dictionary + benchmark: Benchmark instance (from benchmarks.registry.get()) + sampler: Optional data sampler (usually None for benchmarks) + **kwargs: Additional keyword arguments + + Returns: + Benchmark results dictionary containing: + - 'results': per-task accuracies + - 'centered_results': normalized scores + - 'core_metric': overall CORE score + """ + + model = self._require_model() + result = self.testing_strategy.eval_model( + model, config, benchmark, sampler, self.context + ) + + self.benchmark_result = result + + return result + def obtain_model_update(self, config, trainset, sampler): """ Obtain model updates from training. diff --git a/plato/trainers/split_learning.py b/plato/trainers/split_learning.py index 8a8b81536..dcacb08da 100644 --- a/plato/trainers/split_learning.py +++ b/plato/trainers/split_learning.py @@ -214,6 +214,11 @@ def test_model(self, model, config, testset, sampler, context): accuracy = correct / total return accuracy + def eval_model(self, model, config, benchmark, sampler, context): + raise NotImplementedError( + "eval_model is not implemented yet for SplitLearningTestingStrategy." + ) + # pylint:disable=too-many-instance-attributes class Trainer(ComposableTrainer): diff --git a/plato/trainers/strategies/base.py b/plato/trainers/strategies/base.py index a4c3159b6..30032f252 100644 --- a/plato/trainers/strategies/base.py +++ b/plato/trainers/strategies/base.py @@ -554,3 +554,36 @@ def test_model( setting eval mode, and computing the metric. """ pass + + @abstractmethod + def eval_model( + self, + model: nn.Module, + config: dict[str, Any], + benchmark, + sampler, + context: TrainingContext, + ) -> dict[str, Any]: + """ + Evaluate the model on benchmark and return results. + + Args: + model: The model to test + config: Testing configuration dictionary + benchmark: Benchmark instance for evaluation + sampler: Optional data sampler for test set + context: Training context with device, client_id, etc. + + Returns: + Benchmark results dictionary containing evaluation metrics. + For CORE benchmark, this includes: + - 'results': per-task accuracies + - 'centered_results': normalized scores + - 'core_metric': overall CORE score + + Note: + This method should handle moving model to device, + setting eval mode, and computing the benchmark metrics. + The specific return format depends on the benchmark type. + """ + pass diff --git a/plato/trainers/strategies/testing.py b/plato/trainers/strategies/testing.py index 80e6c424f..b170f35a0 100644 --- a/plato/trainers/strategies/testing.py +++ b/plato/trainers/strategies/testing.py @@ -7,6 +7,7 @@ import logging import os +from typing import Any import torch @@ -97,3 +98,27 @@ def test_model(self, model, config, testset, sampler, context): ) return accuracy + + def eval_model(self, model, config, benchmark, sampler, context) -> dict[str, Any]: + """ + Evaluate the model on benchmark and return results. + + Args: + model: The model to test + config: Testing configuration dictionary + benchmark: Benchmark instance for evaluation + sampler: Optional data sampler for test set + context: Training context with device, client_id, etc. + + Returns: + Benchmark results dictionary + + Note: + DefaultTestingStrategy does not implement benchmark evaluation. + Use a specialized testing strategy (e.g., LLMSplitLearningTestingStrategy) + for benchmark support. + """ + raise NotImplementedError( + "DefaultTestingStrategy does not support benchmark evaluation. " + "Please implement a custom TestingStrategy with eval_model() for your use case." + )