Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,13 +710,17 @@ def get_extensions():
print("Building mxfp8_cuda extension")
ext_modules.append(
CUDAExtension(
name="torchao.prototype.mxfp8_cuda",
name="torchao._C_mxfp8",
sources=mxfp8_sources,
include_dirs=[
mxfp8_extension_dir, # For mxfp8_quantize.cuh, mxfp8_extension.cpp, and mxfp8_cuda.cu
],
extra_compile_args={
"cxx": ["-std=c++17", "-O3"],
"cxx": [
f"-DPy_LIMITED_API={min_supported_cpython_hexcode}",
"-std=c++17",
"-O3",
],
"nvcc": nvcc_args
+ [
"-gencode=arch=compute_100,code=sm_100",
Expand Down
8 changes: 3 additions & 5 deletions test/prototype/moe_training/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
triton_fp8_per_group_rowwise_scales,
)
from torchao.prototype.moe_training.kernels.mxfp8 import (
mxfp8_quantize_cuda_3d,
torch_to_blocked_2d_K_groups,
torch_to_blocked_2d_M_groups,
torch_to_blocked_per_group_3d,
Expand Down Expand Up @@ -317,8 +318,6 @@ def test_triton_mx_block_rearrange_2d_K_groups(
@pytest.mark.parametrize("input_dtype", (torch.bfloat16,))
@pytest.mark.parametrize("scaling_mode", (ScaleCalculationMode.FLOOR,))
def test_cuda_mx_dim1_3d_numerics(E, N, K, input_dtype, scaling_mode):
from torchao.prototype import mxfp8_cuda

scaling_mode_str = (
"floor" if scaling_mode == ScaleCalculationMode.FLOOR else "rceil"
)
Expand All @@ -344,9 +343,8 @@ def test_cuda_mx_dim1_3d_numerics(E, N, K, input_dtype, scaling_mode):
y_d1_ref = y_d1_ref.transpose(-2, -1)
s_d1_ref = s_d1_ref.transpose(-2, -1)

# CUDA implementation (should work with any stride pattern)
y_d1, s_d1 = mxfp8_cuda.quantize_3d(
x, scale_dim_n=block_size, scaling_mode=scaling_mode_str
y_d1, s_d1 = mxfp8_quantize_cuda_3d(
x, block_size=block_size, scaling_mode=scaling_mode_str
)
# Check scales
torch.testing.assert_close(s_d1, s_d1_ref, rtol=0, atol=0)
Expand Down
42 changes: 3 additions & 39 deletions test/prototype/mx_formats/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
f32_to_f6_e2m3_unpacked,
f32_to_f6_e3m2_unpacked,
get_bits,
mxfp8_quantize_cuda,
pack_uint4,
triton_mxfp8_dequant_dim0,
triton_to_mxfp8_dim0,
Expand Down Expand Up @@ -541,8 +542,6 @@ def test_rearrange(shape):
"scaling_mode", (ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL)
)
def test_cuda_mx_dim1_numerics(M, K, input_dtype, scaling_mode):
from torchao.prototype import mxfp8_cuda

scaling_mode_str = (
"floor" if scaling_mode == ScaleCalculationMode.FLOOR else "rceil"
)
Expand All @@ -561,13 +560,11 @@ def test_cuda_mx_dim1_numerics(M, K, input_dtype, scaling_mode):
scaling_mode=scaling_mode,
)

_, y_d1, _, s_d1 = mxfp8_cuda.quantize(
_, y_d1, _, s_d1 = mxfp8_quantize_cuda(
x,
rowwise=False,
colwise=True,
scaling_mode=scaling_mode_str,
scale_dim_x=1,
scale_dim_y=block_size,
)

# check scales
Expand All @@ -587,48 +584,15 @@ def test_cuda_mx_dim1_numerics(M, K, input_dtype, scaling_mode):
reason="CUDA version >= 12.8 required for MXFP8 CUDA kernels",
)
def test_cuda_mx_dim0_not_supported():
from torchao.prototype import mxfp8_cuda

M, K = 64, 64
block_size = 32
x = (
torch.arange(0, M * K, dtype=torch.bfloat16, device="cuda")
.reshape(M, K)
.contiguous()
)
with pytest.raises(RuntimeError):
_, y_d1, _, s_d1 = mxfp8_cuda.quantize(
_, y_d1, _, s_d1 = mxfp8_quantize_cuda(
x,
rowwise=True,
colwise=False,
scale_dim_x=block_size,
scale_dim_y=1,
)


@pytest.mark.skipif(
not is_sm_at_least_100(),
reason="MXFP8 requires CUDA capability 10.0 or greater",
)
@pytest.mark.skipif(
not is_cuda_version_at_least(12, 8),
reason="CUDA version >= 12.8 required for MXFP8 CUDA kernels",
)
def test_cuda_mx_dim1_invalid_block_size():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deleting this test since block_size of 32 is hard coded in the python wrapper for the kernel now, since we always use this for mxfp8

from torchao.prototype import mxfp8_cuda

M, K = 64, 64
x = (
torch.arange(0, M * K, dtype=torch.bfloat16, device="cuda")
.reshape(M, K)
.contiguous()
)
invalid_block_size = 4
with pytest.raises(RuntimeError):
_, y_d1, _, s_d1 = mxfp8_cuda.quantize(
x,
rowwise=False,
colwise=True,
scale_dim_x=1,
scale_dim_y=invalid_block_size,
)
100 changes: 46 additions & 54 deletions torchao/csrc/cuda/mx_kernels/mxfp8_extension.cpp
Original file line number Diff line number Diff line change
@@ -1,32 +1,33 @@
// PyBind wrapping for the mxfp8 extension
// MXFP8 extension using TORCH_LIBRARY (CPython ABI agnostic)
#include <torch/library.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cstdint>
#include <string>
#include <torch/extension.h>

namespace mxfp8 {

// Forward declarations
void mxfp8_quantize_cuda(const torch::Tensor &input,
torch::Tensor &output_rowwise,
torch::Tensor &output_columnwise,
torch::Tensor &scales_rowwise,
torch::Tensor &scales_colwise,
void mxfp8_quantize_cuda(const at::Tensor &input,
at::Tensor &output_rowwise,
at::Tensor &output_columnwise,
at::Tensor &scales_rowwise,
at::Tensor &scales_colwise,
int64_t scale_dim_x,
int64_t scale_dim_y,
const std::string &fp8_format,
const std::string &scaling_mode);

void mxfp8_quantize_3d_cuda(const torch::Tensor &input,
torch::Tensor &output_colwise,
torch::Tensor &scales_colwise,
void mxfp8_quantize_3d_cuda(const at::Tensor &input,
at::Tensor &output_colwise,
at::Tensor &scales_colwise,
int64_t scale_dim_n,
const std::string &fp8_format,
const std::string &scaling_mode);

// Helper for tensor validation
void check_cuda_tensor(const torch::Tensor &t, const char *name) {
void check_cuda_tensor(const at::Tensor &t, const char *name) {
TORCH_CHECK(t.is_cuda(), name, " must be a CUDA tensor");
TORCH_CHECK(t.is_contiguous(), name, " must be contiguous");
}
Expand All @@ -46,8 +47,8 @@ void validate_scale_dimensions(int64_t scale_dim_x, int64_t scale_dim_y) {
}

// Main quantization function
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
mxfp8_quantize(torch::Tensor input, bool rowwise, bool colwise,
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
mxfp8_quantize(const at::Tensor& input, bool rowwise, bool colwise,
int64_t scale_dim_x, int64_t scale_dim_y,
const std::string &fp8_format,
const std::string &scaling_mode) {
Expand All @@ -57,9 +58,9 @@ mxfp8_quantize(torch::Tensor input, bool rowwise, bool colwise,
TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor");
TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
TORCH_CHECK(input.dim() == 2, "input must be 2D");
TORCH_CHECK(input.scalar_type() == torch::kFloat32 ||
input.scalar_type() == torch::kFloat16 ||
input.scalar_type() == torch::kBFloat16,
TORCH_CHECK(input.scalar_type() == at::kFloat ||
input.scalar_type() == at::kHalf ||
input.scalar_type() == at::kBFloat16,
"Input must be float32, float16, or bfloat16");
TORCH_CHECK(rowwise || colwise,
"At least one of rowwise or colwise must be true");
Expand All @@ -75,40 +76,40 @@ mxfp8_quantize(torch::Tensor input, bool rowwise, bool colwise,
c10::cuda::CUDAGuard device_guard(input.device());

// Create tensor options
const auto options_fp8 = torch::TensorOptions()
.dtype(torch::kFloat8_e4m3fn) // FP8 stored as uint8
const auto options_fp8 = at::TensorOptions()
.dtype(at::kFloat8_e4m3fn)
.device(input.device());

const auto options_scale = torch::TensorOptions()
.dtype(torch::kFloat8_e8m0fnu) // E8M0 stored as uint8
const auto options_scale = at::TensorOptions()
.dtype(at::kFloat8_e8m0fnu)
.device(input.device());

// Allocate output tensors
torch::Tensor output_rowwise, output_colwise;
torch::Tensor scales_rowwise, scales_colwise;
at::Tensor output_rowwise, output_colwise;
at::Tensor scales_rowwise, scales_colwise;

if (rowwise) {
const int64_t num_col_blocks = (cols + scale_dim_x - 1) / scale_dim_x;
output_rowwise = torch::empty({rows, cols}, options_fp8);
scales_rowwise = torch::empty({rows, num_col_blocks}, options_scale);
output_rowwise = at::empty({rows, cols}, options_fp8);
scales_rowwise = at::empty({rows, num_col_blocks}, options_scale);
} else {
output_rowwise = torch::empty({0}, options_fp8);
scales_rowwise = torch::empty({0}, options_scale);
output_rowwise = at::empty({0}, options_fp8);
scales_rowwise = at::empty({0}, options_scale);
}

if (colwise) {
const int64_t num_row_blocks = (rows + scale_dim_y - 1) / scale_dim_y;
output_colwise = torch::empty_strided({rows, cols}, {1, rows}, options_fp8);
output_colwise = at::empty_strided({rows, cols}, {1, rows}, options_fp8);
// Need scales_colwise to be this shape so the 'col' dim stride is 1,
// for colwise scaling, we can avoid uncoalesced writes to global memory.
// This is because each of the 32 threads in a warp will be computing
// a scale for a different column of 32 input data values, then each writing
// that scale to global memory - so the stride along this `col` dim should be 1
// so writes can be coalesced into a single transaction.
scales_colwise = torch::empty_strided({cols, num_row_blocks}, {1, cols}, options_scale);
scales_colwise = at::empty_strided({cols, num_row_blocks}, {1, cols}, options_scale);
} else {
output_colwise = torch::empty({0}, options_fp8);
scales_colwise = torch::empty({0}, options_scale);
output_colwise = at::empty({0}, options_fp8);
scales_colwise = at::empty({0}, options_scale);
}

// Call CUDA kernels
Expand All @@ -124,8 +125,8 @@ mxfp8_quantize(torch::Tensor input, bool rowwise, bool colwise,
}

// 3D tensor quantization function
std::tuple<torch::Tensor, torch::Tensor>
mxfp8_quantize_3d(torch::Tensor input, int64_t scale_dim_n,
std::tuple<at::Tensor, at::Tensor>
mxfp8_quantize_3d(const at::Tensor& input, int64_t scale_dim_n,
const std::string &fp8_format,
const std::string &scaling_mode) {

Expand All @@ -134,9 +135,9 @@ mxfp8_quantize_3d(torch::Tensor input, int64_t scale_dim_n,
TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
// Note: We don't check contiguous for 3D as it may have column major strides
TORCH_CHECK(input.dim() == 3, "input must be 3D");
TORCH_CHECK(input.scalar_type() == torch::kFloat32 ||
input.scalar_type() == torch::kFloat16 ||
input.scalar_type() == torch::kBFloat16,
TORCH_CHECK(input.scalar_type() == at::kFloat ||
input.scalar_type() == at::kHalf ||
input.scalar_type() == at::kBFloat16,
"Input must be float32, float16, or bfloat16");
TORCH_CHECK(scale_dim_n == 32, "scale_dim_n must be 32 for now");

Expand All @@ -154,21 +155,21 @@ mxfp8_quantize_3d(torch::Tensor input, int64_t scale_dim_n,
c10::cuda::CUDAGuard device_guard(input.device());

// Create tensor options
const auto options_fp8 = torch::TensorOptions()
.dtype(torch::kFloat8_e4m3fn)
const auto options_fp8 = at::TensorOptions()
.dtype(at::kFloat8_e4m3fn)
.device(input.device());

const auto options_scale = torch::TensorOptions()
.dtype(torch::kFloat8_e8m0fnu)
const auto options_scale = at::TensorOptions()
.dtype(at::kFloat8_e8m0fnu)
.device(input.device());

// Create output tensor with column major layout (required for downstream ops)
torch::Tensor output_colwise = torch::empty_strided(
at::Tensor output_colwise = at::empty_strided(
{E, N, K}, {N * K, 1, N}, options_fp8);

// Create scales tensor with shape (E, num_n_blocks, K)
const int64_t num_n_blocks = (N + scale_dim_n - 1) / scale_dim_n;
torch::Tensor scales_colwise = torch::empty({E, num_n_blocks, K}, options_scale);
at::Tensor scales_colwise = at::empty({E, num_n_blocks, K}, options_scale);

// Call CUDA kernel
mxfp8_quantize_3d_cuda(input, output_colwise, scales_colwise,
Expand All @@ -179,17 +180,8 @@ mxfp8_quantize_3d(torch::Tensor input, int64_t scale_dim_n,

} // namespace mxfp8

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "MXFP8 Quantization PyTorch Extension";

m.def("quantize", &mxfp8::mxfp8_quantize, "MXFP8 quantization",
py::arg("input"), py::arg("rowwise") = true, py::arg("colwise") = false,
py::arg("scale_dim_x") = 32, py::arg("scale_dim_y") = 32,
py::arg("fp8_format") = "e4m3",
py::arg("scaling_mode") = "floor");

m.def("quantize_3d", &mxfp8::mxfp8_quantize_3d, "MXFP8 3D quantization",
py::arg("input"), py::arg("scale_dim_n") = 32,
py::arg("fp8_format") = "e4m3",
py::arg("scaling_mode") = "floor");
// Register CUDA implementations
TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
m.impl("mxfp8_quantize", &mxfp8::mxfp8_quantize);
m.impl("mxfp8_quantize_3d", &mxfp8::mxfp8_quantize_3d);
}
Loading
Loading