From d2d825670644859198ce197c874b717eaab449ea Mon Sep 17 00:00:00 2001 From: Junyi Chen Date: Mon, 15 Dec 2025 14:33:58 +0800 Subject: [PATCH 1/2] add fallback logic of triton autotune --- lightllm/common/triton_utils/autotuner.py | 31 +++++++++++++++++++- lightllm/server/api_cli.py | 5 ++++ lightllm/server/core/objs/start_args_type.py | 1 + 3 files changed, 36 insertions(+), 1 deletion(-) diff --git a/lightllm/common/triton_utils/autotuner.py b/lightllm/common/triton_utils/autotuner.py index a919f7b28..5f5c70735 100644 --- a/lightllm/common/triton_utils/autotuner.py +++ b/lightllm/common/triton_utils/autotuner.py @@ -12,7 +12,7 @@ from lightllm.utils.device_utils import get_current_device_name from lightllm.utils.log_utils import init_logger from typing import Callable, Optional, Union, List -from lightllm.utils.envs_utils import get_triton_autotune_level +from lightllm.utils.envs_utils import get_env_start_args, get_triton_autotune_level from lightllm.common.kernel_config import KernelConfigs from lightllm.utils.dist_utils import get_global_world_size, get_global_rank, get_current_rank_in_node @@ -218,6 +218,35 @@ def _try_load_cache(self, static_key): logger.info(f"Loading cached configs for {self.kernel_name} - {static_key}") with open(cache_file, "rb") as f: self.cached_configs[static_key] = orjson.loads(f.read()) + elif get_env_start_args().enable_kernel_config_fallback: + # list the all triton versions dir + possilble_triton_versions = os.listdir(os.path.join(Path(__file__).parent, "autotune_kernel_configs")) + # get the current triton version + current_triton_version = get_triton_version() + # try sort by the distance between current triton version and possilble triton versions + possilble_triton_versions = sorted( + possilble_triton_versions, + key=lambda x: abs( + int(x.replace("triton_", "").replace(".", "")) + - int(current_triton_version.replace("triton_", "").replace(".", "")) + ), + ) + for triton_version in possilble_triton_versions: + fallback_cache_file = os.path.join( + Path(__file__).parent, + "autotune_kernel_configs", + triton_version, + get_current_device_name(), + self.kernel_name, + KernelConfigs.get_config_file_name(static_key), + ) + if os.path.exists(fallback_cache_file): + logger.warning( + f"Fallback loading cached configs for {self.kernel_name} - {static_key} " + f"from triton version {triton_version}" + ) + with open(fallback_cache_file, "rb") as f: + self.cached_configs[static_key] = orjson.loads(f.read()) return True def kernel_warmup(self, static_key, *args, **kwargs): diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index bf0e89887..3df809956 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -332,6 +332,11 @@ def make_argument_parser() -> argparse.ArgumentParser: action="store_true", help="""inference backend will use the fa3 attention kernel for prefill and decode""", ) + parser.add_argument( + "--enable_kernel_config_fallback", + action="store_true", + help="""Whether to enable kernel config fallback when triton version is not compatible.""", + ) parser.add_argument( "--cache_capacity", type=int, default=200, help="cache server capacity for multimodal resources" ) diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 71cafd6c4..e9e493086 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -131,3 +131,4 @@ class StartArgs: # kernel setting enable_fa3: bool = field(default=False) + enable_kernel_config_fallback: bool = field(default=False) From 0872088a197377c28a76a39f862038a7c780c28a Mon Sep 17 00:00:00 2001 From: Junyi Chen Date: Mon, 15 Dec 2025 14:59:01 +0800 Subject: [PATCH 2/2] improve logic --- lightllm/common/triton_utils/autotuner.py | 80 +++++++++++++++++------ 1 file changed, 61 insertions(+), 19 deletions(-) diff --git a/lightllm/common/triton_utils/autotuner.py b/lightllm/common/triton_utils/autotuner.py index 5f5c70735..0a4e8be20 100644 --- a/lightllm/common/triton_utils/autotuner.py +++ b/lightllm/common/triton_utils/autotuner.py @@ -1,3 +1,4 @@ +import re import triton import orjson import os @@ -11,7 +12,7 @@ from frozendict import frozendict from lightllm.utils.device_utils import get_current_device_name from lightllm.utils.log_utils import init_logger -from typing import Callable, Optional, Union, List +from typing import Callable, Optional, Tuple, Union, List from lightllm.utils.envs_utils import get_env_start_args, get_triton_autotune_level from lightllm.common.kernel_config import KernelConfigs from lightllm.utils.dist_utils import get_global_world_size, get_global_rank, get_current_rank_in_node @@ -219,19 +220,52 @@ def _try_load_cache(self, static_key): with open(cache_file, "rb") as f: self.cached_configs[static_key] = orjson.loads(f.read()) elif get_env_start_args().enable_kernel_config_fallback: - # list the all triton versions dir - possilble_triton_versions = os.listdir(os.path.join(Path(__file__).parent, "autotune_kernel_configs")) - # get the current triton version + + def parse_triton_version_tag(tag: str) -> Optional[Tuple[int, int, int]]: + """ + Parse "triton_X.Y.Z" or "triton_X.Y" to (X, Y, Z), Z defaults to 0. + Returns None if invalid. + """ + match = re.match(r"^triton_(\d+)\.(\d+)(?:\.(\d+))?$", tag) + if not match: + return None + x, y, z = match.groups() + return (int(x), int(y), int(z) if z is not None else 0) + + def version_distance(v1: Tuple[int, int, int], v2: Tuple[int, int, int]) -> int: + """ + Compute weighted distance: major * 1e6 + minor * 1e3 + patch + Ensures lexicographic ordering. + """ + return abs((v1[0] - v2[0]) * 1_000_000 + (v1[1] - v2[1]) * 1_000 + (v1[2] - v2[2])) + current_triton_version = get_triton_version() - # try sort by the distance between current triton version and possilble triton versions - possilble_triton_versions = sorted( - possilble_triton_versions, - key=lambda x: abs( - int(x.replace("triton_", "").replace(".", "")) - - int(current_triton_version.replace("triton_", "").replace(".", "")) - ), - ) - for triton_version in possilble_triton_versions: + current_parsed = parse_triton_version_tag(current_triton_version) + if current_parsed is None: + logger.error("Unable to parse current Triton version. Triton may not be installed properly.") + possible_dirs = [ + d + for d in os.listdir(os.path.join(Path(__file__).parent, "autotune_kernel_configs")) + if d.startswith("triton_") + ] + possible_dirs.sort() + else: + config_dir = os.path.join(Path(__file__).parent, "autotune_kernel_configs") + possible_dirs = [] + for d in os.listdir(config_dir): + if not d.startswith("triton_"): + continue + parsed = parse_triton_version_tag(d) + if parsed is not None: + dist = version_distance(parsed, current_parsed) + possible_dirs.append((dist, d, parsed)) + else: + logger.debug(f"Skipping invalid version directory: {d}") + possible_dirs.sort(key=lambda x: x[0]) + possible_dirs = [d for _, d, _ in possible_dirs] + + loaded = False + for triton_version in possible_dirs: fallback_cache_file = os.path.join( Path(__file__).parent, "autotune_kernel_configs", @@ -241,12 +275,20 @@ def _try_load_cache(self, static_key): KernelConfigs.get_config_file_name(static_key), ) if os.path.exists(fallback_cache_file): - logger.warning( - f"Fallback loading cached configs for {self.kernel_name} - {static_key} " - f"from triton version {triton_version}" - ) - with open(fallback_cache_file, "rb") as f: - self.cached_configs[static_key] = orjson.loads(f.read()) + try: + logger.warning( + f"Fallback loading cached configs for {self.kernel_name} - {static_key} " + f"from triton version {triton_version} (current: {current_triton_version})" + ) + with open(fallback_cache_file, "rb") as f: + self.cached_configs[static_key] = orjson.loads(f.read()) + loaded = True + break + except Exception as e: + logger.error(f"Failed to load fallback config from {fallback_cache_file}: {e}") + + if not loaded: + logger.info(f"No fallback config found for {self.kernel_name} - {static_key}") return True def kernel_warmup(self, static_key, *args, **kwargs):