Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ If you don't have GPU available locally, you can set up [Modal](https://modal.co

## 🚀 Usage
### Run on a single problem
It is easier to get started with a single problem. This will fetch the problem, generate a sample, and evaluate the sample.
It is easier to get started with a single problem. This will fetch the problem, generate a sample, and evaluate the sample.

```
# for example, run level 2 problem 40 from huggingface
Expand All @@ -106,7 +106,7 @@ python3 scripts/generate_and_eval_single_sample.py dataset_src="huggingface" lev
* **`precision`** - You can specify the precision of tensor by `precision=fp32`. Currently all of our reported results are `fp32` but we added support for `fp16` & `bf16`.
* **`backend`** - We are also supporting other GPU programming languages beyond `cuda`. Simply specify `backend=triton`. For now we support DSLs: `cuda`, `triton`, `cute`, `tilelang`.

Check the config fields for comprehensive set of options.
Check the config fields for comprehensive set of options. Note we provide the model with a one-shot example by default along with the minimum set of info; you can check out other prompt settings or construct your own in `src/prompt_constructor_toml.py`.

### Run on all problems

Expand Down
102 changes: 73 additions & 29 deletions scripts/generate_and_eval_single_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@

from src.dataset import construct_kernelbench_dataset
from src.eval import eval_kernel_against_ref
from src.prompt_constructor import prompt_generate_custom_cuda_from_prompt_template
from src.prompt_constructor_multilang import get_prompt_for_backend
from src.prompt_constructor_toml import get_prompt_for_backend, get_custom_prompt
from src.utils import (
create_inference_server_from_presets,
extract_first_code,
Expand All @@ -22,6 +21,9 @@
"""
Generate and evaluate a single sample
Easiest way to get started, to test a single problem for experimentation or debugging

Example usage:
python3 scripts/generate_and_eval_single_sample.py dataset_src=huggingface level=1 problem_id=1 eval_mode=local server_type=google model_name=gemini/gemini-2.5-flash max_tokens=8192 temperature=0.0
"""

REPO_TOP_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
Expand Down Expand Up @@ -72,6 +74,12 @@ def __init__(self):

self.backend = "cuda"

# Prompt construction
self.prompt_option = "one_shot" # choices: zero_shot, one_shot, few_shot
self.include_hardware_info = False
self.hardware_gpu_name = None
self.custom_prompt_key = None

def verbose_logging(self):
self.log = True
self.log_prompt = True
Expand All @@ -86,6 +94,7 @@ def __repr__(self):
def main(config: EvalConfig):
"""
Keep it simple: Generate and evaluate a single sample
Note: will shorten code logic to make this as simple as possible
"""
from src.utils import SERVER_PRESETS

Expand Down Expand Up @@ -129,6 +138,7 @@ def main(config: EvalConfig):
config.problem_id <= num_problems
), f"Problem ID {config.problem_id} out of range for Level {config.level}"

# TODO: refactor dataset fetching logic to be as clean as posisble.
# 1. Fetch Problem
if config.dataset_src == "huggingface":

Expand Down Expand Up @@ -169,24 +179,70 @@ def main(config: EvalConfig):
budget_tokens=config.budget_tokens,
)

# Prompt Construction (Note: could be shortened in future PR)
custom_prompt_key = getattr(config, "custom_prompt_key", None)
if isinstance(custom_prompt_key, str):
trimmed = custom_prompt_key.strip()
if trimmed.lower() in {"", "none"}:
custom_prompt_key = None
else:
custom_prompt_key = trimmed
config.custom_prompt_key = custom_prompt_key

# Use appropriate prompt constructor based on backend
if config.backend == "cuda":
custom_prompt = prompt_generate_custom_cuda_from_prompt_template(ref_arch_src)
elif config.backend in ["triton", "tilelang", "cute"]:
custom_prompt = get_prompt_for_backend(ref_arch_src, config.backend)
else:
prompt_option = str(config.prompt_option).lower()
valid_prompt_options = {"zero_shot", "one_shot", "few_shot"}
include_hardware = config.include_hardware_info
if isinstance(include_hardware, str):
include_hardware = include_hardware.lower() in ["true", "1", "yes"]
config.include_hardware_info = include_hardware

supported_backends = {"cuda", "triton", "tilelang", "cute"}
backend = config.backend.lower()
if backend not in supported_backends:
raise ValueError(
f"Unsupported backend: {config.backend}. Must be 'cuda', 'triton', 'tilelang', or 'cute'."
f"Unsupported backend: {config.backend}. Must be one of {sorted(supported_backends)}."
)

if backend == "tilelang":
config.precision = "fp16" # tilelang only operates with fp16
config.hardware_gpu_name = config.hardware_gpu_name or getattr(config, "gpu", None)

if not custom_prompt_key:
if prompt_option not in valid_prompt_options:
raise ValueError(
f"Invalid prompt_option '{config.prompt_option}'. "
f"Must be one of {sorted(valid_prompt_options)}."
)
if include_hardware and not config.hardware_gpu_name:
raise ValueError(
"include_hardware_info is True but hardware_gpu_name is not provided."
)

if custom_prompt_key:
custom_prompt = get_custom_prompt(
custom_prompt_key,
ref_arch_src=ref_arch_src,
backend=backend,
option=prompt_option,
precision=config.precision,
include_hardware=include_hardware,
gpu_name=config.hardware_gpu_name,
)
else:
custom_prompt = get_prompt_for_backend(
ref_arch_src,
backend,
option=prompt_option,
precision=config.precision,
include_hardware=include_hardware,
gpu_name=config.hardware_gpu_name,
)

os.makedirs(config.logdir, exist_ok=True)

if config.log_prompt:
with open(
os.path.join(
config.logdir,
f"prompt_level_{config.level}_problem_{config.problem_id}.txt",
),
"w",
) as f:
with open(os.path.join(config.logdir, f"prompt_level_{config.level}_problem_{config.problem_id}.txt"), "w") as f:
f.write(custom_prompt)

# Query server with constructed prompt
Expand All @@ -200,13 +256,7 @@ def main(config: EvalConfig):

# this should be optional
if config.log:
with open(
os.path.join(
config.logdir,
f"generated_kernel_level_{config.level}_problem_{config.problem_id}.py",
),
"w",
) as f:
with open(os.path.join(config.logdir, f"generated_kernel_level_{config.level}_problem_{config.problem_id}.py"), "w") as f:
f.write(custom_kernel)

# 3. Evaluate Kernel
Expand All @@ -228,13 +278,7 @@ def main(config: EvalConfig):
)

if config.log:
with open(
os.path.join(
config.logdir,
f"eval_result_level_{config.level}_problem_{config.problem_id}.txt",
),
"a",
) as f:
with open(os.path.join(config.logdir, f"eval_result_level_{config.level}_problem_{config.problem_id}.txt"), "a",) as f:
f.write(f"Problem Name: {problem_name}\n")
f.write(str(kernel_exec_result))

Expand Down
70 changes: 62 additions & 8 deletions scripts/generate_and_eval_single_sample_modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@

#from src.dataset import construct_kernelbench_dataset
from src.eval import eval_kernel_against_ref
from src.prompt_constructor import prompt_generate_custom_cuda_from_prompt_template
from src.prompt_constructor_multilang import get_prompt_for_backend
from src.prompt_constructor_toml import get_prompt_for_backend, get_custom_prompt
from src.utils import extract_first_code, query_server, set_gpu_arch, read_file, create_inference_server_from_presets

app = modal.App("eval_single_sample")
Expand Down Expand Up @@ -76,6 +75,11 @@ def __init__(self):
self.log_eval_result = False

self.backend = "cuda"
# Prompt generation settings
self.prompt_option = "one_shot" # zero_shot, one_shot, few_shot
self.include_hardware_info = False
self.hardware_gpu_name = None
self.custom_prompt_key = None

def verbose_logging(self):
self.log = True
Expand Down Expand Up @@ -194,14 +198,64 @@ def main(config: EvalConfig):
budget_tokens=config.budget_tokens)


custom_prompt_key = getattr(config, "custom_prompt_key", None)
if isinstance(custom_prompt_key, str):
trimmed = custom_prompt_key.strip()
if trimmed.lower() in {"", "none"}:
custom_prompt_key = None
else:
custom_prompt_key = trimmed
config.custom_prompt_key = custom_prompt_key

# Checks if user has inputted a valid argument for how many examples they want to give as context to the model
prompt_option = str(config.prompt_option).lower()
valid_prompt_options = {"zero_shot", "one_shot", "few_shot"}
include_hardware = config.include_hardware_info
if isinstance(include_hardware, str):
include_hardware = include_hardware.lower() in ["true", "1", "yes"]
config.include_hardware_info = include_hardware

supported_backends = {"cuda", "triton", "tilelang", "cute"}
backend = config.backend.lower()
if backend not in supported_backends:
raise ValueError(
f"Unsupported backend: {config.backend}. Must be one of {sorted(supported_backends)}."
)

# Use appropriate prompt constructor based on backend
if config.backend == "cuda":
custom_prompt = prompt_generate_custom_cuda_from_prompt_template(ref_arch_src)
elif config.backend in ["triton", "tilelang", "cute"]:
custom_prompt = get_prompt_for_backend(ref_arch_src, config.backend)
#tilelang only supports fp16 or bf16
if backend == "tilelang":
config.precision = "fp16"
config.hardware_gpu_name = config.hardware_gpu_name or getattr(config, "gpu", None)

if not custom_prompt_key:
if prompt_option not in valid_prompt_options:
raise ValueError(
f"Invalid prompt_option '{config.prompt_option}'. Must be one of {sorted(valid_prompt_options)}."
)
if include_hardware and not config.hardware_gpu_name:
raise ValueError(
"include_hardware_info is True but hardware_gpu_name is not provided."
)

if custom_prompt_key:
custom_prompt = get_custom_prompt(
custom_prompt_key,
ref_arch_src=ref_arch_src,
backend=backend,
option=prompt_option,
precision=config.precision,
include_hardware=include_hardware,
gpu_name=config.hardware_gpu_name,
)
else:
raise ValueError(f"Unsupported backend: {config.backend}. Must be 'cuda', 'triton', 'tilelang', or 'cute'.")
custom_prompt = get_prompt_for_backend(
ref_arch_src,
backend,
option=prompt_option,
precision=config.precision,
include_hardware=include_hardware,
gpu_name=config.hardware_gpu_name,
)

if config.log_prompt:
with open(os.path.join(config.logdir, f"prompt_level_{config.level}_problem_{config.problem_id}.txt"), "w") as f:
Expand Down
77 changes: 62 additions & 15 deletions scripts/generate_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@

from src.dataset import construct_kernelbench_dataset
from src.eval import eval_kernel_against_ref
from src.prompt_constructor import prompt_generate_custom_cuda_from_prompt_template
from src.prompt_constructor_multilang import get_prompt_for_backend
from src.prompt_constructor_toml import get_prompt_for_backend, get_custom_prompt
from src.utils import (
create_inference_server_from_presets,
extract_first_code,
Expand Down Expand Up @@ -80,6 +79,10 @@ def __init__(self):
self.backend = "cuda"

self.precision = "fp32"
self.prompt_option = "one_shot" # zero_shot, one_shot, few_shot
self.include_hardware_info = False
self.hardware_gpu_name = None
self.custom_prompt_key = None

def greedy(self):
# For greedy decoding, epsecially baseline eval
Expand Down Expand Up @@ -126,30 +129,38 @@ def generate_sample_single(
problem_number == work.problem_id
), f"Problem number in filename ({problem_number}) does not match config problem_id ({config.problem_id})"

# Construct Prompt
if config.backend == "cuda":
custom_cuda_prompt = prompt_generate_custom_cuda_from_prompt_template(
ref_arch_src
if config.custom_prompt_key:
custom_prompt = get_custom_prompt(
config.custom_prompt_key,
ref_arch_src=ref_arch_src,
backend=config.backend,
option=config.prompt_option,
precision=config.precision,
include_hardware=config.include_hardware_info,
gpu_name=config.hardware_gpu_name,
)
elif config.backend in ["triton", "cute", "tilelang"]:
custom_cuda_prompt = get_prompt_for_backend(ref_arch_src, config.backend)
else:
raise ValueError(
f"Unsupported backend: {config.backend}. Must be 'cuda', 'triton', 'cute', or 'tilelang'."
custom_prompt = get_prompt_for_backend(
ref_arch_src,
config.backend,
option=config.prompt_option,
precision=config.precision,
include_hardware=config.include_hardware_info,
gpu_name=config.hardware_gpu_name,
)
if config.log_prompt:
prompt_path = os.path.join(
run_dir,
f"level_{config.level}_problem_{work.problem_id}_sample_{work.sample_id}_prompt.txt",
)
with open(prompt_path, "w") as f:
f.write(custom_cuda_prompt)
f.write(custom_prompt)

# Query server with constructed prompt
custom_cuda = inference_server(custom_cuda_prompt)
custom_cuda = extract_first_code(custom_cuda, ["python", "cpp"])
custom_kernel = inference_server(custom_prompt)
custom_kernel = extract_first_code(custom_kernel, ["python", "cpp"])
# check LLM is able to generate custom CUDA code
assert custom_cuda is not None, "Custom CUDA code generation failed"
assert custom_kernel is not None, "Custom CUDA code generation failed"

if config.verbose:
print(
Expand All @@ -162,7 +173,7 @@ def generate_sample_single(
f"level_{config.level}_problem_{work.problem_id}_sample_{work.sample_id}_kernel.py",
)
with open(kernel_path, "w") as f:
f.write(custom_cuda)
f.write(custom_kernel)

return True

Expand Down Expand Up @@ -214,6 +225,42 @@ def main(config: GenerationConfig):
if isinstance(config.is_reasoning_model, str):
config.is_reasoning_model = config.is_reasoning_model.lower() in ['true', '1', 'yes']

custom_prompt_key = getattr(config, "custom_prompt_key", None)
if isinstance(custom_prompt_key, str):
trimmed = custom_prompt_key.strip()
if trimmed.lower() in {"", "none"}:
custom_prompt_key = None
else:
custom_prompt_key = trimmed
config.custom_prompt_key = custom_prompt_key

include_hardware = config.include_hardware_info
if isinstance(include_hardware, str):
include_hardware = include_hardware.lower() in ["true", "1", "yes"]
config.include_hardware_info = include_hardware

supported_backends = {"cuda", "triton", "cute", "tilelang"}
backend = config.backend.lower()
if backend not in supported_backends:
raise ValueError(
f"Unsupported backend: {config.backend}. Must be one of {sorted(supported_backends)}."
)
config.backend = backend
if backend == "tilelang":
config.precision = "fp16"

config.prompt_option = str(config.prompt_option).lower()
valid_prompt_options = {"zero_shot", "one_shot", "few_shot"}
if not config.custom_prompt_key:
if config.prompt_option not in valid_prompt_options:
raise ValueError(
f"Invalid prompt_option '{config.prompt_option}'. Must be one of {sorted(valid_prompt_options)}."
)
if include_hardware and not config.hardware_gpu_name:
raise ValueError(
"include_hardware_info is True but hardware_gpu_name is not provided."
)

print(f"Starting Batch Generation with config: {config}")

# Dataset Configurations
Expand Down
Loading