diff --git a/README.md b/README.md index ba897602..60a2ba7a 100644 --- a/README.md +++ b/README.md @@ -90,7 +90,7 @@ If you don't have GPU available locally, you can set up [Modal](https://modal.co ## 🚀 Usage ### Run on a single problem -It is easier to get started with a single problem. This will fetch the problem, generate a sample, and evaluate the sample. +It is easier to get started with a single problem. This will fetch the problem, generate a sample, and evaluate the sample. ``` # for example, run level 2 problem 40 from huggingface @@ -106,7 +106,7 @@ python3 scripts/generate_and_eval_single_sample.py dataset_src="huggingface" lev * **`precision`** - You can specify the precision of tensor by `precision=fp32`. Currently all of our reported results are `fp32` but we added support for `fp16` & `bf16`. * **`backend`** - We are also supporting other GPU programming languages beyond `cuda`. Simply specify `backend=triton`. For now we support DSLs: `cuda`, `triton`, `cute`, `tilelang`. -Check the config fields for comprehensive set of options. +Check the config fields for comprehensive set of options. Note we provide the model with a one-shot example by default along with the minimum set of info; you can check out other prompt settings or construct your own in `src/prompt_constructor_toml.py`. ### Run on all problems diff --git a/scripts/generate_and_eval_single_sample.py b/scripts/generate_and_eval_single_sample.py index 18fb3c55..2b2d5301 100644 --- a/scripts/generate_and_eval_single_sample.py +++ b/scripts/generate_and_eval_single_sample.py @@ -9,8 +9,7 @@ from src.dataset import construct_kernelbench_dataset from src.eval import eval_kernel_against_ref -from src.prompt_constructor import prompt_generate_custom_cuda_from_prompt_template -from src.prompt_constructor_multilang import get_prompt_for_backend +from src.prompt_constructor_toml import get_prompt_for_backend, get_custom_prompt from src.utils import ( create_inference_server_from_presets, extract_first_code, @@ -22,6 +21,9 @@ """ Generate and evaluate a single sample Easiest way to get started, to test a single problem for experimentation or debugging + +Example usage: +python3 scripts/generate_and_eval_single_sample.py dataset_src=huggingface level=1 problem_id=1 eval_mode=local server_type=google model_name=gemini/gemini-2.5-flash max_tokens=8192 temperature=0.0 """ REPO_TOP_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) @@ -72,6 +74,12 @@ def __init__(self): self.backend = "cuda" + # Prompt construction + self.prompt_option = "one_shot" # choices: zero_shot, one_shot, few_shot + self.include_hardware_info = False + self.hardware_gpu_name = None + self.custom_prompt_key = None + def verbose_logging(self): self.log = True self.log_prompt = True @@ -86,6 +94,7 @@ def __repr__(self): def main(config: EvalConfig): """ Keep it simple: Generate and evaluate a single sample + Note: will shorten code logic to make this as simple as possible """ from src.utils import SERVER_PRESETS @@ -129,6 +138,7 @@ def main(config: EvalConfig): config.problem_id <= num_problems ), f"Problem ID {config.problem_id} out of range for Level {config.level}" + # TODO: refactor dataset fetching logic to be as clean as posisble. # 1. Fetch Problem if config.dataset_src == "huggingface": @@ -169,24 +179,70 @@ def main(config: EvalConfig): budget_tokens=config.budget_tokens, ) + # Prompt Construction (Note: could be shortened in future PR) + custom_prompt_key = getattr(config, "custom_prompt_key", None) + if isinstance(custom_prompt_key, str): + trimmed = custom_prompt_key.strip() + if trimmed.lower() in {"", "none"}: + custom_prompt_key = None + else: + custom_prompt_key = trimmed + config.custom_prompt_key = custom_prompt_key + # Use appropriate prompt constructor based on backend - if config.backend == "cuda": - custom_prompt = prompt_generate_custom_cuda_from_prompt_template(ref_arch_src) - elif config.backend in ["triton", "tilelang", "cute"]: - custom_prompt = get_prompt_for_backend(ref_arch_src, config.backend) - else: + prompt_option = str(config.prompt_option).lower() + valid_prompt_options = {"zero_shot", "one_shot", "few_shot"} + include_hardware = config.include_hardware_info + if isinstance(include_hardware, str): + include_hardware = include_hardware.lower() in ["true", "1", "yes"] + config.include_hardware_info = include_hardware + + supported_backends = {"cuda", "triton", "tilelang", "cute"} + backend = config.backend.lower() + if backend not in supported_backends: raise ValueError( - f"Unsupported backend: {config.backend}. Must be 'cuda', 'triton', 'tilelang', or 'cute'." + f"Unsupported backend: {config.backend}. Must be one of {sorted(supported_backends)}." ) + if backend == "tilelang": + config.precision = "fp16" # tilelang only operates with fp16 + config.hardware_gpu_name = config.hardware_gpu_name or getattr(config, "gpu", None) + + if not custom_prompt_key: + if prompt_option not in valid_prompt_options: + raise ValueError( + f"Invalid prompt_option '{config.prompt_option}'. " + f"Must be one of {sorted(valid_prompt_options)}." + ) + if include_hardware and not config.hardware_gpu_name: + raise ValueError( + "include_hardware_info is True but hardware_gpu_name is not provided." + ) + + if custom_prompt_key: + custom_prompt = get_custom_prompt( + custom_prompt_key, + ref_arch_src=ref_arch_src, + backend=backend, + option=prompt_option, + precision=config.precision, + include_hardware=include_hardware, + gpu_name=config.hardware_gpu_name, + ) + else: + custom_prompt = get_prompt_for_backend( + ref_arch_src, + backend, + option=prompt_option, + precision=config.precision, + include_hardware=include_hardware, + gpu_name=config.hardware_gpu_name, + ) + + os.makedirs(config.logdir, exist_ok=True) + if config.log_prompt: - with open( - os.path.join( - config.logdir, - f"prompt_level_{config.level}_problem_{config.problem_id}.txt", - ), - "w", - ) as f: + with open(os.path.join(config.logdir, f"prompt_level_{config.level}_problem_{config.problem_id}.txt"), "w") as f: f.write(custom_prompt) # Query server with constructed prompt @@ -200,13 +256,7 @@ def main(config: EvalConfig): # this should be optional if config.log: - with open( - os.path.join( - config.logdir, - f"generated_kernel_level_{config.level}_problem_{config.problem_id}.py", - ), - "w", - ) as f: + with open(os.path.join(config.logdir, f"generated_kernel_level_{config.level}_problem_{config.problem_id}.py"), "w") as f: f.write(custom_kernel) # 3. Evaluate Kernel @@ -228,13 +278,7 @@ def main(config: EvalConfig): ) if config.log: - with open( - os.path.join( - config.logdir, - f"eval_result_level_{config.level}_problem_{config.problem_id}.txt", - ), - "a", - ) as f: + with open(os.path.join(config.logdir, f"eval_result_level_{config.level}_problem_{config.problem_id}.txt"), "a",) as f: f.write(f"Problem Name: {problem_name}\n") f.write(str(kernel_exec_result)) diff --git a/scripts/generate_and_eval_single_sample_modal.py b/scripts/generate_and_eval_single_sample_modal.py index 6962f515..7628e0bf 100644 --- a/scripts/generate_and_eval_single_sample_modal.py +++ b/scripts/generate_and_eval_single_sample_modal.py @@ -15,8 +15,7 @@ #from src.dataset import construct_kernelbench_dataset from src.eval import eval_kernel_against_ref -from src.prompt_constructor import prompt_generate_custom_cuda_from_prompt_template -from src.prompt_constructor_multilang import get_prompt_for_backend +from src.prompt_constructor_toml import get_prompt_for_backend, get_custom_prompt from src.utils import extract_first_code, query_server, set_gpu_arch, read_file, create_inference_server_from_presets app = modal.App("eval_single_sample") @@ -76,6 +75,11 @@ def __init__(self): self.log_eval_result = False self.backend = "cuda" + # Prompt generation settings + self.prompt_option = "one_shot" # zero_shot, one_shot, few_shot + self.include_hardware_info = False + self.hardware_gpu_name = None + self.custom_prompt_key = None def verbose_logging(self): self.log = True @@ -194,14 +198,64 @@ def main(config: EvalConfig): budget_tokens=config.budget_tokens) + custom_prompt_key = getattr(config, "custom_prompt_key", None) + if isinstance(custom_prompt_key, str): + trimmed = custom_prompt_key.strip() + if trimmed.lower() in {"", "none"}: + custom_prompt_key = None + else: + custom_prompt_key = trimmed + config.custom_prompt_key = custom_prompt_key + + # Checks if user has inputted a valid argument for how many examples they want to give as context to the model + prompt_option = str(config.prompt_option).lower() + valid_prompt_options = {"zero_shot", "one_shot", "few_shot"} + include_hardware = config.include_hardware_info + if isinstance(include_hardware, str): + include_hardware = include_hardware.lower() in ["true", "1", "yes"] + config.include_hardware_info = include_hardware + + supported_backends = {"cuda", "triton", "tilelang", "cute"} + backend = config.backend.lower() + if backend not in supported_backends: + raise ValueError( + f"Unsupported backend: {config.backend}. Must be one of {sorted(supported_backends)}." + ) - # Use appropriate prompt constructor based on backend - if config.backend == "cuda": - custom_prompt = prompt_generate_custom_cuda_from_prompt_template(ref_arch_src) - elif config.backend in ["triton", "tilelang", "cute"]: - custom_prompt = get_prompt_for_backend(ref_arch_src, config.backend) + #tilelang only supports fp16 or bf16 + if backend == "tilelang": + config.precision = "fp16" + config.hardware_gpu_name = config.hardware_gpu_name or getattr(config, "gpu", None) + + if not custom_prompt_key: + if prompt_option not in valid_prompt_options: + raise ValueError( + f"Invalid prompt_option '{config.prompt_option}'. Must be one of {sorted(valid_prompt_options)}." + ) + if include_hardware and not config.hardware_gpu_name: + raise ValueError( + "include_hardware_info is True but hardware_gpu_name is not provided." + ) + + if custom_prompt_key: + custom_prompt = get_custom_prompt( + custom_prompt_key, + ref_arch_src=ref_arch_src, + backend=backend, + option=prompt_option, + precision=config.precision, + include_hardware=include_hardware, + gpu_name=config.hardware_gpu_name, + ) else: - raise ValueError(f"Unsupported backend: {config.backend}. Must be 'cuda', 'triton', 'tilelang', or 'cute'.") + custom_prompt = get_prompt_for_backend( + ref_arch_src, + backend, + option=prompt_option, + precision=config.precision, + include_hardware=include_hardware, + gpu_name=config.hardware_gpu_name, + ) if config.log_prompt: with open(os.path.join(config.logdir, f"prompt_level_{config.level}_problem_{config.problem_id}.txt"), "w") as f: diff --git a/scripts/generate_samples.py b/scripts/generate_samples.py index 5b476445..e47c6e87 100644 --- a/scripts/generate_samples.py +++ b/scripts/generate_samples.py @@ -10,8 +10,7 @@ from src.dataset import construct_kernelbench_dataset from src.eval import eval_kernel_against_ref -from src.prompt_constructor import prompt_generate_custom_cuda_from_prompt_template -from src.prompt_constructor_multilang import get_prompt_for_backend +from src.prompt_constructor_toml import get_prompt_for_backend, get_custom_prompt from src.utils import ( create_inference_server_from_presets, extract_first_code, @@ -80,6 +79,10 @@ def __init__(self): self.backend = "cuda" self.precision = "fp32" + self.prompt_option = "one_shot" # zero_shot, one_shot, few_shot + self.include_hardware_info = False + self.hardware_gpu_name = None + self.custom_prompt_key = None def greedy(self): # For greedy decoding, epsecially baseline eval @@ -126,16 +129,24 @@ def generate_sample_single( problem_number == work.problem_id ), f"Problem number in filename ({problem_number}) does not match config problem_id ({config.problem_id})" - # Construct Prompt - if config.backend == "cuda": - custom_cuda_prompt = prompt_generate_custom_cuda_from_prompt_template( - ref_arch_src + if config.custom_prompt_key: + custom_prompt = get_custom_prompt( + config.custom_prompt_key, + ref_arch_src=ref_arch_src, + backend=config.backend, + option=config.prompt_option, + precision=config.precision, + include_hardware=config.include_hardware_info, + gpu_name=config.hardware_gpu_name, ) - elif config.backend in ["triton", "cute", "tilelang"]: - custom_cuda_prompt = get_prompt_for_backend(ref_arch_src, config.backend) else: - raise ValueError( - f"Unsupported backend: {config.backend}. Must be 'cuda', 'triton', 'cute', or 'tilelang'." + custom_prompt = get_prompt_for_backend( + ref_arch_src, + config.backend, + option=config.prompt_option, + precision=config.precision, + include_hardware=config.include_hardware_info, + gpu_name=config.hardware_gpu_name, ) if config.log_prompt: prompt_path = os.path.join( @@ -143,13 +154,13 @@ def generate_sample_single( f"level_{config.level}_problem_{work.problem_id}_sample_{work.sample_id}_prompt.txt", ) with open(prompt_path, "w") as f: - f.write(custom_cuda_prompt) + f.write(custom_prompt) # Query server with constructed prompt - custom_cuda = inference_server(custom_cuda_prompt) - custom_cuda = extract_first_code(custom_cuda, ["python", "cpp"]) + custom_kernel = inference_server(custom_prompt) + custom_kernel = extract_first_code(custom_kernel, ["python", "cpp"]) # check LLM is able to generate custom CUDA code - assert custom_cuda is not None, "Custom CUDA code generation failed" + assert custom_kernel is not None, "Custom CUDA code generation failed" if config.verbose: print( @@ -162,7 +173,7 @@ def generate_sample_single( f"level_{config.level}_problem_{work.problem_id}_sample_{work.sample_id}_kernel.py", ) with open(kernel_path, "w") as f: - f.write(custom_cuda) + f.write(custom_kernel) return True @@ -214,6 +225,42 @@ def main(config: GenerationConfig): if isinstance(config.is_reasoning_model, str): config.is_reasoning_model = config.is_reasoning_model.lower() in ['true', '1', 'yes'] + custom_prompt_key = getattr(config, "custom_prompt_key", None) + if isinstance(custom_prompt_key, str): + trimmed = custom_prompt_key.strip() + if trimmed.lower() in {"", "none"}: + custom_prompt_key = None + else: + custom_prompt_key = trimmed + config.custom_prompt_key = custom_prompt_key + + include_hardware = config.include_hardware_info + if isinstance(include_hardware, str): + include_hardware = include_hardware.lower() in ["true", "1", "yes"] + config.include_hardware_info = include_hardware + + supported_backends = {"cuda", "triton", "cute", "tilelang"} + backend = config.backend.lower() + if backend not in supported_backends: + raise ValueError( + f"Unsupported backend: {config.backend}. Must be one of {sorted(supported_backends)}." + ) + config.backend = backend + if backend == "tilelang": + config.precision = "fp16" + + config.prompt_option = str(config.prompt_option).lower() + valid_prompt_options = {"zero_shot", "one_shot", "few_shot"} + if not config.custom_prompt_key: + if config.prompt_option not in valid_prompt_options: + raise ValueError( + f"Invalid prompt_option '{config.prompt_option}'. Must be one of {sorted(valid_prompt_options)}." + ) + if include_hardware and not config.hardware_gpu_name: + raise ValueError( + "include_hardware_info is True but hardware_gpu_name is not provided." + ) + print(f"Starting Batch Generation with config: {config}") # Dataset Configurations diff --git a/scripts/verify_generation.py b/scripts/verify_generation.py index c284d3b5..61c150c7 100644 --- a/scripts/verify_generation.py +++ b/scripts/verify_generation.py @@ -1,8 +1,7 @@ import sys, os import src.utils as utils import time -from src.prompt_constructor import prompt_generate_custom_cuda_from_prompt_template - +from src.prompt_constructor_toml import get_prompt_for_backend, get_custom_prompt """ For testing infernece and quickly iterate on prompts Uses functions in prompt_constructor @@ -25,14 +24,21 @@ def inference_with_prompt(arch_path, inference_server: callable = None, log_to_l with open("./scratch/model.py", "w") as f: f.write(arch) - custom_cuda_prompt = prompt_generate_custom_cuda_from_prompt_template(arch) + custom_backend_prompt = get_prompt_for_backend( + ref_arch_src=arch, + backend="cuda", + option="one_shot", + precision="fp16", + include_hardware=False, + gpu_name="H100" + ) if log_to_local: with open(f"./scratch/prompt.py", "w") as f: - f.write(custom_cuda_prompt) + f.write(custom_backend_prompt) # query LLM - custom_cuda = inference_server(custom_cuda_prompt) + custom_cuda = inference_server(custom_backend_prompt) custom_cuda = utils.extract_first_code(custom_cuda, ["python", "cpp"]) # check LLM is able to generate custom CUDA code @@ -62,13 +68,12 @@ def sanity_check_inference(inference_server: callable): if __name__ == "__main__": - inference_server = utils.create_inference_server_from_presets(server_type="together", - model_name="meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", + inference_provider_preset = "deepseek" + inference_server = utils.create_inference_server_from_presets(server_type=inference_provider_preset, greedy_sample=True, verbose=True, time_generation=True) - # sanity_check_inference(inference_server) if len(sys.argv) > 1: diff --git a/src/prompt_constructor.py b/src/prompt_constructor.py deleted file mode 100644 index 36cde19f..00000000 --- a/src/prompt_constructor.py +++ /dev/null @@ -1,520 +0,0 @@ -import os -from .utils import read_file - - -""" -Construct Prompt - -Design principles: -- To evaluate base model performance on KernelBench, we use the simplest prompt possible to guide model output to generated desired output format. -- However, we do not do extensive prompt engineering or few-shot example in the LLM to steer behaviour. -""" - -REPO_TOP_PATH = os.path.abspath( - os.path.join( - os.path.dirname(__file__), - "..", - ) -) -KERNEL_BENCH_PATH = os.path.join(REPO_TOP_PATH, "KernelBench") - - -def get_arch_definition_from_file(arch_path): - arch_src = read_file(arch_path) - return get_arch_definition(arch_src) - - -def get_arch_definition(arch_src): - """ - Construct torch definition from original torch nn.Module definition - """ - prompt = f"Here is a pytorch defintion of a neural network architecture in the file model.py: ```{arch_src}```\n" - return prompt - - -############################################ -# CUDA Prompt -############################################ -PROBLEM_STATEMENT = """You write custom CUDA kernels to replace the pytorch operators in the given architecture to get speedups. \n - You have complete freedom to choose the set of operators you want to replace. You may make the decision to replace some operators with custom CUDA kernels and leave others unchanged. You may replace multiple operators with custom implementations, consider operator fusion opportunities (combining multiple operators into a single kernel, for example, combining matmul+relu), or algorithmic changes (such as online softmax). You are only limited by your imagination.\n -""" -PROBLEM_INSTRUCTION = """ -Optimize the architecture named Model with custom CUDA operators! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Just output the new model code, no other text, and NO testing code! \n -""" - - -def prompt_generate_custom_cuda( - arc_src: str, example_arch_src: str, example_new_arch_src: str -) -> str: - prompt = PROBLEM_STATEMENT - - if example_arch_src != "" and example_new_arch_src != "": - prompt += f""" - Here's an example to show you the syntax of inline embedding custom CUDA operators in torch: The example given architecture is: \n - ``` \n - {example_arch_src} - ``` \n - The example new arch with custom CUDA kernels looks like this: - ``` - {example_new_arch_src} - ``` \n - """ - - prompt += f""" - You are given the following architecture: \n - ``` - {arc_src} - ``` - """ - prompt += PROBLEM_INSTRUCTION - return prompt - - -PROBLEM_STATEMENT_CLEANED = """You write custom CUDA kernels to replace the pytorch operators in the given architecture to get speedups.\n\nYou have complete freedom to choose the set of operators you want to replace. You may make the decision to replace some operators with custom CUDA kernels and leave others unchanged. You may replace multiple operators with custom implementations, consider operator fusion opportunities (combining multiple operators into a single kernel, for example, combining matmul+relu), or algorithmic changes (such as online softmax). You are only limited by your imagination.\n -""" -PROBLEM_INSTRUCTION_CLEANED = """ -Optimize the architecture named Model with custom CUDA operators! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Just output the new model code, no other text, and NO testing code! \n -""" - -def prompt_generate_custom_cuda_fewshot_and_template(ref_arch_src: str, shots: list) -> str: - """ - Generate a prompt with specified few-shot examples following a template - - shots: list of few-shot examples to include in the prompt - Avaliable few shot options to start with: - - ex_add: pointwise addition - - ex_fuse_gelu: fused gelu - - ex_mnist2: fused convolutions and relus (DEPRECATED) - - ex_tiled_matmul: tiled matrix multiplication - - ex_flash_attn: simple flash attention - """ - prompt = PROBLEM_STATEMENT_CLEANED - - # k = 1 - example_add = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_ex_add.py") - ) - example_add_new = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_new_ex_add.py") - ) - example_add_desc = "This given architecture is for a pointwise addition: " - - # k = 2 - example_fuse_gelu = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_ex_fuse_gelu.py") - ) - example_fuse_gelu_new = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_new_ex_fuse_gelu.py") - ) - example_fuse_gelu_desc = "This given architecture is for a fused gelu: " - - # k = 3 (DEPRECATED) - example_mnist2 = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_ex_mnist2.py") - ) - example_mnist2_new = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_new_ex_mnist2.py") - ) - exmaple_mnist2_desc = "This given architecture is for a model with fused convolutions and relus: " - - # k = 4 - example_tiled_matmul = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_ex_tiled_matmul.py") - ) - example_tiled_matmul_new = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_new_ex_tiled_matmul.py") - ) - example_tiled_matmul_desc = "This given architecture is for a model with tiled matrix multiplication: " - - # k = 5 - example_flash_attn = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_ex_flash_attn.py") - ) - example_flash_attn_new = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_new_ex_flash_attn.py") - ) - example_flash_attn_desc = "This given architecture is for a model with simple io-aware implementation of attention, also known as flash attention: " - - examples = [] - for s in shots: - if s not in ["ex_add", "ex_fuse_gelu", "ex_mnist2", "ex_tiled_matmul", "ex_flash_attn"]: - raise ValueError(f"Invalid shot: {s}") - elif s == "ex_add": - examples.append((example_add, example_add_new, example_add_desc)) - elif s == "ex_fuse_gelu": - examples.append((example_fuse_gelu, example_fuse_gelu_new, example_fuse_gelu_desc)) - elif s == "ex_mnist2": # DEPRECATED - raise ValueError("ex_mnist2 is deprecated") - examples.append((example_mnist2, example_mnist2_new, exmaple_mnist2_desc)) - elif s == "ex_tiled_matmul": - examples.append((example_tiled_matmul, example_tiled_matmul_new, example_tiled_matmul_desc)) - elif s == "ex_flash_attn": - examples.append((example_flash_attn, example_flash_attn_new, example_flash_attn_desc)) - - - for i, tup in enumerate(examples): - base, kernel, desc = tup - - prompt += f""" -Example {i+1}:\n\n -Here is an example architecture:\n\n -``` -{base} -```\n -{PROBLEM_INSTRUCTION_CLEANED} \n -Here is an optimized verison with custom CUDA kernels: \n -``` -{kernel} -```\n\n -""" - -# should we put task here? - prompt += f""" -Task:\n\n -Here is an example architecture:\n\n -``` -{ref_arch_src} -```\n -""" - prompt += PROBLEM_INSTRUCTION_CLEANED - return prompt - -def prompt_generate_ex_with_CoT_template(ref_arch_src: str, cot_example: str) -> str: - """ - Generate a prompt with a CoT example following a template - Avaliable CoT examples: - - ex_fuse_gelu: fused gelu - - ex_mnist2: fused convolutions and relus - - ex_tiled_matmul: tiled matrix multiplication - """ - - # I updated this to allow CoT. Also explicilty state think step by step. - PROBLEM_INSTRUCTION_COT = """ -Optimize the architecture named Model with custom CUDA operators! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Do not output testing code. -In the end, make sure the final code block contains code for output architecture ModelNew with cuda code.\n -Let's think step by step.\n -""" - - prompt = PROBLEM_STATEMENT_CLEANED - - assert cot_example in ["ex_fuse_gelu", "ex_mnist2", "ex_tiled_matmul"] - - # k = 2 - example_fuse_gelu = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_ex_fuse_gelu.py") - ) - example_fuse_gelu_cot = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/cot/model_cot_fuse_gelu.py") - ) - example_fuse_gelu_new = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_new_ex_fuse_gelu.py") - ) - example_fuse_gelu_desc = "This given architecture is for a fused gelu: " - - # k = 3 - example_mnist2 = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_ex_mnist2.py") - ) - example_mnist2_cot = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/cot/model_cot_mnist2.py") - ) - example_mnist2_new = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_new_ex_mnist2.py") - ) - exmaple_mnist2_desc = "This given architecture is for a model with fused convolutions and relus: " - - # k = 4 - example_tiled_matmul = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_ex_tiled_matmul.py") - ) - example_tiled_matmul_cot = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/cot/model_cot_tiled_matmul.py") - ) - example_tiled_matmul_new = read_file( - os.path.join(REPO_TOP_PATH, "src/prompts/few_shot/model_new_ex_tiled_matmul.py") - ) - example_tiled_matmul_desc = "This given architecture is for a model with tiled matrix multiplication: " - - match cot_example: - case "ex_fuse_gelu": - base = example_fuse_gelu - cot = example_fuse_gelu_cot - kernel = example_fuse_gelu_new - desc = example_fuse_gelu_desc - case "ex_mnist2": - base = example_mnist2 - cot = example_mnist2_cot - kernel = example_mnist2_new - desc = exmaple_mnist2_desc - case "ex_tiled_matmul": - base = example_tiled_matmul - cot = example_tiled_matmul_cot - kernel = example_tiled_matmul_new - desc = example_tiled_matmul_desc - case _: - raise ValueError(f"Invalid CoT example: {cot_example} not found in CoT examples") - - # construct example with - # NOTE: we only do one example with CoT for now - # 1. ref_src problem -> 2. Instruction -> 3. CoT -> 4. Solution - prompt += f""" -Here is an example architecture:\n\n -``` -{base} -```\n -{PROBLEM_INSTRUCTION_COT} \n -{cot} \n -``` -{kernel} -```\n\n -""" - -# show task to solve - prompt += f""" -Task:\n\n -Here is an example architecture:\n\n -``` -{ref_arch_src} -```\n -""" - prompt += PROBLEM_INSTRUCTION_COT - - return prompt - - - -def prompt_generate_custom_cuda_from_file_one_example(ref_arch_src, example_ind=1): - """ - Deprecated: use prompt_generate_custom_cuda_from_prompt_template instead - Keep this around for background compatibility - NOTE: Anne to clean this up - Check example_ind for prompt templates - """ - # arch = get_arch_definition_from_file(arch_path) - arch = ref_arch_src - # These are strictly defined for now - - example_arch_path = os.path.join( - REPO_TOP_PATH, f"src/prompts/model_ex_{example_ind}.py" - ) - example_new_arch_path = os.path.join( - REPO_TOP_PATH, f"src/prompts/model_new_ex_{example_ind}.py" - ) - - if not os.path.exists(example_arch_path): - raise FileNotFoundError( - f"Example architecture file not found: {example_arch_path}" - ) - if not os.path.exists(example_new_arch_path): - raise FileNotFoundError( - f"Example new architecture file not found: {example_new_arch_path}" - ) - - example_arch = read_file(example_arch_path) - example_new_arch = read_file(example_new_arch_path) - - return prompt_generate_custom_cuda(arch, example_arch, example_new_arch) - - -def prompt_generate_custom_cuda_from_prompt_template(ref_arch_src: str) -> str: - """ - Using prompt example (an element-wise addition) for prompt templates - The most basic form of example just to show LLM the task and the expected output format - """ - arch = ref_arch_src - # These are strictly defined for now - - # path to prompt template, show an example of Model (torch specifications) and ModelNew (torch + custom CUDA kernels) - example_arch_path = os.path.join( - REPO_TOP_PATH, f"src/prompts/model_ex_add.py" - ) - example_new_arch_path = os.path.join( - REPO_TOP_PATH, f"src/prompts/model_new_ex_add.py" - ) - - if not os.path.exists(example_arch_path): - raise FileNotFoundError( - f"Example architecture file not found: {example_arch_path}" - ) - if not os.path.exists(example_new_arch_path): - raise FileNotFoundError( - f"Example new architecture file not found: {example_new_arch_path}" - ) - - example_arch = read_file(example_arch_path) - example_new_arch = read_file(example_new_arch_path) - - return prompt_generate_custom_cuda(arch, example_arch, example_new_arch) - - -def prompt_generate_prompt_with_hardware_info_from_template(ref_arch_src: str, gpu_name: str) -> str: - """ - Similar to prompt_generate_custom_cuda_from_prompt_template, - but with hardware information for the given GPU - """ - - arch = ref_arch_src - # These are strictly defined for now - - # path to prompt template, show an example of Model (torch specifications) and ModelNew (torch + custom CUDA kernels) - example_arch_path = os.path.join( - REPO_TOP_PATH, f"src/prompts/model_ex_add.py" - ) - example_new_arch_path = os.path.join( - REPO_TOP_PATH, f"src/prompts/model_new_ex_add.py" - ) - - gpu_spec_file_path = os.path.join(REPO_TOP_PATH, f"src/prompts/hardware/gpu_specs.py") - - example_arch = read_file(example_arch_path) - example_new_arch = read_file(example_new_arch_path) - gpu_spec_info = read_file(gpu_spec_file_path) - - return prompt_generate_prompt_with_hardware_info( - ref_arch_src=arch, - gpu_name=gpu_name, - example_arch_src=example_arch, - example_new_arch_src=example_new_arch, - gpu_spec_info_src=gpu_spec_info - ) - - - -def prompt_generate_prompt_with_hardware_info(ref_arch_src: str, - gpu_name: str, - example_arch_src: str, - example_new_arch_src: str, - gpu_spec_info_src: str) -> str: - """ - Generate a prompt with hardware information for the given GPU - gpu_spec_info_src: str of the gpu spec src file - """ - - # Create a dictionary to store the local namespace - local_dict = {} - - # Execute the GPU spec file in the local namespace - exec(gpu_spec_info_src, {}, local_dict) - - # Get the required variables from the local namespace - GPU_SPEC_INFO = local_dict.get('GPU_SPEC_INFO') - GPU_DEFINITIONS = local_dict.get('GPU_DEFINITIONS') - GPU_BEST_PRACTICES = local_dict.get('GPU_BEST_PRACTICES') - - if not GPU_SPEC_INFO or not GPU_DEFINITIONS or not GPU_BEST_PRACTICES: - raise ValueError("GPU_SPEC_INFO or GPU_DEFINITIONS or GPU_BEST_PRACTICES not found in gpu_spec_info_src") - - assert gpu_name in GPU_SPEC_INFO, f"GPU name {gpu_name} not found in GPU_SPEC_INFO" - - prompt = PROBLEM_STATEMENT - - if example_arch_src != "" and example_new_arch_src != "": - prompt += f""" - Here's an example to show you the syntax of inline embedding custom CUDA operators in torch: The example given architecture is: \n - ``` \n - {example_arch_src} - ``` \n - The example new arch with custom CUDA kernels looks like this: - ``` - {example_new_arch_src} - ``` \n - """ - - curr_gpu_spec_info = GPU_SPEC_INFO[gpu_name] - - gpu_architecture = curr_gpu_spec_info.get("GPU Architecture") - prompt += f""" - Here is some information about the underlying hardware that you should keep in mind. \n\n -The GPU that will run the kernel is NVIDIA {gpu_name}, {gpu_architecture} architecture.\n\n""" - - for key, value in curr_gpu_spec_info.items(): - if key == "GPU Architecture": - continue - prompt += f"""- We have {value} of {key}.\n""" - - - prompt += f"""\n\n -Here are some concepts about the GPU architecture that could be helpful: \n\n""" - for key, value in GPU_DEFINITIONS.items(): - prompt += f"""- {key}: {value}\n""" - - prompt += f"""\n\n -Here are some best practices for writing CUDA kernels on GPU: \n\n""" - for best_practice in GPU_BEST_PRACTICES: - prompt += f"""- {best_practice}\n""" - - - prompt += f""" - You are given the following architecture: \n - ``` - {ref_arch_src} - ``` - """ - - - prompt += PROBLEM_INSTRUCTION - return prompt - - - return Nonoe - - - - - -def prompt_fix_compile(ref_arch_src, custom_cuda, metadata): - prompt = PROBLEM_STATEMENT - prompt += f""" - With the following architecture: - ``` - {ref_arch_src} - ``` - You generated the following solution and it failed to compile: - ``` - {custom_cuda} - ``` - Here's the metadata of the compilation error: - ``` - {metadata} - ``` - - Please fix the compilation error in the new model code. Please output the corrected code in codeblocks. - """ - return prompt - - -def prompt_fix_correctness(ref_arch_src, custom_cuda, metadata): - prompt = PROBLEM_STATEMENT - prompt += f""" - With the following architecture: - ``` - {ref_arch_src} - ``` - You generated the following solution and it failed correctness: - ``` - {custom_cuda} - ``` - Here's the metadata of the correctness error: - ``` - {metadata} - ``` - Please consider how your custom CUDA kernels are implemented, how it is different from the reference implementation, and fix the correctness error in the new model code. Please output the corrected code in codeblocks. - """ - return prompt - -def main(): - gpu_name = "L40S" - - - ref_arch_src = read_file(os.path.join(KERNEL_BENCH_PATH, f"level1/19_ReLU.py")) - assert len(ref_arch_src) > 0, "ref_arch_src is empty" - prompt = prompt_generate_prompt_with_hardware_info_from_template(ref_arch_src, gpu_name) - print(prompt) - # Write prompt to temp file - temp_file_path = os.path.join(REPO_TOP_PATH, "scratch", "prompt_draft.txt") - os.makedirs(os.path.dirname(temp_file_path), exist_ok=True) - with open(temp_file_path, "w") as f: - f.write(prompt) - -if __name__ == "__main__": - main() diff --git a/src/prompt_constructor_multilang.py b/src/prompt_constructor_multilang.py deleted file mode 100644 index 8a520d10..00000000 --- a/src/prompt_constructor_multilang.py +++ /dev/null @@ -1,551 +0,0 @@ -import os -from .utils import read_file - -""" -Multi-Language Prompt Constructor - -Supports: Triton, CuTe (TileLang currently disabled/commented out) - -Design principles: -- To evaluate base model performance on KernelBench, we use the simplest prompt possible to guide model output to generated desired output format. -- However, we do not do extensive prompt engineering or few-shot examples in the LLM to steer behaviour. -""" - -REPO_TOP_PATH = os.path.abspath( - os.path.join( - os.path.dirname(__file__), - "..", - ) -) -KERNEL_BENCH_PATH = os.path.join(REPO_TOP_PATH, "KernelBench") - - -def get_arch_definition_from_file(arch_path): - arch_src = read_file(arch_path) - return get_arch_definition(arch_src) - - -def get_arch_definition(arch_src): - """ - Construct torch definition from original torch nn.Module definition - """ - prompt = f"Here is a pytorch defintion of a neural network architecture in the file model.py: ```{arch_src}```\n" - return prompt - - -################################################################################ -# Triton Backend -################################################################################ - -TRITON_PROBLEM_STATEMENT = """You write custom Triton kernels to replace the pytorch operators in the given architecture to get speedups. \n - You have complete freedom to choose the set of operators you want to replace. You may make the decision to replace some operators with custom Triton kernels and leave others unchanged. You may replace multiple operators with custom implementations, consider operator fusion opportunities (combining multiple operators into a single kernel, for example, combining matmul+relu), or algorithmic changes (such as online softmax). You are only limited by your imagination.\n -""" - -TRITON_PROBLEM_INSTRUCTION = """ -Optimize the architecture named Model with custom Triton kernels! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Just output the new model code, no other text, and NO testing code! \n -""" - -TRITON_PROBLEM_STATEMENT_CLEANED = """You write custom Triton kernels to replace the pytorch operators in the given architecture to get speedups.\n\nYou have complete freedom to choose the set of operators you want to replace. You may make the decision to replace some operators with custom Triton kernels and leave others unchanged. You may replace multiple operators with custom implementations, consider operator fusion opportunities (combining multiple operators into a single kernel, for example, combining matmul+relu), or algorithmic changes (such as online softmax). You are only limited by your imagination.\n -""" - -TRITON_PROBLEM_INSTRUCTION_CLEANED = """ -Optimize the architecture named Model with custom Triton kernels! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Just output the new model code, no other text, and NO testing code! \n -""" - - -def prompt_generate_custom_triton( - arc_src: str, example_arch_src: str, example_new_arch_src: str -) -> str: - prompt = TRITON_PROBLEM_STATEMENT - - assert ( - "@triton.jit" in example_new_arch_src - ), "Example new arch must contain Triton kernel" - - if example_arch_src != "" and example_new_arch_src != "": - prompt += f""" - Here's an example to show you the syntax of inline embedding custom Triton kernels in torch: The example given architecture is: \n - ``` \n - {example_arch_src} - ``` \n - The example new arch with custom Triton kernels looks like this: \n - ``` - {example_new_arch_src} - ``` \n - """ - - prompt += f""" - You are given the following architecture: \n - ``` - {arc_src} - ``` - """ - prompt += TRITON_PROBLEM_INSTRUCTION - return prompt - - -def prompt_generate_custom_triton_fewshot_and_template( - ref_arch_src: str, shots: list -) -> str: - raise NotImplementedError("This function has not been implemented yet") - - -def prompt_generate_ex_with_CoT_template_triton(ref_arch_src: str, cot_example: str) -> str: - raise NotImplementedError("This function has not been implemented yet") - - -def prompt_generate_custom_triton_from_prompt_template(ref_arch_src: str) -> str: - """ - Using prompt example (an element-wise addition) for prompt templates - The most basic form of example just to show LLM the task and the expected output format - """ - arch = ref_arch_src - - # path to prompt template, show an example of Model (torch specifications) and ModelNew (torch + custom Triton kernels) - example_arch_path = os.path.join(REPO_TOP_PATH, f"src/prompts/model_ex_add.py") - example_new_arch_path = os.path.join( - REPO_TOP_PATH, f"src/prompts/model_new_ex_add_triton.py" - ) - - if not os.path.exists(example_arch_path): - raise FileNotFoundError( - f"Example architecture file not found: {example_arch_path}" - ) - if not os.path.exists(example_new_arch_path): - raise FileNotFoundError( - f"Example new architecture file not found: {example_new_arch_path}" - ) - - example_arch = read_file(example_arch_path) - example_new_arch = read_file(example_new_arch_path) - - return prompt_generate_custom_triton(arch, example_arch, example_new_arch) - - -def prompt_generate_prompt_with_hardware_info_from_template_triton( - ref_arch_src: str, gpu_name: str -) -> str: - """ - Similar to prompt_generate_custom_triton_from_prompt_template, - but with hardware information for the given GPU - """ - arch = ref_arch_src - - example_arch_path = os.path.join(REPO_TOP_PATH, f"src/prompts/model_ex_add.py") - example_new_arch_path = os.path.join( - REPO_TOP_PATH, f"src/prompts/model_new_ex_add_triton.py" - ) - gpu_spec_file_path = os.path.join( - REPO_TOP_PATH, f"src/prompts/hardware/gpu_specs.py" - ) - - example_arch = read_file(example_arch_path) - example_new_arch = read_file(example_new_arch_path) - gpu_spec_info = read_file(gpu_spec_file_path) - - return prompt_generate_prompt_with_hardware_info_triton( - ref_arch_src=arch, - gpu_name=gpu_name, - example_arch_src=example_arch, - example_new_arch_src=example_new_arch, - gpu_spec_info_src=gpu_spec_info, - ) - - -def prompt_generate_prompt_with_hardware_info_triton( - ref_arch_src: str, - gpu_name: str, - example_arch_src: str, - example_new_arch_src: str, - gpu_spec_info_src: str, -) -> str: - """ - Generate a prompt with hardware information for the given GPU - gpu_spec_info_src: str of the gpu spec src file - """ - local_dict = {} - exec(gpu_spec_info_src, {}, local_dict) - - GPU_SPEC_INFO = local_dict.get("GPU_SPEC_INFO") - GPU_DEFINITIONS = local_dict.get("GPU_DEFINITIONS") - GPU_BEST_PRACTICES = local_dict.get("GPU_BEST_PRACTICES") - - if not GPU_SPEC_INFO or not GPU_DEFINITIONS or not GPU_BEST_PRACTICES: - raise ValueError( - "GPU_SPEC_INFO or GPU_DEFINITIONS or GPU_BEST_PRACTICES not found in gpu_spec_info_src" - ) - - assert gpu_name in GPU_SPEC_INFO, f"GPU name {gpu_name} not found in GPU_SPEC_INFO" - - prompt = TRITON_PROBLEM_STATEMENT - - if example_arch_src != "" and example_new_arch_src != "": - prompt += f""" - Here's an example to show you the syntax of inline embedding custom Triton kernels in torch: The example given architecture is: \n - ``` \n - {example_arch_src} - ``` \n - The example new arch with custom Triton kernels looks like this: - ``` - {example_new_arch_src} - ``` \n - """ - - curr_gpu_spec_info = GPU_SPEC_INFO[gpu_name] - gpu_architecture = curr_gpu_spec_info.get("GPU Architecture") - prompt += f""" - Here is some information about the underlying hardware that you should keep in mind. \n\n -The GPU that will run the kernel is NVIDIA {gpu_name}, {gpu_architecture} architecture.\n\n""" - - for key, value in curr_gpu_spec_info.items(): - if key == "GPU Architecture": - continue - prompt += f"""- We have {value} of {key}.\n""" - - prompt += f"""\n\n -Here are some concepts about the GPU architecture that could be helpful: \n\n""" - for key, value in GPU_DEFINITIONS.items(): - prompt += f"""- {key}: {value}\n""" - - prompt += f"""\n\n -Here are some best practices for writing Triton kernels on GPU: \n\n""" - for best_practice in GPU_BEST_PRACTICES: - prompt += f"""- {best_practice}\n""" - - prompt += f""" - You are given the following architecture: \n - ``` - {ref_arch_src} - ``` - """ - - prompt += TRITON_PROBLEM_INSTRUCTION - return prompt - - -def prompt_fix_compile_triton(ref_arch_src, custom_kernel, metadata): - prompt = TRITON_PROBLEM_STATEMENT - prompt += f""" - With the following architecture: - ``` - {ref_arch_src} - ``` - You generated the following solution and it failed to compile: - ``` - {custom_kernel} - ``` - Here's the metadata of the compilation error: - ``` - {metadata} - ``` - - Please fix the compilation error in the new model code. Please output the corrected code in codeblocks. - """ - return prompt - - -def prompt_fix_correctness_triton(ref_arch_src, custom_kernel, metadata): - prompt = TRITON_PROBLEM_STATEMENT - prompt += f""" - With the following architecture: - ``` - {ref_arch_src} - ``` - You generated the following solution and it failed correctness: - ``` - {custom_kernel} - ``` - Here's the metadata of the correctness error: - ``` - {metadata} - ``` - Please consider how your custom Triton kernels are implemented, how it is different from the reference implementation, and fix the correctness error in the new model code. Please output the corrected code in codeblocks. - """ - return prompt - - -################################################################################ -# TileLang Backend -################################################################################ - -TILELANG_PROBLEM_STATEMENT = """You write custom TileLang kernels to replace the pytorch operators in the given architecture to get speedups. \n - You have complete freedom to choose the set of operators you want to replace. You may make the decision to replace some operators with custom TileLang kernels and leave others unchanged. You may replace multiple operators with custom implementations, consider operator fusion opportunities (combining multiple operators into a single kernel, for example, combining matmul+relu), or algorithmic changes (such as online softmax). You are only limited by your imagination.\n -""" - -TILELANG_PROBLEM_INSTRUCTION = """ -Optimize the architecture named Model with custom TileLang kernels! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Just output the new model code, no other text, and NO testing code! \n -""" - -TILELANG_PROBLEM_STATEMENT_CLEANED = """You write custom TileLang kernels to replace the pytorch operators in the given architecture to get speedups.\n\nYou have complete freedom to choose the set of operators you want to replace. You may make the decision to replace some operators with custom TileLang kernels and leave others unchanged. You may replace multiple operators with custom implementations, consider operator fusion opportunities (combining multiple operators into a single kernel, for example, combining matmul+relu), or algorithmic changes (such as online softmax). You are only limited by your imagination.\n -""" - -TILELANG_PROBLEM_INSTRUCTION_CLEANED = """ -Optimize the architecture named Model with custom TileLang kernels! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Just output the new model code, no other text, and NO testing code! \n -""" - - -def prompt_generate_custom_tilelang( - arc_src: str, example_arch_src: str, example_new_arch_src: str -) -> str: - prompt = TILELANG_PROBLEM_STATEMENT - - if example_arch_src != "" and example_new_arch_src != "": - prompt += f""" - Here's an example to show you the syntax of inline embedding custom TileLang kernels in torch: The example given architecture is: \n - ``` \n - {example_arch_src} - ``` \n - The example new arch with custom TileLang kernels looks like this: \n - ``` - {example_new_arch_src} - ``` \n - """ - - prompt += f""" - You are given the following architecture: \n - ``` - {arc_src} - ``` - """ - prompt += TILELANG_PROBLEM_INSTRUCTION - return prompt - - -def prompt_generate_custom_tilelang_from_prompt_template(ref_arch_src: str) -> str: - """ - Using prompt example for TileLang - """ - arch = ref_arch_src - - example_arch_path = os.path.join(REPO_TOP_PATH, f"src/prompts/model_ex_add.py") - example_new_arch_path = os.path.join( - REPO_TOP_PATH, f"src/prompts/model_new_ex_add_tilelang.py" - ) - - if not os.path.exists(example_arch_path): - raise FileNotFoundError( - f"Example architecture file not found: {example_arch_path}" - ) - if not os.path.exists(example_new_arch_path): - # For now, use a basic template without examples if file doesn't exist - return prompt_generate_custom_tilelang(arch, "", "") - - example_arch = read_file(example_arch_path) - example_new_arch = read_file(example_new_arch_path) - - return prompt_generate_custom_tilelang(arch, example_arch, example_new_arch) - - -def prompt_fix_compile_tilelang(ref_arch_src, custom_kernel, metadata): - prompt = TILELANG_PROBLEM_STATEMENT - prompt += f""" - With the following architecture: - ``` - {ref_arch_src} - ``` - You generated the following solution and it failed to compile: - ``` - {custom_kernel} - ``` - Here's the metadata of the compilation error: - ``` - {metadata} - ``` - - Please fix the compilation error in the new model code. Please output the corrected code in codeblocks. - """ - return prompt - - -def prompt_fix_correctness_tilelang(ref_arch_src, custom_kernel, metadata): - prompt = TILELANG_PROBLEM_STATEMENT - prompt += f""" - With the following architecture: - ``` - {ref_arch_src} - ``` - You generated the following solution and it failed correctness: - ``` - {custom_kernel} - ``` - Here's the metadata of the correctness error: - ``` - {metadata} - ``` - Please consider how your custom TileLang kernels are implemented, how it is different from the reference implementation, and fix the correctness error in the new model code. Please output the corrected code in codeblocks. - """ - return prompt - - -################################################################################ -# CuTe Backend -################################################################################ - -CUTE_PROBLEM_STATEMENT = """You write custom CuTe (CUTLASS) kernels to replace the pytorch operators in the given architecture to get speedups. \n - You have complete freedom to choose the set of operators you want to replace. You may make the decision to replace some operators with custom CuTe kernels and leave others unchanged. You may replace multiple operators with custom implementations, consider operator fusion opportunities (combining multiple operators into a single kernel, for example, combining matmul+relu), or algorithmic changes (such as online softmax). You are only limited by your imagination.\n -""" - -CUTE_PROBLEM_INSTRUCTION = """ -Optimize the architecture named Model with custom CuTe (CUTLASS) kernels! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Just output the new model code, no other text, and NO testing code! \n -""" - -CUTE_PROBLEM_STATEMENT_CLEANED = """You write custom CuTe (CUTLASS) kernels to replace the pytorch operators in the given architecture to get speedups.\n\nYou have complete freedom to choose the set of operators you want to replace. You may make the decision to replace some operators with custom CuTe kernels and leave others unchanged. You may replace multiple operators with custom implementations, consider operator fusion opportunities (combining multiple operators into a single kernel, for example, combining matmul+relu), or algorithmic changes (such as online softmax). You are only limited by your imagination.\n -""" - -CUTE_PROBLEM_INSTRUCTION_CLEANED = """ -Optimize the architecture named Model with custom CuTe (CUTLASS) kernels! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Just output the new model code, no other text, and NO testing code! \n -""" - - -def prompt_generate_custom_cute( - arc_src: str, example_arch_src: str, example_new_arch_src: str -) -> str: - prompt = CUTE_PROBLEM_STATEMENT - - if example_arch_src != "" and example_new_arch_src != "": - prompt += f""" - Here's an example to show you the syntax of inline embedding custom CuTe (CUTLASS) kernels in torch: The example given architecture is: \n - ``` \n - {example_arch_src} - ``` \n - The example new arch with custom CuTe kernels looks like this: \n - ``` - {example_new_arch_src} - ``` \n - """ - - prompt += f""" - You are given the following architecture: \n - ``` - {arc_src} - ``` - """ - prompt += CUTE_PROBLEM_INSTRUCTION - return prompt - - -def prompt_generate_custom_cute_from_prompt_template(ref_arch_src: str) -> str: - """ - Using prompt example for CuTe - Note: You'll need to create a CuTe example file - """ - arch = ref_arch_src - - # TODO: Create model_new_ex_add_cute.py example file - example_arch_path = os.path.join(REPO_TOP_PATH, f"src/prompts/model_ex_add.py") - example_new_arch_path = os.path.join( - REPO_TOP_PATH, f"src/prompts/model_new_ex_add_cute.py" - ) - - if not os.path.exists(example_arch_path): - raise FileNotFoundError( - f"Example architecture file not found: {example_arch_path}" - ) - if not os.path.exists(example_new_arch_path): - # For now, use a basic template without examples if file doesn't exist - return prompt_generate_custom_cute(arch, "", "") - - example_arch = read_file(example_arch_path) - example_new_arch = read_file(example_new_arch_path) - - return prompt_generate_custom_cute(arch, example_arch, example_new_arch) - - -def prompt_fix_compile_cute(ref_arch_src, custom_kernel, metadata): - prompt = CUTE_PROBLEM_STATEMENT - prompt += f""" - With the following architecture: - ``` - {ref_arch_src} - ``` - You generated the following solution and it failed to compile: - ``` - {custom_kernel} - ``` - Here's the metadata of the compilation error: - ``` - {metadata} - ``` - - Please fix the compilation error in the new model code. Please output the corrected code in codeblocks. - """ - return prompt - - -def prompt_fix_correctness_cute(ref_arch_src, custom_kernel, metadata): - prompt = CUTE_PROBLEM_STATEMENT - prompt += f""" - With the following architecture: - ``` - {ref_arch_src} - ``` - You generated the following solution and it failed correctness: - ``` - {custom_kernel} - ``` - Here's the metadata of the correctness error: - ``` - {metadata} - ``` - Please consider how your custom CuTe kernels are implemented, how it is different from the reference implementation, and fix the correctness error in the new model code. Please output the corrected code in codeblocks. - """ - return prompt - - -################################################################################ -# Unified API -################################################################################ - -def get_prompt_for_backend(ref_arch_src: str, backend: str = "triton") -> str: - """ - Unified API to get prompt for any supported backend - - Args: - ref_arch_src: Reference architecture source code - backend: One of 'triton', 'tilelang', 'cute' - - Returns: - Prompt string for the specified backend - """ - backend_lower = backend.lower() - - if backend_lower == "triton": - return prompt_generate_custom_triton_from_prompt_template(ref_arch_src) - elif backend_lower == "tilelang": - return prompt_generate_custom_tilelang_from_prompt_template(ref_arch_src) - elif backend_lower == "cute": - return prompt_generate_custom_cute_from_prompt_template(ref_arch_src) - else: - raise ValueError( - f"Unsupported backend: {backend}. Must be one of: 'triton', 'tilelang', 'cute'" - ) - - -################################################################################ -# Main (for testing) -################################################################################ - -def main(): - gpu_name = "L40S" - backend = "triton" # Change this to test different backends - - ref_arch_src = read_file(os.path.join(KERNEL_BENCH_PATH, f"level1/19_ReLU.py")) - assert len(ref_arch_src) > 0, "ref_arch_src is empty" - - prompt = get_prompt_for_backend(ref_arch_src, backend) - print(f"\n{'='*80}\n{backend.upper()} PROMPT:\n{'='*80}\n") - print(prompt) - - # Write prompt to temp file - temp_file_path = os.path.join(REPO_TOP_PATH, "scratch", f"prompt_{backend}_draft.txt") - os.makedirs(os.path.dirname(temp_file_path), exist_ok=True) - with open(temp_file_path, "w") as f: - f.write(prompt) - print(f"\nPrompt written to: {temp_file_path}") - - -if __name__ == "__main__": - main() - - - diff --git a/src/prompt_constructor_toml.py b/src/prompt_constructor_toml.py new file mode 100644 index 00000000..fc074494 --- /dev/null +++ b/src/prompt_constructor_toml.py @@ -0,0 +1,478 @@ +# src/prompt_constructor_toml.py | toml based prompt constructor +import os +import runpy +import tomli +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +from src.utils import read_file + +""" +TOML-based prompt constructor for managing prompt templates and configurations. +This module provides a way to load and compose prompt templates from a TOML configuration file. + +You can easily check some of the prompt templates we have provided and create your own. +""" + +REPO_TOP_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +PROMPTS_TOML = os.path.join(REPO_TOP_PATH, "src/prompts/prompts.toml") + +assert os.path.exists(PROMPTS_TOML), f"Prompts.toml not found at {PROMPTS_TOML}" +GPU_SPECS_PY = "src/prompts/hardware/gpu_specs.py" +HARDWARE_COMPONENT_KEYS = [ + "hardware_header", + "hardware_specs", + "hardware_definitions", + "hardware_best_practices", +] + +def _abs_path(rel: str) -> str: + """Convert relative path to absolute path from repo root.""" + if os.path.isabs(rel): + return rel + return os.path.join(REPO_TOP_PATH, rel) + +@dataclass +class PromptConfig: + """ + Configuration wrapper for prompts.toml data. + + This class holds the parsed TOML file data and provides ways to navigate + the nested structure and compose prompt templates. + + The TOML file has a structure like: + [backends.cuda] + [options.few_shot] + [templates.common.arch_block] + + This class makes it easy to look up values in that hierarchy. + """ + data: Dict[str, Any] # The raw parsed TOML data as nested dictionaries + + @classmethod + def from_toml(cls, path: str) -> "PromptConfig": + """ + Load and parse a TOML configuration file. + + Args: + path: Filesystem path to the prompts.toml file + + Returns: + PromptConfig instance with parsed data + """ + with open(path, "rb") as f: + data = tomli.load(f) + return cls(data) + + def compose_blocks(self, keys: List[str]) -> str: + """ + Look up and concatenate multiple template blocks using dotted key paths. + + This method navigates the nested TOML structure using dotted notation + (e.g., "templates.common.arch_block") to find template strings, then + concatenates them together with newlines. + + Args: + keys: List of dotted key paths (e.g., ["templates.common.arch_block"]) + Each key is split on "." and used to traverse the nested dict. + + Returns: + Concatenated string of all template blocks, each separated by newlines + """ + text_parts = [] + for key in keys: + # Navigate through the nested dictionary structure + node: Any = self.data + for part in key.split("."): + if part not in node: + raise KeyError(f"compose key not found: {key}") + node = node[part] + + # Ensure we found a string template, not another dict/list + if not isinstance(node, str): + raise TypeError(f"compose key must resolve to string: {key}") + + text_parts.append(node.strip() + "\n") + + return "\n".join(text_parts).strip() + "\n" + +def _gpu_context_from_gpu_specs(py_path: str, gpu_name: str) -> Dict[str, str]: + """ + Load GPU_* dicts from the GPU specs file (no exec of raw strings; use runpy). + Expected globals: + - GPU_SPEC_INFO: dict[str, dict] + - GPU_DEFINITIONS: dict[str, str] + - GPU_BEST_PRACTICES: list[str] OR {"list": [...]} for compatibility + """ + mod = runpy.run_path(py_path) + spec_info = mod.get("GPU_SPEC_INFO", {}) + definitions = mod.get("GPU_DEFINITIONS", {}) + best = mod.get("GPU_BEST_PRACTICES", []) + + if not spec_info or not definitions or best is None: + raise ValueError("GPU_SPEC_INFO / GPU_DEFINITIONS / GPU_BEST_PRACTICES missing in gpu specs .py") + + if isinstance(best, dict) and "list" in best: + best = best["list"] + + if gpu_name not in spec_info: + raise KeyError(f"GPU name {gpu_name} not found in GPU_SPEC_INFO") + + curr = spec_info[gpu_name] + gpu_architecture = curr.get("GPU Architecture", "Unknown") + specs_bullets = "\n".join([f"- We have {v} of {k}." for k, v in curr.items() if k != "GPU Architecture"]) + defs_bullets = "\n".join([f"- {k}: {v}" for k, v in definitions.items()]) + best_bullets = "\n".join([f"- {x}" for x in (best or [])]) + + return { + "gpu_name": gpu_name, + "gpu_architecture": gpu_architecture, + "gpu_specs_bullets": specs_bullets, + "gpu_definitions_bullets": defs_bullets, + "gpu_best_practices_bullets": best_bullets, + } + +def render_prompt_by_option( + *, + prompts_toml: str, + backend: str, + option: str, + context: Dict[str, str], + gpu_specs_py: Optional[str] = None, + gpu_name: Optional[str] = None, + precision: Optional[str] = None, + include_hardware: bool = False, + components_override: Optional[List[str]] = None, +) -> str: + """ + Render a prompt using backends.X and options.Y structure from TOML. + + Args: + prompts_toml: Path to the prompts.toml file + backend: The kernel backend (triton, cuda, cute, tilelang) + option: The prompt option (zero_shot, one_shot, few_shot) + - zero_shot: No examples (model learns from description only) + - one_shot: Single example + - few_shot: Multiple examples if available for backend, otherwise falls back to one_shot + context: Variables to fill in the prompt template + gpu_specs_py: Optional path to GPU specs Python file (required if hardware info is included) + gpu_name: Optional GPU name (required if hardware info is included) + precision: Optional precision string (fp32, fp16, bf16) - defaults to fp32 if not provided + include_hardware: Whether to inject hardware guidance blocks after the examples section + components_override: When provided, users can arrange prompt components from the toml + file in any order they want. + Components must exist under templates.common or be hardware_* entries. + + Returns: + The rendered prompt string + """ + cfg = PromptConfig.from_toml(prompts_toml) + + # Get backend-specific content + try: + backend_data = cfg.data["backends"][backend] + except KeyError: + raise KeyError(f"Unknown backend: {backend}") + + # Get option configuration + try: + option_data = cfg.data["options"][option] + except KeyError: + raise KeyError(f"Unknown option: {option}") + + component_sequence = list(components_override or option_data["components"]) + if include_hardware: + if components_override is None: + insert_idx = component_sequence.index("arch_block") if "arch_block" in component_sequence else len(component_sequence) + component_sequence[insert_idx:insert_idx] = HARDWARE_COMPONENT_KEYS + else: + # Custom sequences must explicitly have hardware blocks present in their prompt if they + # have set they are including hardware info. + if not any(component in HARDWARE_COMPONENT_KEYS for component in component_sequence): + raise ValueError( + "components_override must contain at least one hardware_* entry when include_hardware=True" + ) + + # Get shared templates + shared = cfg.data.get("shared", {}) + backend_display = backend_data.get("backend_display", backend.upper()) + + # Fill in shared templates with backend-specific terms + problem_statement = shared.get("problem_statement", "").format(backend_display=backend_display) + instruction = shared.get("instruction", "").format(backend_display=backend_display) + + # Add backend-specific content to context + context = { + **context, + "backend": backend.upper() if backend in ["cuda", "cute"] else backend.capitalize(), + "backend_display": backend_display, + "problem_statement": problem_statement, + "instruction": instruction, + } + + # Load precision details if provided + if precision: + try: + precision_data = cfg.data["precision"][precision] + context["precision_display"] = precision_data.get("precision_display", precision.upper()) + except KeyError: + raise KeyError(f"Unknown precision: {precision}. Must be one of: fp32, fp16, bf16") + else: + # Default to fp32 if not specified + default_precision = cfg.data.get("meta", {}).get("default_precision", "fp32") + precision_data = cfg.data["precision"].get(default_precision, {}) + context["precision_display"] = precision_data.get("precision_display", "FP32 (32-bit floating point)") + + # Load example files if requested. Supports loading one shot or few shot examples. + requires_example = option_data.get("requires_example") + if requires_example: + example_entry_template = cfg.compose_blocks(["templates.common.example_entry_template"]).strip() + intro_one_shot = cfg.compose_blocks(["templates.common.example_intro_one_shot"]).strip() + intro_few_shot = cfg.compose_blocks(["templates.common.example_intro_few_shot"]).strip() + intro_one_shot = intro_one_shot.format( + backend_display=backend_display + ) + intro_few_shot = intro_few_shot.format( + backend_display=backend_display + ) + + def render_example_entry(input_code: str, output_code: str, example_label: str) -> str: + return example_entry_template.format( + example_label=example_label, + input_code=input_code, + output_code=output_code, + backend_display=backend_display, + ) + + examples_entries: List[str] = [] + examples_intro = intro_one_shot + + if requires_example == "few_shot": + # Try to load few-shot examples if available + few_shot_examples = backend_data.get("few_shot_examples") + + if few_shot_examples and len(few_shot_examples) > 0: + # Use multiple examples (true few-shot) + examples_intro = intro_few_shot + for i, (input_path, output_path) in enumerate(few_shot_examples, 1): + input_code = read_file(_abs_path(input_path)) + output_code = read_file(_abs_path(output_path)) + examples_entries.append( + render_example_entry(input_code, output_code, f"Example {i}:") + ) + else: + # Fall back to one-shot + ex_arch_path = _abs_path( + backend_data.get("few_shot_example_arch") or shared.get("few_shot_example_arch") + ) + ex_new_path = _abs_path(backend_data["one_shot_new_arch"]) + input_code = read_file(ex_arch_path) + output_code = read_file(ex_new_path) + examples_entries.append( + render_example_entry(input_code, output_code, "Example:") + ) + + elif requires_example == "one_shot": + # Always use one-shot + ex_arch_path = _abs_path( + backend_data.get("few_shot_example_arch") or shared.get("few_shot_example_arch") + ) + ex_new_path = _abs_path(backend_data["one_shot_new_arch"]) + input_code = read_file(ex_arch_path) + output_code = read_file(ex_new_path) + examples_entries.append( + render_example_entry(input_code, output_code, "Example:") + ) + + if not examples_entries: + raise ValueError(f"No example entries could be constructed for option '{option}'.") + + context["examples_intro"] = examples_intro + context["examples_entries"] = "\n\n".join(examples_entries).strip() + + # Load GPU details if requested + if option_data.get("requires_gpu") or include_hardware: + if not (gpu_specs_py and gpu_name): + raise ValueError( + f"Hardware info requested for option '{option}'; provide gpu_specs_py and gpu_name" + ) + context = {**context, **_gpu_context_from_gpu_specs(_abs_path(gpu_specs_py), gpu_name)} + + # Builds the prompt from the components in the toml file. + prompt_parts = [] + for component in component_sequence: + if component == "problem_statement": + # Use the already-formatted problem_statement from context + prompt_parts.append(context["problem_statement"]) + elif component == "instruction": + # Use the already-formatted instruction from context + prompt_parts.append(context["instruction"]) + elif component.startswith("hardware_"): + # Hardware components from templates.hardware + template_key = f"templates.hardware.{component}" + prompt_parts.append(cfg.compose_blocks([template_key])) + else: + # Other components from templates.common + template_key = f"templates.common.{component}" + prompt_parts.append(cfg.compose_blocks([template_key])) + + prompt_text = "\n".join(prompt_parts).strip() + "\n" + + try: + return prompt_text.format(**context).strip() + "\n" + except KeyError as e: + raise KeyError(f"Missing placeholder in context: {e.args[0]}. Available: {list(context.keys())}") from e + +# ------------------------------------------------------------------------- +# High-level convenience functions +# ------------------------------------------------------------------------- + +def get_prompt_for_backend( + ref_arch_src: str, + backend: str = "triton", + option: str = "one_shot", + precision: Optional[str] = None, + include_hardware: bool = False, + gpu_name: Optional[str] = None, +) -> str: + """ + Generate a prompt for a specific backend and option. + + Args: + ref_arch_src: The reference architecture source code + backend: The kernel backend (triton, cuda, cute, tilelang) + option: The prompt option (zero_shot, one_shot, few_shot) + precision: Optional precision (fp32, fp16, bf16) - defaults to fp32 if not provided + include_hardware: When True, append hardware guidance blocks (requires gpu_name) + gpu_name: GPU identifier used when include_hardware is True (e.g., "A100") + """ + return render_prompt_by_option( + prompts_toml=PROMPTS_TOML, + backend=backend.lower(), + option=option.lower(), + context={"ref_arch_src": ref_arch_src}, + precision=precision, + include_hardware=include_hardware, + gpu_specs_py=GPU_SPECS_PY if include_hardware else None, + gpu_name=gpu_name, + ) + + +def get_custom_prompt( + custom_key: str, + *, + ref_arch_src: str, + backend: str, + option: str, + precision: Optional[str] = None, + include_hardware: bool = False, + gpu_name: Optional[str] = None, + prompts_toml: str = PROMPTS_TOML, +) -> str: + """ + Render a prompt defined under [custom_prompts.] in prompts.toml. + Must still provide backend/option/precision settings just like + get_prompt_for_backend. + """ + if not ref_arch_src: + raise ValueError(f"Custom prompt '{custom_key}' requires ref_arch_src.") + cfg = PromptConfig.from_toml(prompts_toml) + try: + custom_cfg: Dict[str, Any] = cfg.data["custom_prompts"][custom_key] + except KeyError as exc: + raise KeyError(f"Unknown custom prompt: {custom_key}") from exc + + components_override = custom_cfg.get("components") + + return render_prompt_by_option( + prompts_toml=prompts_toml, + backend=backend.lower(), + option=option.lower(), + context={"ref_arch_src": ref_arch_src}, + precision=precision, + include_hardware=include_hardware, + gpu_specs_py=GPU_SPECS_PY if include_hardware else None, + gpu_name=gpu_name, + components_override=components_override, + ) + +__all__ = [ + "get_prompt_for_backend", + "get_custom_prompt", + "get_prompt_with_hardware", + "render_prompt_by_option", + "PromptConfig", +] + + +def log_prompt(prompt: str, dir_path: str, file_name: str): + os.makedirs(dir_path, exist_ok=True) + with open(os.path.join(dir_path, file_name), "w") as f: + f.write(prompt) + +def test_prompt(): + """ + Demonstrate baseline, few-shot, DSL, hardware-aware, and custom prompt + generation. Customize the reference architecture or custom_prompt_key + if you want to try different inputs. + """ + REPO_TOP_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) + ref_arch_src = read_file(os.path.join(REPO_TOP_PATH, "KernelBench", "level1", "1_Square_matrix_multiplication_.py")) + assert len(ref_arch_src) > 0, "ref_arch_src is empty" + + scratch_dir = os.path.join(REPO_TOP_PATH, "scratch") + # baseline prompt + baseline_prompt = get_prompt_for_backend( + ref_arch_src=ref_arch_src, + backend="cuda", + option="one_shot", + precision="fp32", + # GPU platform agnostic for baseline + ) + log_prompt(baseline_prompt, os.path.join(scratch_dir), "baseline_prompt.txt") + + # few shot prompt + few_shot_prompt = get_prompt_for_backend( + ref_arch_src=ref_arch_src, + backend="cuda", + option="few_shot", + precision="fp32", + ) + log_prompt(few_shot_prompt, os.path.join(scratch_dir), "few_shot_prompt.txt") + + # DSL prompt + dsl_prompt = get_prompt_for_backend( + ref_arch_src=ref_arch_src, + backend="triton", + option="one_shot", + precision="fp32", + ) + log_prompt(dsl_prompt, os.path.join(scratch_dir), "dsl_prompt.txt") + + # hardware prompt + hardware_prompt = get_prompt_for_backend( + ref_arch_src=ref_arch_src, + backend="cute", + option="one_shot", + precision="fp32", + include_hardware=True, + gpu_name="L40S", + ) + log_prompt(hardware_prompt, os.path.join(scratch_dir), "hardware_prompt.txt") + + # custom prompt defined in prompts.toml + custom_prompt = get_custom_prompt( + # the key is whatever you name the prompt in the custom_prompts section of the toml file + custom_key="custom", + + ref_arch_src=ref_arch_src, + backend="triton", + option="one_shot", + precision="fp32", + include_hardware=True, + gpu_name="L40S", + ) + log_prompt(custom_prompt, os.path.join(scratch_dir), "custom_prompt.txt") + +if __name__ == "__main__": + test_prompt() \ No newline at end of file diff --git a/src/prompts/prompts.toml b/src/prompts/prompts.toml new file mode 100644 index 00000000..bcf4e4ed --- /dev/null +++ b/src/prompts/prompts.toml @@ -0,0 +1,214 @@ +[meta] +version = "1.0" +default_backend = "cuda" +default_precision = "fp32" + +# ------------------------------------------------------------------------- +# Shared Templates: Used by all backends with placeholders +# ------------------------------------------------------------------------- +[shared] +problem_statement = """ +You write custom {backend_display} to replace the pytorch operators in the given architecture to get speedups. + +You have complete freedom to choose the set of operators you want to replace. You may make the decision to replace some operators with custom {backend_display} and leave others unchanged. You may replace multiple operators with custom implementations, consider operator fusion opportunities (combining multiple operators into a single kernel, for example, combining matmul+relu), or algorithmic changes (such as online softmax). You are only limited by your imagination. +""" + +instruction = """ +Optimize the architecture named Model with custom {backend_display}! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Just output the new model code, no other text, and NO testing code! +""" + +# Shared example architecture (same for all backends) +few_shot_example_arch = "src/prompts/model_ex_add.py" + +# ------------------------------------------------------------------------- +# Backends: Backend-specific configuration (minimal, just what varies) +# ------------------------------------------------------------------------- +[backends.cuda] +backend_display = "CUDA operators" +# One-shot example (baseline, always available) +one_shot_new_arch = "src/prompts/model_new_ex_add.py" +# Few-shot examples (optional, multiple example pairs) +few_shot_examples = [ + ["src/prompts/few_shot/model_ex_add.py", "src/prompts/few_shot/model_new_ex_add.py"], + ["src/prompts/few_shot/model_ex_fuse_gelu.py", "src/prompts/few_shot/model_new_ex_fuse_gelu.py"], + ["src/prompts/few_shot/model_ex_flash_attn.py", "src/prompts/few_shot/model_new_ex_flash_attn.py"], +] + +[backends.triton] +backend_display = "Triton kernels" +one_shot_new_arch = "src/prompts/model_new_ex_add_triton.py" +# No few_shot_examples - will use one-shot when few_shot option is selected + +[backends.cute] +backend_display = "CuTe (CUTLASS) kernels" +one_shot_new_arch = "src/prompts/model_new_ex_add_cute.py" +# No few_shot_examples - will use one-shot when few_shot option is selected + +[backends.tilelang] +backend_display = "TileLang kernels" +one_shot_new_arch = "src/prompts/model_new_ex_add_tilelang.py" +# No few_shot_examples - will use one-shot when few_shot option is selected + +# ------------------------------------------------------------------------- +# Precision: Precision-specific configuration +# ------------------------------------------------------------------------- +[precision.fp32] +precision_display = "FP32 (32-bit floating point)" + +[precision.fp16] +precision_display = "FP16 (16-bit floating point)" + +[precision.bf16] +precision_display = "BF16 (bfloat16)" + +# ------------------------------------------------------------------------- +# Templates: Reusable text blocks with placeholders +# ------------------------------------------------------------------------- +[templates.common] + +# --- Architecture Presentation --- +# Used to present the reference architecture/PyTorch kernel that needs optimization +arch_block = """ +You are given the following architecture: + + +{ref_arch_src} + +""" + +# ------------------------------------------------------------------------- +# Examples Block +# ------------------------------------------------------------------------- +# Shows example(s) of input architecture and optimized versions +# Dynamically formatted by Python code to handle single or multiple examples + +examples_block = """ +{examples_intro} + +{examples_entries} +""" + +# Different introductions for code examples depending on if its one shot or few shot + +example_intro_one_shot = """ +Here's an example to show you the syntax of inline embedding custom {backend_display} in PyTorch: +""" +example_intro_few_shot = """ +Here are examples showing how to embed custom {backend_display} in PyTorch: +""" + + +# Will inject an input example and output example according to the backend. + +example_entry_template = """ +{example_label} + +Input architecture: + +{input_code} + +Optimized with {backend_display}: + +{output_code} +""" + + +# ------------------------------------------------------------------------- +# Precision Information +# ------------------------------------------------------------------------- +# Specifies the target precision for optimization + +precision_note = """ +Note: The kernels should be optimized for {precision_display} precision. +""" + +# ------------------------------------------------------------------------- +# Custom Templates: Optional user-defined building blocks +# ------------------------------------------------------------------------- +# Add any custom template blocks here and reference them from components lists. + +# Example: + +custom_problem_statement = """ +Custom prompt intro goes here. You can reference {backend_display} or any +other placeholder supported in the shared context. +""" + + +# ------------------------------------------------------------------------- +# Hardware Templates: GPU-specific information blocks +# ------------------------------------------------------------------------- +[templates.hardware] +hardware_header = """ +Here is some information about the underlying hardware that you should keep in mind. +""" + +hardware_specs = """ +The GPU that will run the kernel is NVIDIA {gpu_name}, {gpu_architecture} architecture. + +{gpu_specs_bullets} +""" + +hardware_definitions = """ +Here are some concepts about the GPU architecture that could be helpful: + +{gpu_definitions_bullets} +""" + +hardware_best_practices = """ +Here are some best practices for writing kernels on GPU: + +{gpu_best_practices_bullets} +""" + +# ------------------------------------------------------------------------- +# Options: Different prompt construction modes +# ------------------------------------------------------------------------- + +[options.zero_shot] +# Zero-shot: No examples provided—the model must infer everything from the description +components = ["problem_statement", "arch_block", "precision_note", "instruction"] + +[options.one_shot] +# One-shot: Includes a single example to demonstrate the task +# This is the default KernelBench will use for model baseline performance +components = ["problem_statement", "examples_block", "arch_block", "precision_note", "instruction"] +requires_example = "one_shot" + +[options.few_shot] +# Few-shot: Multiple examples (falls back to one-shot if backend lacks few-shot entries) +components = ["problem_statement", "examples_block", "arch_block", "precision_note", "instruction"] +requires_example = "few_shot" + + +# ------------------------------------------------------------------------- +# Custom Prompts: user-defined prompt compositions +# ------------------------------------------------------------------------- + + +[custom_prompts.custom] +# Use this name with the CLI: pass custom_prompt_key=custom to +# generate_samples.py, generate_and_eval_single_sample.py, or the modal variant +# to load this block structure instead of the standard backend/option combo. +# If you want to add another prompt (e.g., [custom_prompts.custom2]), call it with +# custom_prompt_key=custom2 instead. + +# Define prompt composition here (ordering/extra sections). +# Backend, precision, hardware info, etc. must still be set via CLI flags +# Backend and precision in particular are required for evaluating your kernels. +# Hardware_info information must also be defined if you use any of the hardware +# templates. + +# Order the components for the prompt in whatever way you want and use any +# created templates you want +components = [ + "custom_problem_statement", + "problem_statement", + "hardware_header", + "hardware_specs", + "hardware_best_practices", + "arch_block", + "precision_note", + "examples_block", + "instruction", +] \ No newline at end of file