From d32b7ab9ed3de67ae94b2d83a9f7af720829fe2d Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Fri, 12 Dec 2025 04:17:16 -0800 Subject: [PATCH 01/10] avoiding shape copy, torch dynamo and torch autograd overheads Signed-off-by: Varun Thumbe --- transformer_engine/pytorch/csrc/common.cpp | 62 +++++++- transformer_engine/pytorch/csrc/common.h | 19 ++- .../pytorch/csrc/extensions/bias.cpp | 9 +- .../pytorch/csrc/extensions/gemm.cpp | 8 +- .../pytorch/csrc/extensions/transpose.cpp | 6 +- transformer_engine/pytorch/csrc/quantizer.cpp | 34 +++-- .../pytorch/csrc/type_converters.cpp | 4 +- transformer_engine/pytorch/module/linear.py | 138 ++++++++---------- 8 files changed, 173 insertions(+), 107 deletions(-) diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index e054424dd4d..f7a8540197f 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -26,12 +26,8 @@ std::vector convert_shape_back_from_fp4(const std::vector& shape return ret; } -std::vector getTensorShape(const at::Tensor& t) { - std::vector shape; - for (auto s : t.sizes()) { - shape.push_back(s); - } - return shape; +NVTEShape getTensorShape(const at::Tensor& t) { + return convertTorchShape(t.sizes()); } NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape) { @@ -178,6 +174,38 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( return ret; } +transformer_engine::TensorWrapper makeTransformerEngineTensor( + void* data_ptr, const NVTEShape& shape, const transformer_engine::DType type, + void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape, + NVTEScalingMode scaling_mode) { + TensorWrapper ret(scaling_mode); + ret.set_rowwise_data(data_ptr, type, shape); + const size_t meta_shape_data[1] = {1}; + const NVTEShape meta_shape = nvte_make_shape(meta_shape_data, 1); + ret.set_amax(amax_ptr, DType::kFloat32, meta_shape); + ret.set_scale(scale_ptr, DType::kFloat32, meta_shape); + auto scale_inv_dtype = + (scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 : DType::kFloat32; + ret.set_rowwise_scale_inv(scale_inv_ptr, scale_inv_dtype, scale_inv_shape); + return ret; +} + +transformer_engine::TensorWrapper makeTransformerEngineTensor( + void* data_ptr, const std::vector& shape, const transformer_engine::DType type, + void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape, + NVTEScalingMode scaling_mode) { + TensorWrapper ret(scaling_mode); + ret.set_rowwise_data(data_ptr, type, shape); + const size_t meta_shape_data[1] = {1}; + const NVTEShape meta_shape = nvte_make_shape(meta_shape_data, 1); + ret.set_amax(amax_ptr, DType::kFloat32, meta_shape); + ret.set_scale(scale_ptr, DType::kFloat32, meta_shape); + auto scale_inv_dtype = + (scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 : DType::kFloat32; + ret.set_rowwise_scale_inv(scale_inv_ptr, scale_inv_dtype, scale_inv_shape); + return ret; +} + transformer_engine::TensorWrapper makeTransformerEngineTensor( void* data_ptr, void* columnwise_data_ptr, const std::vector& shape, const std::vector& columnwise_shape, const transformer_engine::DType type, @@ -199,6 +227,28 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( return ret; } +transformer_engine::TensorWrapper makeTransformerEngineTensor( + void* data_ptr, void* columnwise_data_ptr, const NVTEShape& shape, + const NVTEShape& columnwise_shape, const transformer_engine::DType type, + void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr, + const NVTEShape& scale_inv_shape, const NVTEShape& columnwise_scale_inv_shape, + NVTEScalingMode scaling_mode) { + TensorWrapper ret(scaling_mode); + ret.set_rowwise_data(data_ptr, type, shape); + ret.set_columnwise_data(columnwise_data_ptr, type, columnwise_shape); + const size_t meta_shape_data[1] = {1}; + const NVTEShape meta_shape = nvte_make_shape(meta_shape_data, 1); + ret.set_amax(amax_ptr, DType::kFloat32, meta_shape); + ret.set_scale(scale_ptr, DType::kFloat32, meta_shape); + auto scale_inv_dtype = (scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 + : (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat8E4M3 + : DType::kFloat32; + ret.set_rowwise_scale_inv(scale_inv_ptr, scale_inv_dtype, scale_inv_shape); + ret.set_columnwise_scale_inv(columnwise_scale_inv_ptr, scale_inv_dtype, + columnwise_scale_inv_shape); + return ret; +} + transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor, at::Tensor amax, const at::Tensor scale, at::Tensor scale_inv, diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 978bee52dc1..883c2a24cad 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -339,7 +339,7 @@ class NVFP4Quantizer : public Quantizer { std::unique_ptr convert_quantizer(py::handle quantizer); -std::vector getTensorShape(const at::Tensor& t); +NVTEShape getTensorShape(const at::Tensor& t); transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, const std::string& fp8_recipe); @@ -432,6 +432,16 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, std::vector scale_inv_shape = {1}, NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); +transformer_engine::TensorWrapper makeTransformerEngineTensor( + void* data_ptr, const NVTEShape& shape, const transformer_engine::DType type, + void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape, + NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); + +transformer_engine::TensorWrapper makeTransformerEngineTensor( + void* data_ptr, const std::vector& shape, const transformer_engine::DType type, + void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape, + NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); + transformer_engine::TensorWrapper makeTransformerEngineTensor( void* data_ptr, void* columnwise_data_ptr, const std::vector& shape, const std::vector& columnwise_shape, const transformer_engine::DType type, @@ -440,6 +450,13 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( const std::vector& columnwise_scale_inv_shape = {1}, NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); +transformer_engine::TensorWrapper makeTransformerEngineTensor( + void* data_ptr, void* columnwise_data_ptr, const NVTEShape& shape, + const NVTEShape& columnwise_shape, const transformer_engine::DType type, + void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr, + const NVTEShape& scale_inv_shape, const NVTEShape& columnwise_scale_inv_shape, + NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); + transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr, const NVTEShape& shape, const transformer_engine::DType type); diff --git a/transformer_engine/pytorch/csrc/extensions/bias.cpp b/transformer_engine/pytorch/csrc/extensions/bias.cpp index b0435d27230..2eef7438068 100644 --- a/transformer_engine/pytorch/csrc/extensions/bias.cpp +++ b/transformer_engine/pytorch/csrc/extensions/bias.cpp @@ -26,7 +26,8 @@ std::vector bgrad_quantize(const at::Tensor &grad_output, py::handle // Grad output tensor auto grad_output_torch = grad_output.contiguous(); const TensorWrapper &grad_output_nvte = makeTransformerEngineTensor(grad_output_torch); - const auto shape = getTensorShape(grad_output_torch); + const auto shape_nvte = getTensorShape(grad_output_torch); + const auto shape = convertShape(shape_nvte); auto grad_output_dtype = GetTransformerEngineDType(grad_output_torch.scalar_type()); // Construct grad bias tensor @@ -116,11 +117,13 @@ std::vector dact_dbias( // Grad output and activation input tensors grad_output_torch = grad_output_torch.contiguous(); const TensorWrapper &grad_output_nvte = makeTransformerEngineTensor(grad_output_torch); - const auto output_shape = getTensorShape(grad_output_torch); + const auto output_shape_nvte = getTensorShape(grad_output_torch); + const auto output_shape = convertShape(output_shape_nvte); auto grad_output_dtype = GetTransformerEngineDType(grad_output_torch.scalar_type()); act_input_torch = act_input_torch.contiguous(); const TensorWrapper &act_input_nvte = makeTransformerEngineTensor(act_input_torch); - const auto input_shape = getTensorShape(act_input_torch); + const auto input_shape_nvte = getTensorShape(act_input_torch); + const auto input_shape = convertShape(input_shape_nvte); // Construct tensors auto quantizer_cpp = convert_quantizer(quantizer_py); diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 13e8bfb6e5f..f704864cb60 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -365,12 +365,16 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type, NVTEScalingMode nvte_scaling_modeA = NVTE_DELAYED_TENSOR_SCALING; NVTEScalingMode nvte_scaling_modeB = NVTE_DELAYED_TENSOR_SCALING; + const size_t A_shape_data[2] = {static_cast(A.size(0)), static_cast(A.size(1))}; + const NVTEShape A_shape = nvte_make_shape(A_shape_data, 2); auto te_A = makeTransformerEngineTensor( - A.data_ptr(), {static_cast(A.size(0)), static_cast(A.size(1))}, A_type, + A.data_ptr(), A_shape, A_type, nullptr, nullptr, A_scale_inverse.data_ptr(), getTensorShape(A_scale_inverse), nvte_scaling_modeA); + const size_t B_shape_data[2] = {static_cast(B.size(0)), static_cast(B.size(1))}; + const NVTEShape B_shape = nvte_make_shape(B_shape_data, 2); auto te_B = makeTransformerEngineTensor( - B.data_ptr(), {static_cast(B.size(0)), static_cast(B.size(1))}, B_type, + B.data_ptr(), B_shape, B_type, nullptr, nullptr, B_scale_inverse.data_ptr(), getTensorShape(B_scale_inverse), nvte_scaling_modeB); // TODO: D_scale_inv cannot be nullptr when D_type is FP8. diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cpp b/transformer_engine/pytorch/csrc/extensions/transpose.cpp index 7dfdf995475..5ace996afcc 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cpp +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cpp @@ -19,7 +19,8 @@ at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional transpose_shape_int64; if (shape.size() > 0) { transpose_shape_int64.push_back(shape.back()); @@ -60,7 +61,8 @@ at::Tensor swap_first_dims(at::Tensor tensor, std::optional out) { // Allocate output tensor if needed if (!out) { - auto in_shape = getTensorShape(input); + const auto in_shape_nvte = getTensorShape(input); + const auto in_shape = convertShape(in_shape_nvte); NVTE_CHECK(in_shape.size() >= 2, "Invalid input tensor dimensions (shape=", in_shape, ")"); std::vector out_shape_int64(in_shape.begin(), in_shape.end()); out_shape_int64[0] = static_cast(in_shape[1]); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index d7e8912ac74..3b94d38ac16 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -209,7 +209,8 @@ std::pair Float8Quantizer::convert_and_update_tensor( // Tensor dimensions std::vector shape; if (has_transpose) { - const auto transpose_shape = getTensorShape(*transpose_tensor); + const auto transpose_shape_nvte = getTensorShape(*transpose_tensor); + const auto transpose_shape = convertShape(transpose_shape_nvte); if (transpose_shape.size() > 0) { for (size_t i = 1; i < transpose_shape.size(); ++i) { shape.push_back(transpose_shape[i]); @@ -217,12 +218,13 @@ std::pair Float8Quantizer::convert_and_update_tensor( shape.push_back(transpose_shape.front()); } if (has_data) { - auto expected_shape = getTensorShape(*data_tensor); + const auto expected_shape_nvte = getTensorShape(*data_tensor); + const auto expected_shape = convertShape(expected_shape_nvte); NVTE_CHECK(shape == expected_shape, "FP8 data (shape=", expected_shape, ") and transpose (shape=", transpose_shape, ") do not match"); } } else { // Already checked has_data == true - shape = getTensorShape(*data_tensor); + shape = convertShape(getTensorShape(*data_tensor)); } // Coerce data tensor @@ -430,7 +432,8 @@ std::pair Float8CurrentScalingQuantizer::convert_and_ // Tensor dimensions std::vector shape; if (has_transpose) { - const auto transpose_shape = getTensorShape(*transpose_tensor); + const auto transpose_shape_nvte = getTensorShape(*transpose_tensor); + const auto transpose_shape = convertShape(transpose_shape_nvte); if (transpose_shape.size() > 0) { for (size_t i = 1; i < transpose_shape.size(); ++i) { shape.push_back(transpose_shape[i]); @@ -438,12 +441,13 @@ std::pair Float8CurrentScalingQuantizer::convert_and_ shape.push_back(transpose_shape.front()); } if (has_data) { - auto expected_shape = getTensorShape(*data_tensor); + const auto expected_shape_nvte = getTensorShape(*data_tensor); + const auto expected_shape = convertShape(expected_shape_nvte); NVTE_CHECK(shape == expected_shape, "FP8 data (shape=", expected_shape, ") and transpose (shape=", transpose_shape, ") do not match"); } } else { // Already checked has_data == true - shape = getTensorShape(*data_tensor); + shape = convertShape(getTensorShape(*data_tensor)); } // Coerce data tensor in Python tensor @@ -680,9 +684,9 @@ std::pair Float8BlockQuantizer::convert_and_update_te return std::vector(); } if (all_gather_usage) { - return getTensorShape(*columnwise_data); + return convertShape(getTensorShape(*columnwise_data)); } - std::vector shape = getTensorShape(*columnwise_data); + std::vector shape = convertShape(getTensorShape(*columnwise_data)); std::vector shape_transposed(shape.size()); for (size_t i = 0; i + 1 < shape.size(); ++i) { shape_transposed[i] = shape[i + 1]; @@ -694,7 +698,7 @@ std::pair Float8BlockQuantizer::convert_and_update_te }; std::vector shape; if (rowwise_data) { - shape = getTensorShape(*rowwise_data); + shape = convertShape(getTensorShape(*rowwise_data)); if (columnwise_data) { auto expected_shape = get_columnwise_shape(all_gather_usage); NVTE_CHECK(shape == expected_shape, "BlockwiseFP8 row-wise data (shape=", shape, @@ -1004,14 +1008,14 @@ std::pair MXFP8Quantizer::convert_and_update_tensor( // Tensor dimensions std::vector shape; if (columnwise_data) { - shape = getTensorShape(*columnwise_data); + shape = convertShape(getTensorShape(*columnwise_data)); if (rowwise_data) { - auto expected_shape = getTensorShape(*rowwise_data); + const auto expected_shape = convertShape(getTensorShape(*rowwise_data)); NVTE_CHECK(shape == expected_shape, "MXFP8 row-wise data (shape=", expected_shape, ") and column-wise data (shape=", shape, ") do not match"); } } else { // Already checked columnwise_data_tensor == true - shape = getTensorShape(*rowwise_data); + shape = convertShape(getTensorShape(*rowwise_data)); } // Coerce row-wise data @@ -1320,14 +1324,14 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( // Tensor dimensions, shape means original shape std::vector shape; if (columnwise_data) { - shape = convert_shape_back_from_fp4(getTensorShape(*columnwise_data), true); + shape = convert_shape_back_from_fp4(convertShape(getTensorShape(*columnwise_data)), true); if (rowwise_data) { - auto expected_shape = convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false); + auto expected_shape = convert_shape_back_from_fp4(convertShape(getTensorShape(*rowwise_data)), false); NVTE_CHECK(shape == expected_shape, "NVFP4 row-wise data (shape=", expected_shape, ") and column-wise data (shape=", shape, ") do not match"); } } else { // Already checked columnwise_data_tensor == true - shape = convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false); + shape = convert_shape_back_from_fp4(convertShape(getTensorShape(*rowwise_data)), false); } size_t flat_first_dim = 1; diff --git a/transformer_engine/pytorch/csrc/type_converters.cpp b/transformer_engine/pytorch/csrc/type_converters.cpp index 368e9dcdfa3..780a08da7f8 100644 --- a/transformer_engine/pytorch/csrc/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/type_converters.cpp @@ -132,7 +132,7 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) const auto &scale_inv = tensor.attr("_rowwise_scale_inv").cast(); const auto &amax_rowwise = tensor.attr("_amax_rowwise").cast(); ret.set_rowwise_data(data.data_ptr(), dtype, - convert_shape_back_from_fp4(getTensorShape(data), false)); + convert_shape_back_from_fp4(convertShape(getTensorShape(data)), false)); ret.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E4M3, getTensorShape(scale_inv)); ret.set_amax(amax_rowwise.data_ptr(), DType::kFloat32, getTensorShape(amax_rowwise)); } @@ -143,7 +143,7 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) const auto &scale_inv = tensor.attr("_columnwise_scale_inv").cast(); const auto &amax_columnwise = tensor.attr("_amax_columnwise").cast(); ret.set_columnwise_data(data.data_ptr(), DType::kFloat4E2M1, - convert_shape_back_from_fp4(getTensorShape(data), false)); + convert_shape_back_from_fp4(convertShape(getTensorShape(data)), false)); ret.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E4M3, getTensorShape(scale_inv)); ret.set_columnwise_amax(amax_columnwise.data_ptr(), DType::kFloat32, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index b65f7005eb3..7557f5c5396 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -96,40 +96,66 @@ def forward( ( is_first_microbatch, - fp8, - fp8_calibration, - wgrad_store, - input_quantizer, - weight_quantizer, - output_quantizer, - grad_input_quantizer, - grad_weight_quantizer, - grad_output_quantizer, - fuse_wgrad_accumulation, cpu_offloading, - tp_group, - tp_size, - sequence_parallel, - tensor_parallel, - activation_dtype, - parallel_mode, is_grad_enabled, - ub_overlap_rs_fprop, - ub_overlap_ag_dgrad, - ub_overlap_ag_fprop, - ub_overlap_rs_dgrad, - ub_bulk_dgrad, - ub_bulk_wgrad, - ub_name, - fp8_output, # pylint: disable=unused-variable - fsdp_group, + fp8_output, + fp8_grad, module, skip_fp8_weight_update, - symmetric_ar_type, - save_original_input, debug, ) = non_tensor_args + (fp8, + fp8_calibration, + wgrad_store, + fuse_wgrad_accumulation, + tp_group, + tp_size, + sequence_parallel, + tensor_parallel, + activation_dtype, + parallel_mode, + ub_overlap_rs_fprop, + ub_overlap_ag_dgrad, + ub_overlap_ag_fprop, + ub_overlap_rs_dgrad, + ub_bulk_dgrad, + ub_bulk_wgrad, + ub_name, + fsdp_group, + symmetric_ar_type, + save_original_input + ) = (module.fp8, + module.fp8_calibration, + module.wgrad_store, + module.fuse_wgrad_accumulation, + module.tp_group, + module.tp_size, + module.sequence_parallel, + module.tp_size > 1, + module.activation_dtype, + module.parallel_mode, + module.ub_overlap_rs_fprop, + module.ub_overlap_ag_dgrad, + module.ub_overlap_ag_fprop, + module.ub_overlap_rs_dgrad, + module.ub_bulk_dgrad, + module.ub_bulk_wgrad, + module.ub_name, + module.fsdp_group, + module.symmetric_ar_type, + module.save_original_input, + ) + quantizers = module._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) + + if debug: + quantizers = module._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) + if module.no_debug_features_active(quantizers): + debug = False + quantizers = module._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) + (input_quantizer, weight_quantizer, output_quantizer, grad_input_quantizer, grad_weight_quantizer, grad_output_quantizer) = quantizers + + # NVTX label for profiling nvtx_label = "transformer_engine._Linear.forward" if ub_name is not None: @@ -981,7 +1007,6 @@ def wgrad_gemm( None, ) - class Linear(TransformerEngineBaseModule): """Applies a linear transformation to the incoming data :math:`y = xA^T + b` @@ -1343,7 +1368,6 @@ def reset_parameters(self, defer_init=False): elif self.parallel_mode == "column": set_tensor_model_parallel_attributes(getattr(self, bias), True, 0, 1) - @no_torch_dynamo() def forward( self, inp: torch.Tensor, @@ -1401,28 +1425,7 @@ def forward( inp, allow_non_contiguous=isinstance(inp, QuantizedTensor), ) as inp: - weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() - - quantizers = ( - self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) - if not debug - else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) - ) - if debug: - if self.no_debug_features_active(quantizers): - debug = False - quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) - - ( - input_quantizer, - weight_quantizer, - output_quantizer, - grad_input_quantizer, - grad_weight_quantizer, - grad_output_quantizer, - ) = quantizers - if is_grad_enabled: linear_fn = _Linear.apply autograd_ctx = [] @@ -1432,37 +1435,12 @@ def forward( non_tensor_args = ( is_first_microbatch, - self.fp8, - self.fp8_calibration, - self.wgrad_store, - input_quantizer, - weight_quantizer, - output_quantizer, - grad_input_quantizer, - grad_weight_quantizer, - grad_output_quantizer, - self.fuse_wgrad_accumulation, is_cpu_offload_enabled(), - self.tp_group, - self.tp_size, - self.sequence_parallel, - self.tp_size > 1, - self.activation_dtype, - self.parallel_mode, is_grad_enabled, - self.ub_overlap_rs_fprop, - self.ub_overlap_ag_dgrad, - self.ub_overlap_ag_fprop, - self.ub_overlap_rs_dgrad, - self.ub_bulk_dgrad, - self.ub_bulk_wgrad, - self.ub_name, fp8_output, - self.fsdp_group, + fp8_grad, self, skip_fp8_weight_update, - self.symmetric_ar_type, - self.save_original_input, debug, ) out = linear_fn( @@ -1687,3 +1665,11 @@ def _customize_quantizers_float8_blockwise_scaling(self, fwd: bool, recipe: Reci self.quantizers["scaling_bwd"][ tex.FP8BwdTensors.GRAD_OUTPUT1 ].all_gather_usage = True + + +# disable torch dynamo just once to reduce wrapped function overhead on each +# forward call of te Linear. +if torch.__version__ >= "2": + Linear.forward._torchdynamo_disable = True + Linear.forward._torchdynamo_disable_msg = None + From e7248151ecf7877e3217b6bdd1fcf3e4b59d28ae Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Fri, 12 Dec 2025 22:47:01 +0000 Subject: [PATCH 02/10] minor additional change Signed-off-by: Varun Thumbe --- transformer_engine/pytorch/csrc/common.cpp | 2 +- transformer_engine/pytorch/csrc/common.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index f7a8540197f..3467223d2ac 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -30,7 +30,7 @@ NVTEShape getTensorShape(const at::Tensor& t) { return convertTorchShape(t.sizes()); } -NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape) { +NVTEShape convertTorchShape(const c10::IntArrayRef& torch_shape) { NVTEShape ret; ret.ndim = torch_shape.size(); constexpr int max_dimensions = sizeof(ret.data) / sizeof(size_t); diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 883c2a24cad..22061de4773 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -496,7 +496,7 @@ std::vector convertShape(const NVTEShape& shape); size_t roundup(const size_t value, const size_t multiple); -NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape); +NVTEShape convertTorchShape(const c10::IntArrayRef& torch_shape); std::vector convert_shape_back_from_fp4(const std::vector& shape, bool transpose); From b725f5b31d52adc800514898cf34f8c65851e6be Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 12 Dec 2025 23:02:13 +0000 Subject: [PATCH 03/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/csrc/common.cpp | 12 +-- transformer_engine/pytorch/csrc/common.h | 8 +- .../pytorch/csrc/extensions/gemm.cpp | 14 ++- transformer_engine/pytorch/csrc/quantizer.cpp | 3 +- transformer_engine/pytorch/module/linear.py | 94 ++++++++++--------- 5 files changed, 68 insertions(+), 63 deletions(-) diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index 3467223d2ac..c7f0975216b 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -26,9 +26,7 @@ std::vector convert_shape_back_from_fp4(const std::vector& shape return ret; } -NVTEShape getTensorShape(const at::Tensor& t) { - return convertTorchShape(t.sizes()); -} +NVTEShape getTensorShape(const at::Tensor& t) { return convertTorchShape(t.sizes()); } NVTEShape convertTorchShape(const c10::IntArrayRef& torch_shape) { NVTEShape ret; @@ -175,8 +173,8 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( } transformer_engine::TensorWrapper makeTransformerEngineTensor( - void* data_ptr, const NVTEShape& shape, const transformer_engine::DType type, - void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape, + void* data_ptr, const NVTEShape& shape, const transformer_engine::DType type, void* amax_ptr, + void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape, NVTEScalingMode scaling_mode) { TensorWrapper ret(scaling_mode); ret.set_rowwise_data(data_ptr, type, shape); @@ -229,8 +227,8 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( transformer_engine::TensorWrapper makeTransformerEngineTensor( void* data_ptr, void* columnwise_data_ptr, const NVTEShape& shape, - const NVTEShape& columnwise_shape, const transformer_engine::DType type, - void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr, + const NVTEShape& columnwise_shape, const transformer_engine::DType type, void* amax_ptr, + void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr, const NVTEShape& scale_inv_shape, const NVTEShape& columnwise_scale_inv_shape, NVTEScalingMode scaling_mode) { TensorWrapper ret(scaling_mode); diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 22061de4773..e6c22880323 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -433,8 +433,8 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); transformer_engine::TensorWrapper makeTransformerEngineTensor( - void* data_ptr, const NVTEShape& shape, const transformer_engine::DType type, - void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape, + void* data_ptr, const NVTEShape& shape, const transformer_engine::DType type, void* amax_ptr, + void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape, NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); transformer_engine::TensorWrapper makeTransformerEngineTensor( @@ -452,8 +452,8 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( transformer_engine::TensorWrapper makeTransformerEngineTensor( void* data_ptr, void* columnwise_data_ptr, const NVTEShape& shape, - const NVTEShape& columnwise_shape, const transformer_engine::DType type, - void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr, + const NVTEShape& columnwise_shape, const transformer_engine::DType type, void* amax_ptr, + void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr, const NVTEShape& scale_inv_shape, const NVTEShape& columnwise_scale_inv_shape, NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index f704864cb60..35b523b5192 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -367,16 +367,14 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type, const size_t A_shape_data[2] = {static_cast(A.size(0)), static_cast(A.size(1))}; const NVTEShape A_shape = nvte_make_shape(A_shape_data, 2); - auto te_A = makeTransformerEngineTensor( - A.data_ptr(), A_shape, A_type, - nullptr, nullptr, A_scale_inverse.data_ptr(), getTensorShape(A_scale_inverse), - nvte_scaling_modeA); + auto te_A = makeTransformerEngineTensor(A.data_ptr(), A_shape, A_type, nullptr, nullptr, + A_scale_inverse.data_ptr(), + getTensorShape(A_scale_inverse), nvte_scaling_modeA); const size_t B_shape_data[2] = {static_cast(B.size(0)), static_cast(B.size(1))}; const NVTEShape B_shape = nvte_make_shape(B_shape_data, 2); - auto te_B = makeTransformerEngineTensor( - B.data_ptr(), B_shape, B_type, - nullptr, nullptr, B_scale_inverse.data_ptr(), getTensorShape(B_scale_inverse), - nvte_scaling_modeB); + auto te_B = makeTransformerEngineTensor(B.data_ptr(), B_shape, B_type, nullptr, nullptr, + B_scale_inverse.data_ptr(), + getTensorShape(B_scale_inverse), nvte_scaling_modeB); // TODO: D_scale_inv cannot be nullptr when D_type is FP8. auto te_D = makeTransformerEngineTensor( D.data_ptr(), diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 3b94d38ac16..aa8416121d0 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1326,7 +1326,8 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( if (columnwise_data) { shape = convert_shape_back_from_fp4(convertShape(getTensorShape(*columnwise_data)), true); if (rowwise_data) { - auto expected_shape = convert_shape_back_from_fp4(convertShape(getTensorShape(*rowwise_data)), false); + auto expected_shape = + convert_shape_back_from_fp4(convertShape(getTensorShape(*rowwise_data)), false); NVTE_CHECK(shape == expected_shape, "NVFP4 row-wise data (shape=", expected_shape, ") and column-wise data (shape=", shape, ") do not match"); } diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 7557f5c5396..965367ac31b 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -105,46 +105,48 @@ def forward( debug, ) = non_tensor_args - (fp8, - fp8_calibration, - wgrad_store, - fuse_wgrad_accumulation, - tp_group, - tp_size, - sequence_parallel, - tensor_parallel, - activation_dtype, - parallel_mode, - ub_overlap_rs_fprop, - ub_overlap_ag_dgrad, - ub_overlap_ag_fprop, - ub_overlap_rs_dgrad, - ub_bulk_dgrad, - ub_bulk_wgrad, - ub_name, - fsdp_group, - symmetric_ar_type, - save_original_input - ) = (module.fp8, - module.fp8_calibration, - module.wgrad_store, - module.fuse_wgrad_accumulation, - module.tp_group, - module.tp_size, - module.sequence_parallel, - module.tp_size > 1, - module.activation_dtype, - module.parallel_mode, - module.ub_overlap_rs_fprop, - module.ub_overlap_ag_dgrad, - module.ub_overlap_ag_fprop, - module.ub_overlap_rs_dgrad, - module.ub_bulk_dgrad, - module.ub_bulk_wgrad, - module.ub_name, - module.fsdp_group, - module.symmetric_ar_type, - module.save_original_input, + ( + fp8, + fp8_calibration, + wgrad_store, + fuse_wgrad_accumulation, + tp_group, + tp_size, + sequence_parallel, + tensor_parallel, + activation_dtype, + parallel_mode, + ub_overlap_rs_fprop, + ub_overlap_ag_dgrad, + ub_overlap_ag_fprop, + ub_overlap_rs_dgrad, + ub_bulk_dgrad, + ub_bulk_wgrad, + ub_name, + fsdp_group, + symmetric_ar_type, + save_original_input, + ) = ( + module.fp8, + module.fp8_calibration, + module.wgrad_store, + module.fuse_wgrad_accumulation, + module.tp_group, + module.tp_size, + module.sequence_parallel, + module.tp_size > 1, + module.activation_dtype, + module.parallel_mode, + module.ub_overlap_rs_fprop, + module.ub_overlap_ag_dgrad, + module.ub_overlap_ag_fprop, + module.ub_overlap_rs_dgrad, + module.ub_bulk_dgrad, + module.ub_bulk_wgrad, + module.ub_name, + module.fsdp_group, + module.symmetric_ar_type, + module.save_original_input, ) quantizers = module._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) @@ -153,8 +155,14 @@ def forward( if module.no_debug_features_active(quantizers): debug = False quantizers = module._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) - (input_quantizer, weight_quantizer, output_quantizer, grad_input_quantizer, grad_weight_quantizer, grad_output_quantizer) = quantizers - + ( + input_quantizer, + weight_quantizer, + output_quantizer, + grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, + ) = quantizers # NVTX label for profiling nvtx_label = "transformer_engine._Linear.forward" @@ -1007,6 +1015,7 @@ def wgrad_gemm( None, ) + class Linear(TransformerEngineBaseModule): """Applies a linear transformation to the incoming data :math:`y = xA^T + b` @@ -1672,4 +1681,3 @@ def _customize_quantizers_float8_blockwise_scaling(self, fwd: bool, recipe: Reci if torch.__version__ >= "2": Linear.forward._torchdynamo_disable = True Linear.forward._torchdynamo_disable_msg = None - From 7b031d011331324b5f872f9e4038a9ace4a9f86c Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Tue, 23 Dec 2025 12:39:17 +0000 Subject: [PATCH 04/10] changes done to remove the additional nvte_make_shape calls Signed-off-by: Varun Thumbe --- transformer_engine/pytorch/csrc/common.cpp | 31 ++++-- transformer_engine/pytorch/csrc/common.h | 9 ++ .../pytorch/csrc/extensions/attention.cpp | 4 +- .../pytorch/csrc/extensions/bias.cpp | 9 +- .../pytorch/csrc/extensions/cast.cpp | 8 +- .../pytorch/csrc/extensions/gemm.cpp | 105 ++++++++++++------ .../pytorch/csrc/extensions/padding.cpp | 2 +- transformer_engine/pytorch/csrc/quantizer.cpp | 54 +++++---- .../pytorch/csrc/type_converters.cpp | 4 +- transformer_engine/pytorch/csrc/util.cpp | 22 ++-- 10 files changed, 155 insertions(+), 93 deletions(-) diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index c7f0975216b..b6a3853f6fd 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -26,7 +26,17 @@ std::vector convert_shape_back_from_fp4(const std::vector& shape return ret; } -NVTEShape getTensorShape(const at::Tensor& t) { return convertTorchShape(t.sizes()); } +NVTEShape getTensorShape(const at::Tensor& t) { + return convertTorchShape(t.sizes()); +} + +std::vector getTensorShapeVector(const at::Tensor& t) { + std::vector shape; + for (auto s : t.sizes()) { + shape.push_back(s); + } + return shape; +} NVTEShape convertTorchShape(const c10::IntArrayRef& torch_shape) { NVTEShape ret; @@ -113,10 +123,7 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor) { transformer_engine::DType dtype = GetTransformerEngineDType(tensor.scalar_type()); - std::vector shape; - for (auto s : tensor.sizes()) { - shape.push_back(s); - } + NVTEShape shape = getTensorShape(tensor); return makeTransformerEngineTensor(tensor.data_ptr(), shape, dtype); } @@ -179,7 +186,9 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( TensorWrapper ret(scaling_mode); ret.set_rowwise_data(data_ptr, type, shape); const size_t meta_shape_data[1] = {1}; - const NVTEShape meta_shape = nvte_make_shape(meta_shape_data, 1); + NVTEShape meta_shape; + meta_shape.ndim = 1; + meta_shape.data[0] = 1; ret.set_amax(amax_ptr, DType::kFloat32, meta_shape); ret.set_scale(scale_ptr, DType::kFloat32, meta_shape); auto scale_inv_dtype = @@ -194,8 +203,9 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( NVTEScalingMode scaling_mode) { TensorWrapper ret(scaling_mode); ret.set_rowwise_data(data_ptr, type, shape); - const size_t meta_shape_data[1] = {1}; - const NVTEShape meta_shape = nvte_make_shape(meta_shape_data, 1); + NVTEShape meta_shape; + meta_shape.ndim = 1; + meta_shape.data[0] = 1; ret.set_amax(amax_ptr, DType::kFloat32, meta_shape); ret.set_scale(scale_ptr, DType::kFloat32, meta_shape); auto scale_inv_dtype = @@ -234,8 +244,9 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( TensorWrapper ret(scaling_mode); ret.set_rowwise_data(data_ptr, type, shape); ret.set_columnwise_data(columnwise_data_ptr, type, columnwise_shape); - const size_t meta_shape_data[1] = {1}; - const NVTEShape meta_shape = nvte_make_shape(meta_shape_data, 1); + NVTEShape meta_shape; + meta_shape.ndim = 1; + meta_shape.data[0] = 1; ret.set_amax(amax_ptr, DType::kFloat32, meta_shape); ret.set_scale(scale_ptr, DType::kFloat32, meta_shape); auto scale_inv_dtype = (scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index e6c22880323..a9e7d895192 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -141,6 +141,13 @@ class NoneQuantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype, at::Tensor data) const; + std::pair create_tensor(const NVTEShape& shape, + DType dtype) const; + + /*! @brief Construct a tensor with pre-initialized data */ + std::pair create_tensor(const NVTEShape& shape, DType dtype, + at::Tensor data) const; + std::pair convert_and_update_tensor(py::object tensor) const override; void quantize(const TensorWrapper& input, TensorWrapper& out, @@ -341,6 +348,8 @@ std::unique_ptr convert_quantizer(py::handle quantizer); NVTEShape getTensorShape(const at::Tensor& t); +std::vector getTensorShapeVector(const at::Tensor& t); + transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, const std::string& fp8_recipe); diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 2480d9aba9b..804a4667d71 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -479,9 +479,9 @@ std::vector fused_attn_bwd( std::vector cu_seqlens_kv_padded_shape{cu_seqlens_kv_padded_sizes.begin(), cu_seqlens_kv_padded_sizes.end()}; te_cu_seqlens_q_padded = makeTransformerEngineTensor(cu_seqlens_q_padded.value().data_ptr(), - cu_seqlens_q_padded_shape, DType::kInt32); + nvte_make_shape(cu_seqlens_q_padded_shape.data(), cu_seqlens_q_padded_shape.size()), DType::kInt32); te_cu_seqlens_kv_padded = makeTransformerEngineTensor( - cu_seqlens_kv_padded.value().data_ptr(), cu_seqlens_kv_padded_shape, DType::kInt32); + cu_seqlens_kv_padded.value().data_ptr(), nvte_make_shape(cu_seqlens_kv_padded_shape.data(), cu_seqlens_kv_padded_shape.size()), DType::kInt32); } // convert auxiliary tensors from forward to NVTETensors diff --git a/transformer_engine/pytorch/csrc/extensions/bias.cpp b/transformer_engine/pytorch/csrc/extensions/bias.cpp index 2eef7438068..c3e89ed0856 100644 --- a/transformer_engine/pytorch/csrc/extensions/bias.cpp +++ b/transformer_engine/pytorch/csrc/extensions/bias.cpp @@ -26,8 +26,7 @@ std::vector bgrad_quantize(const at::Tensor &grad_output, py::handle // Grad output tensor auto grad_output_torch = grad_output.contiguous(); const TensorWrapper &grad_output_nvte = makeTransformerEngineTensor(grad_output_torch); - const auto shape_nvte = getTensorShape(grad_output_torch); - const auto shape = convertShape(shape_nvte); + const auto shape = getTensorShapeVector(grad_output_torch); auto grad_output_dtype = GetTransformerEngineDType(grad_output_torch.scalar_type()); // Construct grad bias tensor @@ -117,13 +116,11 @@ std::vector dact_dbias( // Grad output and activation input tensors grad_output_torch = grad_output_torch.contiguous(); const TensorWrapper &grad_output_nvte = makeTransformerEngineTensor(grad_output_torch); - const auto output_shape_nvte = getTensorShape(grad_output_torch); - const auto output_shape = convertShape(output_shape_nvte); + const auto output_shape = getTensorShapeVector(grad_output_torch); auto grad_output_dtype = GetTransformerEngineDType(grad_output_torch.scalar_type()); act_input_torch = act_input_torch.contiguous(); const TensorWrapper &act_input_nvte = makeTransformerEngineTensor(act_input_torch); - const auto input_shape_nvte = getTensorShape(act_input_torch); - const auto input_shape = convertShape(input_shape_nvte); + const auto input_shape = getTensorShapeVector(act_input_torch); // Construct tensors auto quantizer_cpp = convert_quantizer(quantizer_py); diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index b12da7542bb..3f107f443c7 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -334,12 +334,12 @@ std::tuple, std::vector> bulk_allocate_fp tensor_cpp_list.emplace_back(makeTransformerEngineTensor( rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_data_list[i].data_ptr() : nullptr, - rowwise_usage ? rowwise_data_shapes[i] : std::vector{}, - columnwise_usage ? columnwise_data_shapes[i] : std::vector{}, fp8_dtype, nullptr, + rowwise_usage ? nvte_make_shape(rowwise_data_shapes[i].data(), rowwise_data_shapes[i].size()) : NVTEShape{}, + columnwise_usage ? nvte_make_shape(columnwise_data_shapes[i].data(), columnwise_data_shapes[i].size()) : NVTEShape{}, fp8_dtype, nullptr, nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr, - rowwise_usage ? rowwise_scale_shapes[i] : std::vector{}, - columnwise_usage ? columnwise_scale_shapes[i] : std::vector{}, scaling_mode)); + rowwise_usage ? nvte_make_shape(rowwise_scale_shapes[i].data(), rowwise_scale_shapes[i].size()) : NVTEShape{}, + columnwise_usage ? nvte_make_shape(columnwise_scale_shapes[i].data(), columnwise_scale_shapes[i].size()) : NVTEShape{}, scaling_mode)); } return retval; diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 35b523b5192..11be2d4e2fe 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -40,8 +40,8 @@ bool is_low_precision(const DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; } -std::vector getGemmOutputShape(const NVTEShape& A_shape, const bool transa, - const NVTEShape& B_shape, const bool transb) { +NVTEShape getGemmOutputShape(const NVTEShape& A_shape, const bool transa, + const NVTEShape& B_shape, const bool transb) { // Flatten outer dims to get 2D matrices const size_t A0 = product(A_shape, 0, A_shape.ndim - 1); const size_t A1 = A_shape.data[A_shape.ndim - 1]; @@ -53,27 +53,29 @@ std::vector getGemmOutputShape(const NVTEShape& A_shape, const bool tran A0, ",", A1, "), transa=", transa, ", B=(", B0, ",", B1, "), transb=", transb, ")"); // Construct output dims - std::vector ret; + NVTEShape ret; + size_t idx = 0; if (transb) { - ret.emplace_back(B1); + ret.data[idx++] = B1; } else { // Unflatten B0 for (size_t i = 0; i < B_shape.ndim - 1; ++i) { - ret.emplace_back(B_shape.data[i]); + ret.data[idx++] = B_shape.data[i]; } } if (transa) { - ret.emplace_back(A0); + ret.data[idx++] = A0; } else { - ret.emplace_back(A1); + ret.data[idx++] = A1; } + ret.ndim = idx; return ret; } -bool checkGemmShape(const std::vector& expected, const NVTEShape& actual) { - if (expected.size() != actual.ndim) return false; - for (size_t i = 0; i < expected.size(); ++i) { - if (expected[i] != actual.data[i]) return false; +bool checkGemmShape(const NVTEShape& expected, const NVTEShape& actual) { + if (expected.ndim != actual.ndim) return false; + for (size_t i = 0; i < expected.ndim; ++i) { + if (expected.data[i] != actual.data[i]) return false; } return true; } @@ -117,7 +119,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans // Check tensor dimensions const auto& A_shape = A_tensor.shape(); const auto& B_shape = B_tensor.shape(); - const auto& D_shape = detail::getGemmOutputShape(A_shape, transa, B_shape, transb); + const NVTEShape D_shape = detail::getGemmOutputShape(A_shape, transa, B_shape, transb); NVTE_CHECK(A_shape.ndim >= 1, "Tensor A needs to have at least 1 dimension"); NVTE_CHECK(B_shape.ndim >= 1, "Tensor B needs to have at least 1 dimension"); @@ -138,7 +140,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans // Output tensor TensorWrapper D_tensor; if (D.is_none()) { - std::tie(D_tensor, D) = createOutputTensor(D_shape, output_dtype, quantizer); + std::tie(D_tensor, D) = createOutputTensor(convertShape(D_shape), output_dtype, quantizer); } else { D_tensor = makeTransformerEngineTensor(D, quantizer); NVTE_CHECK(detail::checkGemmShape(D_shape, D_tensor.shape()), @@ -168,7 +170,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans if (unfused_quantization_needed) { NoneQuantizer q{none}; - std::tie(unquantized_D_tensor, unquantized_out) = q.create_tensor(D_shape, output_dtype); + std::tie(unquantized_D_tensor, unquantized_out) = q.create_tensor(convertShape(D_shape), output_dtype); } TensorWrapper& out_tensor = unfused_quantization_needed ? unquantized_D_tensor : D_tensor; @@ -197,8 +199,8 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans auto dtype = GetATenDType(gelu_type); auto opts = torch::TensorOptions().dtype(dtype).device(torch::kCUDA); std::vector torch_shape; - for (auto v : D_shape) { - torch_shape.push_back(v); + for (size_t i = 0; i < D_shape.ndim; ++i) { + torch_shape.push_back(static_cast(D_shape.data[i])); } pre_gelu_out = at::empty(torch_shape, opts); } else { @@ -207,14 +209,21 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } } } - const auto gelu_shape = gelu ? D_shape : std::vector{0}; + NVTEShape gelu_shape; + gelu_shape.ndim = 1; + gelu_shape.data[0] = 0; + if (gelu) { + gelu_shape = D_shape; + } auto te_pre_gelu_out = makeTransformerEngineTensor(get_data_ptr(pre_gelu_out), gelu_shape, gelu_type); // Workspace - auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(), - std::vector{workspaceSize}, DType::kByte); + NVTEShape workspace_shape; + workspace_shape.ndim = 1; + workspace_shape.data[0] = workspaceSize; + auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(), workspace_shape, DType::kByte); // Set an external SM Margin to all the GEMMs. // This comes in handy when DP is overlapped with GEMMs @@ -263,8 +272,10 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans if (extra_output.has_value()) { extra_output_tensor = makeTransformerEngineTensor(*extra_output); } else { + NVTEShape extra_output_shape; + extra_output_shape.ndim = 0; extra_output_tensor = - makeTransformerEngineTensor(nullptr, std::vector{0}, DType::kByte); + makeTransformerEngineTensor(nullptr, extra_output_shape, DType::kByte); } // Direct GEMM call to the correct overlap @@ -367,28 +378,47 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type, const size_t A_shape_data[2] = {static_cast(A.size(0)), static_cast(A.size(1))}; const NVTEShape A_shape = nvte_make_shape(A_shape_data, 2); - auto te_A = makeTransformerEngineTensor(A.data_ptr(), A_shape, A_type, nullptr, nullptr, - A_scale_inverse.data_ptr(), - getTensorShape(A_scale_inverse), nvte_scaling_modeA); + auto te_A = makeTransformerEngineTensor( + A.data_ptr(), A_shape, A_type, + nullptr, nullptr, A_scale_inverse.data_ptr(), getTensorShape(A_scale_inverse), + nvte_scaling_modeA); const size_t B_shape_data[2] = {static_cast(B.size(0)), static_cast(B.size(1))}; const NVTEShape B_shape = nvte_make_shape(B_shape_data, 2); - auto te_B = makeTransformerEngineTensor(B.data_ptr(), B_shape, B_type, nullptr, nullptr, - B_scale_inverse.data_ptr(), - getTensorShape(B_scale_inverse), nvte_scaling_modeB); + auto te_B = makeTransformerEngineTensor( + B.data_ptr(), B_shape, B_type, + nullptr, nullptr, B_scale_inverse.data_ptr(), getTensorShape(B_scale_inverse), + nvte_scaling_modeB); // TODO: D_scale_inv cannot be nullptr when D_type is FP8. + NVTEShape D_shape, D_scale_inv_shape; + D_shape.ndim = 2; + D_scale_inv_shape.ndim = 1; + D_scale_inv_shape.data[0] = 1; + D_shape.data[0] = static_cast(D.size(0)); + D_shape.data[1] = static_cast(D.size(1)); auto te_D = makeTransformerEngineTensor( D.data_ptr(), - std::vector{static_cast(D.size(0)), static_cast(D.size(1))}, D_type, - D_amax.data_ptr(), D_scale.data_ptr(), nullptr); + D_shape, D_type, + D_amax.data_ptr(), D_scale.data_ptr(), nullptr, D_scale_inv_shape); + NVTEShape bias_shape; + bias_shape.ndim = 1; + bias_shape.data[0] = static_cast(bias.size(0)); auto te_bias = makeTransformerEngineTensor( - bias.data_ptr(), std::vector{static_cast(bias.size(0))}, bias_type); + bias.data_ptr(), bias_shape, bias_type); + NVTEShape counter_shape; + counter_shape.ndim = 1; + counter_shape.data[0] = static_cast(counter.size(0)); auto te_counter = makeTransformerEngineTensor( - counter.data_ptr(), std::vector{static_cast(counter.size(0))}, DType::kInt32); + counter.data_ptr(), counter_shape, DType::kInt32); - const auto gelu_shape = pre_gelu_out.data_ptr() == nullptr - ? std::vector{static_cast(pre_gelu_out.size(0))} - : std::vector{static_cast(pre_gelu_out.size(0)), - static_cast(pre_gelu_out.size(1))}; + NVTEShape gelu_shape; + if (pre_gelu_out.data_ptr() == nullptr) { + gelu_shape.ndim = 1; + gelu_shape.data[0] = static_cast(pre_gelu_out.size(0)); + } else { + gelu_shape.ndim = 2; + gelu_shape.data[0] = static_cast(pre_gelu_out.size(0)); + gelu_shape.data[1] = static_cast(pre_gelu_out.size(1)); + } auto te_pre_gelu_out = makeTransformerEngineTensor( pre_gelu_out.data_ptr(), gelu_shape, GetTransformerEngineDType(pre_gelu_out.scalar_type())); auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(), @@ -432,12 +462,13 @@ std::optional> te_general_grouped_gemm( // if there is single output at::Tensor out_tensor; - auto size_t_shape = + const NVTEShape nvte_D_shape = pytorch::detail::getGemmOutputShape(te_A.shape(), transa, te_B.shape(), transb); bool D_numel_is_zero = false; std::vector D_shape; - for (size_t t : size_t_shape) { - D_shape.push_back(t); + for (size_t j = 0; j < nvte_D_shape.ndim; ++j) { + const size_t t = nvte_D_shape.data[j]; + D_shape.push_back(static_cast(t)); if (t == 0) { D_numel_is_zero = true; } diff --git a/transformer_engine/pytorch/csrc/extensions/padding.cpp b/transformer_engine/pytorch/csrc/extensions/padding.cpp index d4b64a485c1..389308405b1 100644 --- a/transformer_engine/pytorch/csrc/extensions/padding.cpp +++ b/transformer_engine/pytorch/csrc/extensions/padding.cpp @@ -34,7 +34,7 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, input_row_list[tensor_id] * input.size(1) * input.element_size(); input_char_ptr += input_dptr_offset; d_input_ptr = reinterpret_cast(input_char_ptr); - + NVTEShape input_shape = {input_row_list[tensor_id], static_cast(input.size(1))}; input_shape_list.push_back({input_row_list[tensor_id], static_cast(input.size(1))}); input_type_list.push_back(GetTransformerEngineDType(input.scalar_type())); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index aa8416121d0..00f43433435 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -77,6 +77,16 @@ std::pair NoneQuantizer::create_tensor(const std::vec return create_tensor(shape, dtype, at::empty(shape_int64, opts)); } +std::pair NoneQuantizer::create_tensor(const NVTEShape& shape, + DType dtype) const { + std::vector shape_int64; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_int64.push_back(static_cast(shape.data[i])); + } + const auto opts = at::TensorOptions().dtype(GetATenDType(dtype)).device(torch::kCUDA); + return create_tensor(shape, dtype, at::empty(shape_int64, opts)); +} + std::pair NoneQuantizer::create_tensor(const std::vector& shape, DType dtype, at::Tensor data) const { @@ -86,6 +96,15 @@ std::pair NoneQuantizer::create_tensor(const std::vec return {std::move(out_cpp), py::cast(data)}; } +std::pair NoneQuantizer::create_tensor(const NVTEShape& shape, + DType dtype, + at::Tensor data) const { +TensorWrapper out_cpp; +out_cpp.set_rowwise_data(data.data_ptr(), dtype, shape); +set_quantization_params(&out_cpp); +return {std::move(out_cpp), py::cast(data)}; +} + std::pair NoneQuantizer::convert_and_update_tensor( py::object tensor) const { auto tensor_pyt = tensor.cast(); @@ -209,8 +228,7 @@ std::pair Float8Quantizer::convert_and_update_tensor( // Tensor dimensions std::vector shape; if (has_transpose) { - const auto transpose_shape_nvte = getTensorShape(*transpose_tensor); - const auto transpose_shape = convertShape(transpose_shape_nvte); + const auto transpose_shape = getTensorShapeVector(*transpose_tensor); if (transpose_shape.size() > 0) { for (size_t i = 1; i < transpose_shape.size(); ++i) { shape.push_back(transpose_shape[i]); @@ -218,13 +236,12 @@ std::pair Float8Quantizer::convert_and_update_tensor( shape.push_back(transpose_shape.front()); } if (has_data) { - const auto expected_shape_nvte = getTensorShape(*data_tensor); - const auto expected_shape = convertShape(expected_shape_nvte); + const auto expected_shape = getTensorShapeVector(*data_tensor); NVTE_CHECK(shape == expected_shape, "FP8 data (shape=", expected_shape, ") and transpose (shape=", transpose_shape, ") do not match"); } } else { // Already checked has_data == true - shape = convertShape(getTensorShape(*data_tensor)); + shape = getTensorShapeVector(*data_tensor); } // Coerce data tensor @@ -432,8 +449,7 @@ std::pair Float8CurrentScalingQuantizer::convert_and_ // Tensor dimensions std::vector shape; if (has_transpose) { - const auto transpose_shape_nvte = getTensorShape(*transpose_tensor); - const auto transpose_shape = convertShape(transpose_shape_nvte); + const auto transpose_shape = getTensorShapeVector(*transpose_tensor); if (transpose_shape.size() > 0) { for (size_t i = 1; i < transpose_shape.size(); ++i) { shape.push_back(transpose_shape[i]); @@ -441,13 +457,12 @@ std::pair Float8CurrentScalingQuantizer::convert_and_ shape.push_back(transpose_shape.front()); } if (has_data) { - const auto expected_shape_nvte = getTensorShape(*data_tensor); - const auto expected_shape = convertShape(expected_shape_nvte); + const auto expected_shape = getTensorShapeVector(*data_tensor); NVTE_CHECK(shape == expected_shape, "FP8 data (shape=", expected_shape, ") and transpose (shape=", transpose_shape, ") do not match"); } } else { // Already checked has_data == true - shape = convertShape(getTensorShape(*data_tensor)); + shape = getTensorShapeVector(*data_tensor); } // Coerce data tensor in Python tensor @@ -684,9 +699,9 @@ std::pair Float8BlockQuantizer::convert_and_update_te return std::vector(); } if (all_gather_usage) { - return convertShape(getTensorShape(*columnwise_data)); + return getTensorShapeVector(*columnwise_data); } - std::vector shape = convertShape(getTensorShape(*columnwise_data)); + std::vector shape = getTensorShapeVector(*columnwise_data); std::vector shape_transposed(shape.size()); for (size_t i = 0; i + 1 < shape.size(); ++i) { shape_transposed[i] = shape[i + 1]; @@ -698,7 +713,7 @@ std::pair Float8BlockQuantizer::convert_and_update_te }; std::vector shape; if (rowwise_data) { - shape = convertShape(getTensorShape(*rowwise_data)); + shape = getTensorShapeVector(*rowwise_data); if (columnwise_data) { auto expected_shape = get_columnwise_shape(all_gather_usage); NVTE_CHECK(shape == expected_shape, "BlockwiseFP8 row-wise data (shape=", shape, @@ -1008,14 +1023,14 @@ std::pair MXFP8Quantizer::convert_and_update_tensor( // Tensor dimensions std::vector shape; if (columnwise_data) { - shape = convertShape(getTensorShape(*columnwise_data)); + shape = getTensorShapeVector(*columnwise_data); if (rowwise_data) { - const auto expected_shape = convertShape(getTensorShape(*rowwise_data)); + const auto expected_shape = getTensorShapeVector(*rowwise_data); NVTE_CHECK(shape == expected_shape, "MXFP8 row-wise data (shape=", expected_shape, ") and column-wise data (shape=", shape, ") do not match"); } } else { // Already checked columnwise_data_tensor == true - shape = convertShape(getTensorShape(*rowwise_data)); + shape = getTensorShapeVector(*rowwise_data); } // Coerce row-wise data @@ -1324,15 +1339,14 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( // Tensor dimensions, shape means original shape std::vector shape; if (columnwise_data) { - shape = convert_shape_back_from_fp4(convertShape(getTensorShape(*columnwise_data)), true); + shape = convert_shape_back_from_fp4(getTensorShapeVector(*columnwise_data), true); if (rowwise_data) { - auto expected_shape = - convert_shape_back_from_fp4(convertShape(getTensorShape(*rowwise_data)), false); + auto expected_shape = convert_shape_back_from_fp4(getTensorShapeVector(*rowwise_data), false); NVTE_CHECK(shape == expected_shape, "NVFP4 row-wise data (shape=", expected_shape, ") and column-wise data (shape=", shape, ") do not match"); } } else { // Already checked columnwise_data_tensor == true - shape = convert_shape_back_from_fp4(convertShape(getTensorShape(*rowwise_data)), false); + shape = convert_shape_back_from_fp4(getTensorShapeVector(*rowwise_data), false); } size_t flat_first_dim = 1; diff --git a/transformer_engine/pytorch/csrc/type_converters.cpp b/transformer_engine/pytorch/csrc/type_converters.cpp index 780a08da7f8..48e9f06cc40 100644 --- a/transformer_engine/pytorch/csrc/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/type_converters.cpp @@ -132,7 +132,7 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) const auto &scale_inv = tensor.attr("_rowwise_scale_inv").cast(); const auto &amax_rowwise = tensor.attr("_amax_rowwise").cast(); ret.set_rowwise_data(data.data_ptr(), dtype, - convert_shape_back_from_fp4(convertShape(getTensorShape(data)), false)); + convert_shape_back_from_fp4(getTensorShapeVector(data), false)); ret.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E4M3, getTensorShape(scale_inv)); ret.set_amax(amax_rowwise.data_ptr(), DType::kFloat32, getTensorShape(amax_rowwise)); } @@ -143,7 +143,7 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) const auto &scale_inv = tensor.attr("_columnwise_scale_inv").cast(); const auto &amax_columnwise = tensor.attr("_amax_columnwise").cast(); ret.set_columnwise_data(data.data_ptr(), DType::kFloat4E2M1, - convert_shape_back_from_fp4(convertShape(getTensorShape(data)), false)); + convert_shape_back_from_fp4(getTensorShapeVector(data), false)); ret.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E4M3, getTensorShape(scale_inv)); ret.set_columnwise_amax(amax_columnwise.data_ptr(), DType::kFloat32, diff --git a/transformer_engine/pytorch/csrc/util.cpp b/transformer_engine/pytorch/csrc/util.cpp index 134185ac823..7fc04801e49 100644 --- a/transformer_engine/pytorch/csrc/util.cpp +++ b/transformer_engine/pytorch/csrc/util.cpp @@ -15,7 +15,7 @@ std::optional swizzle_scaling_factors(transformer_engine::TensorWrap if (input.scaling_mode() == NVTE_INVALID_SCALING) { NVTE_ERROR("Invalid scaling mode for swizzle."); - } else if (input.scaling_mode() != NVTE_MXFP8_1D_SCALING && + } else if (input.scaling_mode() != NVTE_MXFP8_1D_SCALING || input.scaling_mode() != NVTE_NVFP4_1D_SCALING) { return std::nullopt; } @@ -59,24 +59,24 @@ std::optional swizzle_scaling_factors(transformer_engine::TensorWrap (nvfp4) ? transformer_engine::DType::kFloat8E4M3 : transformer_engine::DType::kFloat8E8M0; if (rowwise) { - input_cu.set_rowwise_data(input.dptr(), input_dtype, input_shape); - input_cu.set_rowwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shape); - output_cu.set_rowwise_data(input.dptr(), input_dtype, input_shape); - output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape); + input_cu.set_rowwise_data(input.dptr(), input_dtype, nvte_input_shape); + input_cu.set_rowwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv.shape); + output_cu.set_rowwise_data(input.dptr(), input_dtype, nvte_input_shape); + output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv.shape); } else { - input_cu.set_columnwise_data(input.columnwise_dptr(), input_dtype, input_shape); - input_cu.set_columnwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shape); - output_cu.set_columnwise_data(input.columnwise_dptr(), input_dtype, input_shape); - output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape); + input_cu.set_columnwise_data(input.columnwise_dptr(), input_dtype, nvte_input_shape); + input_cu.set_columnwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv.shape); + output_cu.set_columnwise_data(input.columnwise_dptr(), input_dtype, nvte_input_shape); + output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv.shape); } // Launch kernel nvte_swizzle_scaling_factors(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); if (rowwise) { - input.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape); + input.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv.shape); } else { - input.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape); + input.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv.shape); } return swizzled_scale_inv; From d6ac3f1b28b2d3d2e8abbe0ab51cba7f782712b6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 24 Dec 2025 04:47:56 +0000 Subject: [PATCH 05/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/csrc/common.cpp | 4 +-- transformer_engine/pytorch/csrc/common.h | 9 +++-- .../pytorch/csrc/extensions/attention.cpp | 10 ++++-- .../pytorch/csrc/extensions/gemm.cpp | 36 +++++++++---------- transformer_engine/pytorch/csrc/quantizer.cpp | 14 ++++---- 5 files changed, 35 insertions(+), 38 deletions(-) diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index b6a3853f6fd..d4ce064facf 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -26,9 +26,7 @@ std::vector convert_shape_back_from_fp4(const std::vector& shape return ret; } -NVTEShape getTensorShape(const at::Tensor& t) { - return convertTorchShape(t.sizes()); -} +NVTEShape getTensorShape(const at::Tensor& t) { return convertTorchShape(t.sizes()); } std::vector getTensorShapeVector(const at::Tensor& t) { std::vector shape; diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index a9e7d895192..58e2acb6959 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -141,13 +141,12 @@ class NoneQuantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype, at::Tensor data) const; - std::pair create_tensor(const NVTEShape& shape, - DType dtype) const; - + std::pair create_tensor(const NVTEShape& shape, DType dtype) const; + /*! @brief Construct a tensor with pre-initialized data */ std::pair create_tensor(const NVTEShape& shape, DType dtype, - at::Tensor data) const; - + at::Tensor data) const; + std::pair convert_and_update_tensor(py::object tensor) const override; void quantize(const TensorWrapper& input, TensorWrapper& out, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 804a4667d71..1007dcb80c6 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -478,10 +478,14 @@ std::vector fused_attn_bwd( auto cu_seqlens_kv_padded_sizes = cu_seqlens_kv_padded.value().sizes().vec(); std::vector cu_seqlens_kv_padded_shape{cu_seqlens_kv_padded_sizes.begin(), cu_seqlens_kv_padded_sizes.end()}; - te_cu_seqlens_q_padded = makeTransformerEngineTensor(cu_seqlens_q_padded.value().data_ptr(), - nvte_make_shape(cu_seqlens_q_padded_shape.data(), cu_seqlens_q_padded_shape.size()), DType::kInt32); + te_cu_seqlens_q_padded = makeTransformerEngineTensor( + cu_seqlens_q_padded.value().data_ptr(), + nvte_make_shape(cu_seqlens_q_padded_shape.data(), cu_seqlens_q_padded_shape.size()), + DType::kInt32); te_cu_seqlens_kv_padded = makeTransformerEngineTensor( - cu_seqlens_kv_padded.value().data_ptr(), nvte_make_shape(cu_seqlens_kv_padded_shape.data(), cu_seqlens_kv_padded_shape.size()), DType::kInt32); + cu_seqlens_kv_padded.value().data_ptr(), + nvte_make_shape(cu_seqlens_kv_padded_shape.data(), cu_seqlens_kv_padded_shape.size()), + DType::kInt32); } // convert auxiliary tensors from forward to NVTETensors diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index b8928053d77..0e478ecd3ce 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -40,8 +40,8 @@ bool is_low_precision(const DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; } -NVTEShape getGemmOutputShape(const NVTEShape& A_shape, const bool transa, - const NVTEShape& B_shape, const bool transb) { +NVTEShape getGemmOutputShape(const NVTEShape& A_shape, const bool transa, const NVTEShape& B_shape, + const bool transb) { // Flatten outer dims to get 2D matrices const size_t A0 = A_shape.ndim > 0 ? product(A_shape, 0, A_shape.ndim - 1) : 1; const size_t A1 = A_shape.ndim > 0 ? A_shape.data[A_shape.ndim - 1] : 1; @@ -170,7 +170,8 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans if (unfused_quantization_needed) { NoneQuantizer q{none}; - std::tie(unquantized_D_tensor, unquantized_out) = q.create_tensor(convertShape(D_shape), output_dtype); + std::tie(unquantized_D_tensor, unquantized_out) = + q.create_tensor(convertShape(D_shape), output_dtype); } TensorWrapper& out_tensor = unfused_quantization_needed ? unquantized_D_tensor : D_tensor; @@ -223,7 +224,8 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans NVTEShape workspace_shape; workspace_shape.ndim = 1; workspace_shape.data[0] = workspaceSize; - auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(), workspace_shape, DType::kByte); + auto te_workspace = + makeTransformerEngineTensor(workspace.data_ptr(), workspace_shape, DType::kByte); // Set an external SM Margin to all the GEMMs. // This comes in handy when DP is overlapped with GEMMs @@ -378,16 +380,14 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type, const size_t A_shape_data[2] = {static_cast(A.size(0)), static_cast(A.size(1))}; const NVTEShape A_shape = nvte_make_shape(A_shape_data, 2); - auto te_A = makeTransformerEngineTensor( - A.data_ptr(), A_shape, A_type, - nullptr, nullptr, A_scale_inverse.data_ptr(), getTensorShape(A_scale_inverse), - nvte_scaling_modeA); + auto te_A = makeTransformerEngineTensor(A.data_ptr(), A_shape, A_type, nullptr, nullptr, + A_scale_inverse.data_ptr(), + getTensorShape(A_scale_inverse), nvte_scaling_modeA); const size_t B_shape_data[2] = {static_cast(B.size(0)), static_cast(B.size(1))}; const NVTEShape B_shape = nvte_make_shape(B_shape_data, 2); - auto te_B = makeTransformerEngineTensor( - B.data_ptr(), B_shape, B_type, - nullptr, nullptr, B_scale_inverse.data_ptr(), getTensorShape(B_scale_inverse), - nvte_scaling_modeB); + auto te_B = makeTransformerEngineTensor(B.data_ptr(), B_shape, B_type, nullptr, nullptr, + B_scale_inverse.data_ptr(), + getTensorShape(B_scale_inverse), nvte_scaling_modeB); // TODO: D_scale_inv cannot be nullptr when D_type is FP8. NVTEShape D_shape, D_scale_inv_shape; D_shape.ndim = 2; @@ -395,20 +395,16 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type, D_scale_inv_shape.data[0] = 1; D_shape.data[0] = static_cast(D.size(0)); D_shape.data[1] = static_cast(D.size(1)); - auto te_D = makeTransformerEngineTensor( - D.data_ptr(), - D_shape, D_type, - D_amax.data_ptr(), D_scale.data_ptr(), nullptr, D_scale_inv_shape); + auto te_D = makeTransformerEngineTensor(D.data_ptr(), D_shape, D_type, D_amax.data_ptr(), + D_scale.data_ptr(), nullptr, D_scale_inv_shape); NVTEShape bias_shape; bias_shape.ndim = 1; bias_shape.data[0] = static_cast(bias.size(0)); - auto te_bias = makeTransformerEngineTensor( - bias.data_ptr(), bias_shape, bias_type); + auto te_bias = makeTransformerEngineTensor(bias.data_ptr(), bias_shape, bias_type); NVTEShape counter_shape; counter_shape.ndim = 1; counter_shape.data[0] = static_cast(counter.size(0)); - auto te_counter = makeTransformerEngineTensor( - counter.data_ptr(), counter_shape, DType::kInt32); + auto te_counter = makeTransformerEngineTensor(counter.data_ptr(), counter_shape, DType::kInt32); NVTEShape gelu_shape; if (pre_gelu_out.data_ptr() == nullptr) { diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 64a6fa84766..0f8aa8381a8 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -78,7 +78,7 @@ std::pair NoneQuantizer::create_tensor(const std::vec } std::pair NoneQuantizer::create_tensor(const NVTEShape& shape, - DType dtype) const { + DType dtype) const { std::vector shape_int64; for (size_t i = 0; i < shape.ndim; ++i) { shape_int64.push_back(static_cast(shape.data[i])); @@ -97,12 +97,12 @@ std::pair NoneQuantizer::create_tensor(const std::vec } std::pair NoneQuantizer::create_tensor(const NVTEShape& shape, - DType dtype, - at::Tensor data) const { -TensorWrapper out_cpp; -out_cpp.set_rowwise_data(data.data_ptr(), dtype, shape); -set_quantization_params(&out_cpp); -return {std::move(out_cpp), py::cast(data)}; + DType dtype, + at::Tensor data) const { + TensorWrapper out_cpp; + out_cpp.set_rowwise_data(data.data_ptr(), dtype, shape); + set_quantization_params(&out_cpp); + return {std::move(out_cpp), py::cast(data)}; } std::pair NoneQuantizer::convert_and_update_tensor( From 51dd309ed4c4ece54350196d27351c163eb3ca9d Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Sun, 28 Dec 2025 08:49:18 +0000 Subject: [PATCH 06/10] some additional changes Signed-off-by: Varun Thumbe --- .../common/gemm/cublaslt_gemm.cu | 10 +++++---- .../common/transformer_engine.cpp | 2 +- transformer_engine/pytorch/csrc/common.cpp | 2 +- transformer_engine/pytorch/csrc/common.h | 2 +- transformer_engine/pytorch/csrc/quantizer.cpp | 22 ++++++++++--------- transformer_engine/pytorch/module/linear.py | 6 +---- 6 files changed, 22 insertions(+), 22 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 118bf193353..899f5fe5e6d 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -129,7 +129,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.Atype = A.data.dtype; ret.A_scale_inv = A.scale_inv.dptr; ret.lda = is_A_transposed ? k : m; - if (!nvte_is_non_tn_fp8_gemm_supported() && !is_A_transposed) { + int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); + if (!is_nvte_non_tn_fp8_gemm_supported && !is_A_transposed) { // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. if (A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype)) { ret.A = A.columnwise_data.dptr; @@ -140,7 +141,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla } else { NVTE_CHECK(!is_fp8_dtype(ret.Atype), "Input A is missing column-wise usage"); } - } else if (nvte_is_non_tn_fp8_gemm_supported() && !A.has_data()) { + } else if (is_nvte_non_tn_fp8_gemm_supported && !A.has_data()) { // Blackwell supports any GEMM layout for FP8, so we can use column-wise/transposed // data with the mirrored transpose-flag if we don't have row-wise data. NVTE_CHECK(A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype), @@ -220,7 +221,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.Btype = B.data.dtype; ret.B_scale_inv = B.scale_inv.dptr; ret.ldb = is_B_transposed ? n : k; - if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) { + int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); + if (!is_nvte_non_tn_fp8_gemm_supported && is_B_transposed) { // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. if (B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype)) { ret.B = B.columnwise_data.dptr; @@ -231,7 +233,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla } else { NVTE_CHECK(!is_fp8_dtype(ret.Btype), "Input B is missing column-wise usage"); } - } else if (nvte_is_non_tn_fp8_gemm_supported() && !B.has_data()) { + } else if (is_nvte_non_tn_fp8_gemm_supported && !B.has_data()) { // Blackwell supports any GEMM layout for FP8, so we can use column-wise/transposed // data with the mirrored transpose-flag if we don't have row-wise data. NVTE_CHECK(B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype), diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 4a140b4376d..0c4c9456c60 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -961,7 +961,7 @@ void nvte_destroy_quantization_config(NVTEQuantizationConfig config) { } int nvte_is_non_tn_fp8_gemm_supported() { - int num_devices = transformer_engine::cuda::num_devices(); + static int num_devices = transformer_engine::cuda::num_devices(); static std::vector cache(num_devices, -1); static std::vector flags(num_devices); int device_id = transformer_engine::cuda::current_device(); diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index d4ce064facf..04ae78c0b78 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -36,7 +36,7 @@ std::vector getTensorShapeVector(const at::Tensor& t) { return shape; } -NVTEShape convertTorchShape(const c10::IntArrayRef& torch_shape) { +NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape) { NVTEShape ret; ret.ndim = torch_shape.size(); constexpr int max_dimensions = sizeof(ret.data) / sizeof(size_t); diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 58e2acb6959..dd36e178ce7 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -504,7 +504,7 @@ std::vector convertShape(const NVTEShape& shape); size_t roundup(const size_t value, const size_t multiple); -NVTEShape convertTorchShape(const c10::IntArrayRef& torch_shape); +NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape); std::vector convert_shape_back_from_fp4(const std::vector& shape, bool transpose); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 0f8aa8381a8..4dc776b4546 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -141,8 +141,9 @@ std::pair Float8Quantizer::create_tensor( std::optional transpose, std::optional scale_inv) const { using namespace pybind11::literals; + int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); // Initialize data tensor - const bool with_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); + const bool with_data = rowwise_usage || is_non_tn_fp8_gemm_supported; if (with_data && !data) { const std::vector shape_int64(shape.begin(), shape.end()); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); @@ -153,7 +154,7 @@ std::pair Float8Quantizer::create_tensor( py::object data_py = with_data ? py::cast(*data) : py::none(); // Initialize transpose tensor - const bool with_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + const bool with_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported; if (with_transpose && !transpose) { const auto transpose_shape = make_transpose_shape(shape); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); @@ -204,10 +205,10 @@ std::pair Float8Quantizer::create_tensor( std::pair Float8Quantizer::convert_and_update_tensor( py::object tensor) const { NVTE_CHECK(detail::IsFloat8Tensor(tensor.ptr()), "Float8Quantizer must output to Float8Tensor."); - + int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); // Expected buffers - const bool need_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); - const bool need_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + const bool need_data = rowwise_usage || is_non_tn_fp8_gemm_supported; + const bool need_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported; NVTE_CHECK(need_data || need_transpose, "Invalid usages for Float8Quantizer."); // Extract buffers from Python tensor @@ -347,7 +348,8 @@ std::pair Float8CurrentScalingQuantizer::create_tenso // Initialize data tensor at::Tensor data_tensor; - const bool with_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); + int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); + const bool with_data = rowwise_usage || is_non_tn_fp8_gemm_supported; if (with_data) { const std::vector shape_int64(shape.begin(), shape.end()); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); @@ -356,7 +358,7 @@ std::pair Float8CurrentScalingQuantizer::create_tenso // Initialize transpose tensor at::Tensor transpose_tensor; - const bool with_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + const bool with_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported; if (with_transpose) { const auto transpose_shape = make_transpose_shape(shape); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); @@ -425,10 +427,10 @@ std::pair Float8CurrentScalingQuantizer::convert_and_ py::object tensor) const { NVTE_CHECK(detail::IsFloat8Tensor(tensor.ptr()), "Float8CurrentScalingQuantizer must output to Float8Tensor."); - + int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); // Expected buffers - const bool need_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); - const bool need_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + const bool need_data = rowwise_usage || is_non_tn_fp8_gemm_supported; + const bool need_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported; NVTE_CHECK(need_data || need_transpose, "Invalid quantizer usages."); // Extract buffers from Python tensor diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 965367ac31b..f71f780be3c 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -1377,6 +1377,7 @@ def reset_parameters(self, defer_init=False): elif self.parallel_mode == "column": set_tensor_model_parallel_attributes(getattr(self, bias), True, 0, 1) + @no_torch_dynamo() def forward( self, inp: torch.Tensor, @@ -1676,8 +1677,3 @@ def _customize_quantizers_float8_blockwise_scaling(self, fwd: bool, recipe: Reci ].all_gather_usage = True -# disable torch dynamo just once to reduce wrapped function overhead on each -# forward call of te Linear. -if torch.__version__ >= "2": - Linear.forward._torchdynamo_disable = True - Linear.forward._torchdynamo_disable_msg = None From 425182f7db96dab01e078fd54252415bd179cbd8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 28 Dec 2025 08:50:00 +0000 Subject: [PATCH 07/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/linear.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index f71f780be3c..4722d51f59d 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -1675,5 +1675,3 @@ def _customize_quantizers_float8_blockwise_scaling(self, fwd: bool, recipe: Reci self.quantizers["scaling_bwd"][ tex.FP8BwdTensors.GRAD_OUTPUT1 ].all_gather_usage = True - - From f8748c068313a5e1d6479a1669762560754296c0 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Sun, 28 Dec 2025 09:50:55 +0000 Subject: [PATCH 08/10] some optimizations Signed-off-by: Varun Thumbe --- transformer_engine/pytorch/csrc/common.h | 20 +++- .../pytorch/csrc/extensions/gemm.cpp | 10 +- .../pytorch/csrc/extensions/transpose.cpp | 3 +- transformer_engine/pytorch/csrc/quantizer.cpp | 51 +++++++- transformer_engine/pytorch/module/linear.py | 110 ++++++++++-------- 5 files changed, 127 insertions(+), 67 deletions(-) diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index dd36e178ce7..c5670357fb9 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -102,7 +102,8 @@ class Quantizer { /*! @brief Construct a tensor with uninitialized data */ virtual std::pair create_tensor(const std::vector& shape, DType dtype) const = 0; - + virtual std::pair create_tensor(const NVTEShape& shape, + DType dtype) const = 0; /*! @brief Convert a PyTorch tensor into a Transformer Engine C++ tensor * * The PyTorch tensor's attributes are modified to match the @@ -141,7 +142,7 @@ class NoneQuantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype, at::Tensor data) const; - std::pair create_tensor(const NVTEShape& shape, DType dtype) const; + std::pair create_tensor(const NVTEShape& shape, DType dtype) const override; /*! @brief Construct a tensor with pre-initialized data */ std::pair create_tensor(const NVTEShape& shape, DType dtype, @@ -168,7 +169,8 @@ class Float8Quantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; - + std::pair create_tensor(const NVTEShape& shape, + DType dtype) const override; /*! @brief Construct a tensor with pre-initialized data */ std::pair create_tensor(const std::vector& shape, DType dtype, std::optional data, @@ -200,7 +202,8 @@ class Float8CurrentScalingQuantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; - + std::pair create_tensor(const NVTEShape& shape, + DType dtype) const override; /*! @brief Construct an unquantized tensor that shares the quantizer's amax pointer. * * The amax is zeroed out. Most TE kernels that output amax expect @@ -259,6 +262,8 @@ class Float8BlockQuantizer : public Quantizer { // and optionally columnwise usage. std::pair create_tensor(const std::vector& shape, DType dtype) const override; + std::pair create_tensor(const NVTEShape& shape, + DType dtype) const override; std::pair convert_and_update_tensor(py::object shape) const override; @@ -280,6 +285,8 @@ class MXFP8Quantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; + std::pair create_tensor(const NVTEShape& shape, + DType dtype) const override; std::pair convert_and_update_tensor(py::object shape) const override; @@ -314,7 +321,8 @@ class NVFP4Quantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; - + std::pair create_tensor(const NVTEShape& shape, + DType dtype) const override; /*! @brief Construct an unquantized tensor that shares NVFP4 tensor's amax pointer * * The amax is zeroed out. Most TE kernels that output amax expect @@ -560,4 +568,4 @@ inline string to_string(const NVTEShape& s) { } } // namespace std -#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_ +#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_ \ No newline at end of file diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 0e478ecd3ce..fa01af53fe8 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -82,7 +82,7 @@ bool checkGemmShape(const NVTEShape& expected, const NVTEShape& actual) { } // namespace detail -std::pair createOutputTensor(const std::vector& shape, +std::pair createOutputTensor(const NVTEShape& shape, DType dtype, py::handle quantizer) { std::unique_ptr my_quantizer = convert_quantizer(quantizer); return my_quantizer->create_tensor(shape, dtype); @@ -119,7 +119,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans // Check tensor dimensions const auto& A_shape = A_tensor.shape(); const auto& B_shape = B_tensor.shape(); - const NVTEShape D_shape = detail::getGemmOutputShape(A_shape, transa, B_shape, transb); + const auto& D_shape = detail::getGemmOutputShape(A_shape, transa, B_shape, transb); NVTE_CHECK(A_shape.ndim >= 1, "Tensor A needs to have at least 1 dimension"); NVTE_CHECK(B_shape.ndim >= 1, "Tensor B needs to have at least 1 dimension"); @@ -140,7 +140,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans // Output tensor TensorWrapper D_tensor; if (D.is_none()) { - std::tie(D_tensor, D) = createOutputTensor(convertShape(D_shape), output_dtype, quantizer); + std::tie(D_tensor, D) = createOutputTensor(D_shape, output_dtype, quantizer); } else { D_tensor = makeTransformerEngineTensor(D, quantizer); NVTE_CHECK(detail::checkGemmShape(D_shape, D_tensor.shape()), @@ -171,7 +171,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans if (unfused_quantization_needed) { NoneQuantizer q{none}; std::tie(unquantized_D_tensor, unquantized_out) = - q.create_tensor(convertShape(D_shape), output_dtype); + q.create_tensor(D_shape, output_dtype); } TensorWrapper& out_tensor = unfused_quantization_needed ? unquantized_D_tensor : D_tensor; @@ -596,4 +596,4 @@ std::optional> te_general_grouped_gemm( return bias; } -} // namespace transformer_engine::pytorch +} // namespace transformer_engine::pytorch \ No newline at end of file diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cpp b/transformer_engine/pytorch/csrc/extensions/transpose.cpp index 5ace996afcc..55c7fd57d79 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cpp +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cpp @@ -61,8 +61,7 @@ at::Tensor swap_first_dims(at::Tensor tensor, std::optional out) { // Allocate output tensor if needed if (!out) { - const auto in_shape_nvte = getTensorShape(input); - const auto in_shape = convertShape(in_shape_nvte); + const auto in_shape = getTensorShapeVector(input); NVTE_CHECK(in_shape.size() >= 2, "Invalid input tensor dimensions (shape=", in_shape, ")"); std::vector out_shape_int64(in_shape.begin(), in_shape.end()); out_shape_int64[0] = static_cast(in_shape[1]); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 4dc776b4546..6345ae3894c 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -135,12 +135,19 @@ std::pair Float8Quantizer::create_tensor( at::Tensor scale_inv = at::empty(std::vector{1}, opts); return create_tensor(shape, dtype, std::nullopt, std::nullopt, std::move(scale_inv)); } +std::pair Float8Quantizer::create_tensor( + const NVTEShape& shape, DType dtype) const { + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); + } + return create_tensor(shape_vec, dtype); +} std::pair Float8Quantizer::create_tensor( const std::vector& shape, DType dtype, std::optional data, std::optional transpose, std::optional scale_inv) const { using namespace pybind11::literals; - int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); // Initialize data tensor const bool with_data = rowwise_usage || is_non_tn_fp8_gemm_supported; @@ -205,8 +212,9 @@ std::pair Float8Quantizer::create_tensor( std::pair Float8Quantizer::convert_and_update_tensor( py::object tensor) const { NVTE_CHECK(detail::IsFloat8Tensor(tensor.ptr()), "Float8Quantizer must output to Float8Tensor."); - int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); + // Expected buffers + int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); const bool need_data = rowwise_usage || is_non_tn_fp8_gemm_supported; const bool need_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported; NVTE_CHECK(need_data || need_transpose, "Invalid usages for Float8Quantizer."); @@ -341,7 +349,14 @@ void Float8CurrentScalingQuantizer::set_quantization_params(TensorWrapper* tenso tensor->set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), getTensorShape(amax)); } - +std::pair Float8CurrentScalingQuantizer::create_tensor( + const NVTEShape& shape, DType dtype) const { + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); + } + return create_tensor(shape_vec, dtype); +} std::pair Float8CurrentScalingQuantizer::create_tensor( const std::vector& shape, DType dtype) const { using namespace pybind11::literals; @@ -427,8 +442,9 @@ std::pair Float8CurrentScalingQuantizer::convert_and_ py::object tensor) const { NVTE_CHECK(detail::IsFloat8Tensor(tensor.ptr()), "Float8CurrentScalingQuantizer must output to Float8Tensor."); - int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); + // Expected buffers + int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); const bool need_data = rowwise_usage || is_non_tn_fp8_gemm_supported; const bool need_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported; NVTE_CHECK(need_data || need_transpose, "Invalid quantizer usages."); @@ -581,6 +597,15 @@ Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quanti void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const {} +std::pair Float8BlockQuantizer::create_tensor( + const NVTEShape& shape, DType dtype) const { + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); + } + return create_tensor(shape_vec, dtype); +} + std::pair Float8BlockQuantizer::create_tensor( const std::vector& shape, DType dtype) const { using namespace pybind11::literals; @@ -923,6 +948,15 @@ MXFP8Quantizer::MXFP8Quantizer(const py::handle& quantizer) : Quantizer(quantize void MXFP8Quantizer::set_quantization_params(TensorWrapper* tensor) const {} +std::pair MXFP8Quantizer::create_tensor( + const NVTEShape& shape, DType dtype) const { + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); + } + return create_tensor(shape_vec, dtype); +} + std::pair MXFP8Quantizer::create_tensor(const std::vector& shape, DType dtype) const { using namespace pybind11::literals; @@ -1189,7 +1223,14 @@ void NVFP4Quantizer::set_quantization_params(TensorWrapper* tensor) const { tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), columnwise_data.shape); } - +std::pair NVFP4Quantizer::create_tensor( + const NVTEShape& shape, DType dtype) const { + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); + } + return create_tensor(shape_vec, dtype); +} std::pair NVFP4Quantizer::create_tensor(const std::vector& shape, DType dtype) const { using namespace pybind11::literals; diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index f71f780be3c..09b01e288a5 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -96,26 +96,24 @@ def forward( ( is_first_microbatch, - cpu_offloading, - is_grad_enabled, - fp8_output, - fp8_grad, - module, - skip_fp8_weight_update, - debug, - ) = non_tensor_args - - ( fp8, fp8_calibration, wgrad_store, + input_quantizer, + weight_quantizer, + output_quantizer, + grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, fuse_wgrad_accumulation, + cpu_offloading, tp_group, tp_size, sequence_parallel, tensor_parallel, activation_dtype, parallel_mode, + is_grad_enabled, ub_overlap_rs_fprop, ub_overlap_ag_dgrad, ub_overlap_ag_fprop, @@ -123,46 +121,14 @@ def forward( ub_bulk_dgrad, ub_bulk_wgrad, ub_name, + fp8_output, # pylint: disable=unused-variable fsdp_group, + module, + skip_fp8_weight_update, symmetric_ar_type, save_original_input, - ) = ( - module.fp8, - module.fp8_calibration, - module.wgrad_store, - module.fuse_wgrad_accumulation, - module.tp_group, - module.tp_size, - module.sequence_parallel, - module.tp_size > 1, - module.activation_dtype, - module.parallel_mode, - module.ub_overlap_rs_fprop, - module.ub_overlap_ag_dgrad, - module.ub_overlap_ag_fprop, - module.ub_overlap_rs_dgrad, - module.ub_bulk_dgrad, - module.ub_bulk_wgrad, - module.ub_name, - module.fsdp_group, - module.symmetric_ar_type, - module.save_original_input, - ) - quantizers = module._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) - - if debug: - quantizers = module._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) - if module.no_debug_features_active(quantizers): - debug = False - quantizers = module._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) - ( - input_quantizer, - weight_quantizer, - output_quantizer, - grad_input_quantizer, - grad_weight_quantizer, - grad_output_quantizer, - ) = quantizers + debug, + ) = non_tensor_args # NVTX label for profiling nvtx_label = "transformer_engine._Linear.forward" @@ -1377,7 +1343,7 @@ def reset_parameters(self, defer_init=False): elif self.parallel_mode == "column": set_tensor_model_parallel_attributes(getattr(self, bias), True, 0, 1) - @no_torch_dynamo() + @no_torch_dynamo(recursive=False) def forward( self, inp: torch.Tensor, @@ -1435,7 +1401,28 @@ def forward( inp, allow_non_contiguous=isinstance(inp, QuantizedTensor), ) as inp: + weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() + + quantizers = ( + self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) + if not debug + else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) + ) + if debug: + if self.no_debug_features_active(quantizers): + debug = False + quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) + + ( + input_quantizer, + weight_quantizer, + output_quantizer, + grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, + ) = quantizers + if is_grad_enabled: linear_fn = _Linear.apply autograd_ctx = [] @@ -1445,12 +1432,37 @@ def forward( non_tensor_args = ( is_first_microbatch, + self.fp8, + self.fp8_calibration, + self.wgrad_store, + input_quantizer, + weight_quantizer, + output_quantizer, + grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, + self.fuse_wgrad_accumulation, is_cpu_offload_enabled(), + self.tp_group, + self.tp_size, + self.sequence_parallel, + self.tp_size > 1, + self.activation_dtype, + self.parallel_mode, is_grad_enabled, + self.ub_overlap_rs_fprop, + self.ub_overlap_ag_dgrad, + self.ub_overlap_ag_fprop, + self.ub_overlap_rs_dgrad, + self.ub_bulk_dgrad, + self.ub_bulk_wgrad, + self.ub_name, fp8_output, - fp8_grad, + self.fsdp_group, self, skip_fp8_weight_update, + self.symmetric_ar_type, + self.save_original_input, debug, ) out = linear_fn( From 72886287f48b1501d74ea3d10724c4249806fea7 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Tue, 30 Dec 2025 12:40:16 +0000 Subject: [PATCH 09/10] got rid of vector in makeTransformerEngineTensor Signed-off-by: Varun Thumbe --- .../transformer_engine/transformer_engine.h | 63 +++++++++ transformer_engine/pytorch/csrc/common.cpp | 60 +-------- transformer_engine/pytorch/csrc/common.h | 36 ++--- .../pytorch/csrc/extensions/attention.cpp | 73 +++++----- .../pytorch/csrc/extensions/cast.cpp | 126 ++++++++++-------- .../pytorch/csrc/extensions/gemm.cpp | 23 +++- .../pytorch/csrc/extensions/padding.cpp | 36 +++-- .../pytorch/csrc/extensions/permutation.cpp | 47 +++++-- .../pytorch/csrc/extensions/transpose.cpp | 11 +- transformer_engine/pytorch/csrc/quantizer.cpp | 112 ++++++++++------ 10 files changed, 343 insertions(+), 244 deletions(-) diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 19cb646be29..f9eb244cd9c 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -518,6 +518,69 @@ NVTEShape nvte_get_grouped_tensor_logical_shape(const NVTEGroupedTensor tensor); */ namespace transformer_engine { +/*! \class NVTEShapeWrapper + * \brief C++ wrapper for NVTEShape with container-like interface. + */ +class NVTEShapeWrapper { + private: + NVTEShape data; + + public: + // Default constructor + NVTEShapeWrapper() { + data.ndim = 0; + } + + // Constructor from NVTEShape (direct assignment by reference) + NVTEShapeWrapper(const NVTEShape& shape) { + data = shape; + } + + // Constructor from vector (creates a copy) + template NVTEShapeWrapper(const std::vector& shape_vec) { + data.ndim = shape_vec.size(); + for (size_t i = 0; i < data.ndim; ++i) { + data.data[i] = static_cast(shape_vec[i]); + } + } + + operator NVTEShape&() { return data; } + operator const NVTEShape&() const { return data; } + + // Iterator support + size_t* begin() { return data.data; } + const size_t* begin() const { return data.data; } + size_t* end() { return data.data + data.ndim; } + const size_t* end() const { return data.data + data.ndim; } + + // Index access + size_t& operator[](size_t idx) { return data.data[idx]; } + const size_t& operator[](size_t idx) const { return data.data[idx]; } + + // Back access + size_t& back() { return data.data[data.ndim - 1]; } + const size_t& back() const { return data.data[data.ndim - 1]; } + + // Size access + size_t size() const { return data.ndim; } + bool empty() const { return data.ndim == 0; } + + // Container operations + void push_back(size_t value) { + if (data.ndim < 15) { + data.data[data.ndim++] = value; + } + } + + void clear() { data.ndim = 0; } + + void resize(size_t new_size) { + if (new_size <= 15) { + data.ndim = new_size; + } + } +}; + /*! \enum DType * \brief TE datatype. */ diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index 04ae78c0b78..66b1e227c25 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -114,11 +114,6 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( return transformer_engine::TensorWrapper(data_ptr, shape, type); } -transformer_engine::TensorWrapper makeTransformerEngineTensor( - void* data_ptr, const std::vector& shape, const transformer_engine::DType type) { - return transformer_engine::TensorWrapper(data_ptr, shape, type); -} - transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor) { transformer_engine::DType dtype = GetTransformerEngineDType(tensor.scalar_type()); NVTEShape shape = getTensorShape(tensor); @@ -162,21 +157,6 @@ makeTransformerEngineTensorList(std::vector> at_tensor_l std::move(nvte_tensor_list_ptrs), num_lists, num_tensors); } -transformer_engine::TensorWrapper makeTransformerEngineTensor( - void* data_ptr, const std::vector& shape, const transformer_engine::DType type, - void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, std::vector scale_inv_shape, - NVTEScalingMode scaling_mode) { - TensorWrapper ret(scaling_mode); - ret.set_rowwise_data(data_ptr, type, shape); - const std::vector meta_shape{1}; - ret.set_amax(amax_ptr, DType::kFloat32, meta_shape); - ret.set_scale(scale_ptr, DType::kFloat32, meta_shape); - auto scale_inv_dtype = - (scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 : DType::kFloat32; - ret.set_rowwise_scale_inv(scale_inv_ptr, scale_inv_dtype, scale_inv_shape); - return ret; -} - transformer_engine::TensorWrapper makeTransformerEngineTensor( void* data_ptr, const NVTEShape& shape, const transformer_engine::DType type, void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape, @@ -195,43 +175,6 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( return ret; } -transformer_engine::TensorWrapper makeTransformerEngineTensor( - void* data_ptr, const std::vector& shape, const transformer_engine::DType type, - void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape, - NVTEScalingMode scaling_mode) { - TensorWrapper ret(scaling_mode); - ret.set_rowwise_data(data_ptr, type, shape); - NVTEShape meta_shape; - meta_shape.ndim = 1; - meta_shape.data[0] = 1; - ret.set_amax(amax_ptr, DType::kFloat32, meta_shape); - ret.set_scale(scale_ptr, DType::kFloat32, meta_shape); - auto scale_inv_dtype = - (scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 : DType::kFloat32; - ret.set_rowwise_scale_inv(scale_inv_ptr, scale_inv_dtype, scale_inv_shape); - return ret; -} - -transformer_engine::TensorWrapper makeTransformerEngineTensor( - void* data_ptr, void* columnwise_data_ptr, const std::vector& shape, - const std::vector& columnwise_shape, const transformer_engine::DType type, - void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr, - const std::vector& scale_inv_shape, - const std::vector& columnwise_scale_inv_shape, NVTEScalingMode scaling_mode) { - TensorWrapper ret(scaling_mode); - ret.set_rowwise_data(data_ptr, type, shape); - ret.set_columnwise_data(columnwise_data_ptr, type, columnwise_shape); - const std::vector meta_shape{1}; - ret.set_amax(amax_ptr, DType::kFloat32, meta_shape); - ret.set_scale(scale_ptr, DType::kFloat32, meta_shape); - auto scale_inv_dtype = (scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 - : (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat8E4M3 - : DType::kFloat32; - ret.set_rowwise_scale_inv(scale_inv_ptr, scale_inv_dtype, scale_inv_shape); - ret.set_columnwise_scale_inv(columnwise_scale_inv_ptr, scale_inv_dtype, - columnwise_scale_inv_shape); - return ret; -} transformer_engine::TensorWrapper makeTransformerEngineTensor( void* data_ptr, void* columnwise_data_ptr, const NVTEShape& shape, @@ -287,6 +230,9 @@ template size_t product(const std::vector& shape); template int64_t product(const std::vector& shape); size_t product(const NVTEShape& shape, size_t begin, size_t end) { + if(end == -1) { + end = shape.ndim; + } NVTE_CHECK(begin <= end && end <= shape.ndim, "Attempted to access entries ", begin, " to ", end, " in a shape with ", shape.ndim, " entries"); size_t ret = 1; diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index c5670357fb9..b703cbc6810 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -271,6 +271,11 @@ class Float8BlockQuantizer : public Quantizer { const std::optional& noop_flag = std::nullopt) override; std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; + NVTEShapeWrapper get_scale_shape(const NVTEShapeWrapper& shape, bool columnwise) const; + + private: + template + ShapeT get_scale_shape_impl(const ShapeT& shape, bool columnwise) const; }; class MXFP8Quantizer : public Quantizer { @@ -294,6 +299,11 @@ class MXFP8Quantizer : public Quantizer { const std::optional& noop_flag = std::nullopt) override; std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; + NVTEShapeWrapper get_scale_shape(const NVTEShapeWrapper& shape, bool columnwise) const; + + private: + template + ShapeT get_scale_shape_impl(const ShapeT& shape, bool columnwise) const; }; class NVFP4Quantizer : public Quantizer { @@ -345,8 +355,11 @@ class NVFP4Quantizer : public Quantizer { void quantize_with_amax(TensorWrapper& input, TensorWrapper& out); std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; + NVTEShapeWrapper get_scale_shape(const NVTEShapeWrapper& shape, bool columnwise) const; private: + template + ShapeT get_scale_shape_impl(const ShapeT& shape, bool columnwise) const; void quantize_impl(const TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag, bool compute_amax); }; @@ -439,32 +452,11 @@ inline transformer_engine::DType GetTransformerEngineDType(int DType_value) { return static_cast(DType_value); } -transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr, - const std::vector& shape, - const transformer_engine::DType type); - -transformer_engine::TensorWrapper makeTransformerEngineTensor( - void* data_ptr, const std::vector& shape, const transformer_engine::DType type, - void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, std::vector scale_inv_shape = {1}, - NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); - transformer_engine::TensorWrapper makeTransformerEngineTensor( void* data_ptr, const NVTEShape& shape, const transformer_engine::DType type, void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape, NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); -transformer_engine::TensorWrapper makeTransformerEngineTensor( - void* data_ptr, const std::vector& shape, const transformer_engine::DType type, - void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape, - NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); - -transformer_engine::TensorWrapper makeTransformerEngineTensor( - void* data_ptr, void* columnwise_data_ptr, const std::vector& shape, - const std::vector& columnwise_shape, const transformer_engine::DType type, - void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr, - const std::vector& scale_inv_shape = {1}, - const std::vector& columnwise_scale_inv_shape = {1}, - NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); transformer_engine::TensorWrapper makeTransformerEngineTensor( void* data_ptr, void* columnwise_data_ptr, const NVTEShape& shape, @@ -492,7 +484,7 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( template T product(const std::vector& shape); -size_t product(const NVTEShape& shape, size_t begin, size_t end); +size_t product(const NVTEShape& shape, size_t begin=0, size_t end=-1); std::vector nvte_shape_to_vector(const NVTEShape& nvte_shape); diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 1007dcb80c6..6d6effce6d9 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -163,52 +163,52 @@ std::vector fused_attn_fwd( } if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) { auto bias_sizes = Bias.value().sizes().vec(); - std::vector bias_shape{bias_sizes.begin(), bias_sizes.end()}; - te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), bias_shape, DType::kFloat32); + NVTEShapeWrapper bias_shape{bias_sizes}; + te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), static_cast(bias_shape), DType::kFloat32); } auto cu_seqlens_q_sizes = cu_seqlens_q.sizes().vec(); - std::vector cu_seqlens_q_shape{cu_seqlens_q_sizes.begin(), cu_seqlens_q_sizes.end()}; + NVTEShapeWrapper cu_seqlens_q_shape{cu_seqlens_q_sizes}; auto cu_seqlens_kv_sizes = cu_seqlens_kv.sizes().vec(); - std::vector cu_seqlens_kv_shape{cu_seqlens_kv_sizes.begin(), cu_seqlens_kv_sizes.end()}; + NVTEShapeWrapper cu_seqlens_kv_shape{cu_seqlens_kv_sizes}; te_cu_seqlens_q = - makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), cu_seqlens_q_shape, DType::kInt32); + makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), static_cast(cu_seqlens_q_shape), DType::kInt32); te_cu_seqlens_kv = - makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape, DType::kInt32); + makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), static_cast(cu_seqlens_kv_shape), DType::kInt32); if ((cu_seqlens_q_padded.has_value()) && (cu_seqlens_kv_padded.has_value())) { auto cu_seqlens_q_padded_sizes = cu_seqlens_q_padded.value().sizes().vec(); - std::vector cu_seqlens_q_padded_shape{cu_seqlens_q_padded_sizes.begin(), - cu_seqlens_q_padded_sizes.end()}; + NVTEShapeWrapper cu_seqlens_q_padded_shape{cu_seqlens_q_padded_sizes}; auto cu_seqlens_kv_padded_sizes = cu_seqlens_kv_padded.value().sizes().vec(); - std::vector cu_seqlens_kv_padded_shape{cu_seqlens_kv_padded_sizes.begin(), - cu_seqlens_kv_padded_sizes.end()}; + NVTEShapeWrapper cu_seqlens_kv_padded_shape{cu_seqlens_kv_padded_sizes}; te_cu_seqlens_q_padded = makeTransformerEngineTensor(cu_seqlens_q_padded.value().data_ptr(), - cu_seqlens_q_padded_shape, DType::kInt32); + static_cast(cu_seqlens_q_padded_shape), DType::kInt32); te_cu_seqlens_kv_padded = makeTransformerEngineTensor( - cu_seqlens_kv_padded.value().data_ptr(), cu_seqlens_kv_padded_shape, DType::kInt32); + cu_seqlens_kv_padded.value().data_ptr(), static_cast(cu_seqlens_kv_padded_shape), DType::kInt32); } - + NVTEShape default_scale_inv_shape; + default_scale_inv_shape.ndim = 1; + default_scale_inv_shape.data[0] = 1; if ((page_table_k.has_value()) && (page_table_v.has_value())) { auto page_table_k_sizes = page_table_k.value().sizes().vec(); - std::vector page_table_k_shape{page_table_k_sizes.begin(), page_table_k_sizes.end()}; + NVTEShapeWrapper page_table_k_shape{page_table_k_sizes}; auto page_table_v_sizes = page_table_v.value().sizes().vec(); - std::vector page_table_v_shape{page_table_v_sizes.begin(), page_table_v_sizes.end()}; + NVTEShapeWrapper page_table_v_shape{page_table_v_sizes}; te_page_table_k = - makeTransformerEngineTensor(page_table_k.value().data_ptr(), page_table_k_shape, - DType::kInt32, nullptr, nullptr, nullptr); + makeTransformerEngineTensor(page_table_k.value().data_ptr(), static_cast(page_table_k_shape), + DType::kInt32, nullptr, nullptr, nullptr, default_scale_inv_shape); te_page_table_v = - makeTransformerEngineTensor(page_table_v.value().data_ptr(), page_table_v_shape, - DType::kInt32, nullptr, nullptr, nullptr); + makeTransformerEngineTensor(page_table_v.value().data_ptr(), static_cast(page_table_v_shape), + DType::kInt32, nullptr, nullptr, nullptr, default_scale_inv_shape); } // softmax offset TensorWrapper te_SoftmaxOffset; if ((softmax_type != NVTE_VANILLA_SOFTMAX) && (SoftmaxOffset.has_value())) { auto SoftmaxOffset_sizes = SoftmaxOffset.value().sizes().vec(); - std::vector SoftmaxOffset_shape{SoftmaxOffset_sizes.begin(), SoftmaxOffset_sizes.end()}; + NVTEShapeWrapper SoftmaxOffset_shape{SoftmaxOffset_sizes}; te_SoftmaxOffset = - makeTransformerEngineTensor(SoftmaxOffset.value().data_ptr(), SoftmaxOffset_shape, - DType::kFloat32, nullptr, nullptr, nullptr); + makeTransformerEngineTensor(SoftmaxOffset.value().data_ptr(), static_cast(SoftmaxOffset_shape), + DType::kFloat32, nullptr, nullptr, nullptr, default_scale_inv_shape); } // extract rng seed and offset @@ -461,30 +461,31 @@ std::vector fused_attn_bwd( // create cu_seqlens tensorwrappers auto cu_seqlens_q_sizes = cu_seqlens_q.sizes().vec(); - std::vector cu_seqlens_q_shape{cu_seqlens_q_sizes.begin(), cu_seqlens_q_sizes.end()}; + NVTEShapeWrapper cu_seqlens_q_shape{cu_seqlens_q_sizes}; auto cu_seqlens_kv_sizes = cu_seqlens_kv.sizes().vec(); - std::vector cu_seqlens_kv_shape{cu_seqlens_kv_sizes.begin(), cu_seqlens_kv_sizes.end()}; + NVTEShapeWrapper cu_seqlens_kv_shape{cu_seqlens_kv_sizes}; + NVTEShape zero_scale_inv_shape; + zero_scale_inv_shape.ndim = 1; + zero_scale_inv_shape.data[0] = 0; TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv; - te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), cu_seqlens_q_shape, - DType::kInt32, nullptr, nullptr, nullptr); - te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape, - DType::kInt32, nullptr, nullptr, nullptr); + te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), static_cast(cu_seqlens_q_shape), + DType::kInt32, nullptr, nullptr, nullptr, zero_scale_inv_shape); + te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), static_cast(cu_seqlens_kv_shape), + DType::kInt32, nullptr, nullptr, nullptr, zero_scale_inv_shape); TensorWrapper te_cu_seqlens_q_padded, te_cu_seqlens_kv_padded; if ((cu_seqlens_q_padded.has_value()) && (cu_seqlens_kv_padded.has_value())) { auto cu_seqlens_q_padded_sizes = cu_seqlens_q_padded.value().sizes().vec(); - std::vector cu_seqlens_q_padded_shape{cu_seqlens_q_padded_sizes.begin(), - cu_seqlens_q_padded_sizes.end()}; + NVTEShapeWrapper cu_seqlens_q_padded_shape{cu_seqlens_q_padded_sizes}; auto cu_seqlens_kv_padded_sizes = cu_seqlens_kv_padded.value().sizes().vec(); - std::vector cu_seqlens_kv_padded_shape{cu_seqlens_kv_padded_sizes.begin(), - cu_seqlens_kv_padded_sizes.end()}; + NVTEShapeWrapper cu_seqlens_kv_padded_shape{cu_seqlens_kv_padded_sizes}; te_cu_seqlens_q_padded = makeTransformerEngineTensor( cu_seqlens_q_padded.value().data_ptr(), - nvte_make_shape(cu_seqlens_q_padded_shape.data(), cu_seqlens_q_padded_shape.size()), + static_cast(cu_seqlens_q_padded_shape), DType::kInt32); te_cu_seqlens_kv_padded = makeTransformerEngineTensor( cu_seqlens_kv_padded.value().data_ptr(), - nvte_make_shape(cu_seqlens_kv_padded_shape.data(), cu_seqlens_kv_padded_shape.size()), + static_cast(cu_seqlens_kv_padded_shape), DType::kInt32); } @@ -494,12 +495,12 @@ std::vector fused_attn_bwd( nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size(); for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { const std::vector &signed_shape = Aux_CTX_Tensors[i].sizes().vec(); - const std::vector tmp(signed_shape.begin(), signed_shape.end()); + NVTEShapeWrapper tmp(signed_shape); NVTEBasicTensor temp_data = { Aux_CTX_Tensors[i].data_ptr(), static_cast(GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type())), - nvte_make_shape(tmp.data(), tmp.size())}; + static_cast(tmp)}; nvte_set_tensor_param(&nvte_aux_tensor_pack.tensors[i], kNVTERowwiseData, &temp_data); } diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index aa9d800c7bb..af04328948b 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -56,9 +56,8 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob TensorWrapper output_cpp; py::object output_py; if (output.is_none()) { - const auto shape = get_tensor_shape(input_cpp); const auto fake_dtype = input_cpp.dtype(); - std::tie(output_cpp, output_py) = quantizer_cpp->create_tensor(shape, fake_dtype); + std::tie(output_cpp, output_py) = quantizer_cpp->create_tensor(input_cpp.shape(), fake_dtype); } else { std::tie(output_cpp, output_py) = quantizer_cpp->convert_and_update_tensor(output); } @@ -180,8 +179,8 @@ std::vector multi_tensor_quantize(const std::vector &ten const auto input_dtype = GetTransformerEngineDType(input_py.scalar_type()); // Construct output tensor - std::vector output_shape(input_shape.data, input_shape.data + input_shape.ndim); - auto [output_cpp, output_py] = quantizer_cpp_list[i]->create_tensor(output_shape, input_dtype); + // std::vector output_shape(input_shape.data, input_shape.data + input_shape.ndim); + auto [output_cpp, output_py] = quantizer_cpp_list[i]->create_tensor(input_shape, input_dtype); output_cpp_list.emplace_back(std::move(output_cpp)); output_py_list.emplace_back(std::move(output_py)); } @@ -195,7 +194,7 @@ std::vector multi_tensor_quantize(const std::vector &ten namespace { std::tuple, std::vector> bulk_allocate_fp8_blockwise_tensors( - std::vector> &shape_list, std::vector &quantizer_py_list, + std::vector &shape_list, std::vector &quantizer_py_list, std::vector &quantizer_cpp_list) { init_extension(); std::tuple, std::vector> retval; @@ -220,7 +219,7 @@ std::tuple, std::vector> bulk_allocate_fp // Helper function to construct tensor view // Note: Deleter holds a shared_ptr for the buffer, so the buffer // will survive until all views are deleted. - auto make_torch_view = [](std::shared_ptr &buffer, const std::vector &shape, + auto make_torch_view = [](std::shared_ptr &buffer, const NVTEShapeWrapper &shape, size_t offset, at::ScalarType dtype) -> at::Tensor { std::vector shape_int64(shape.begin(), shape.end()); bool is_empty_shape = product(shape) == 0; @@ -235,13 +234,13 @@ std::tuple, std::vector> bulk_allocate_fp // Allocate row-wise data std::vector rowwise_data_list, rowwise_scale_list; - std::vector> rowwise_data_shapes, rowwise_scale_shapes; + std::vector rowwise_data_shapes, rowwise_scale_shapes; if (rowwise_usage) { // Tensor sizes for (size_t i = 0; i < num_tensors; ++i) { rowwise_data_shapes.emplace_back(shape_list[i]); rowwise_scale_shapes.emplace_back( - quantizer_cpp_list[i]->get_scale_shape(shape_list[i], false)); + quantizer_cpp_list[i]->get_scale_shape(rowwise_data_shapes[i], false)); } // Offsets in full buffer @@ -273,7 +272,7 @@ std::tuple, std::vector> bulk_allocate_fp // Allocate column-wise data std::vector columnwise_data_list, columnwise_scale_list; - std::vector> columnwise_data_shapes, columnwise_scale_shapes; + std::vector columnwise_data_shapes, columnwise_scale_shapes; if (columnwise_usage) { // Tensor sizes for (size_t i = 0; i < num_tensors; ++i) { @@ -330,24 +329,26 @@ std::tuple, std::vector> bulk_allocate_fp tensor_py_list.emplace_back(Float8BlockwiseQTensorClass( rowwise_data, rowwise_scale, columnwise_data, columnwise_scale, fp8_dtype, quantizer_py_list[i], is_2D_scaled, Float8BlockScaleTensorFormat::GEMM_READY)); - + NVTEShape zero_shape; + zero_shape.ndim = 1; + zero_shape.data[0] = 0; // Construct C++ tensor tensor_cpp_list.emplace_back(makeTransformerEngineTensor( rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_data_list[i].data_ptr() : nullptr, - rowwise_usage ? rowwise_data_shapes[i] : std::vector{0}, - columnwise_usage ? columnwise_data_shapes[i] : std::vector{0}, fp8_dtype, nullptr, + rowwise_usage ? static_cast(rowwise_data_shapes[i]) : zero_shape, + columnwise_usage ? static_cast(columnwise_data_shapes[i]) : zero_shape, fp8_dtype, nullptr, nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr, - rowwise_usage ? rowwise_scale_shapes[i] : std::vector{0}, - columnwise_usage ? columnwise_scale_shapes[i] : std::vector{0}, scaling_mode)); + rowwise_usage ? static_cast(rowwise_scale_shapes[i]) : zero_shape, + columnwise_usage ? static_cast(columnwise_scale_shapes[i]) : zero_shape, scaling_mode)); } return retval; } std::tuple, std::vector> bulk_allocate_mxfp8_tensors( - std::vector> &shape_list, std::vector &quantizer_py_list, + std::vector &shape_list, std::vector &quantizer_py_list, std::vector &quantizer_cpp_list) { init_extension(); std::tuple, std::vector> retval; @@ -371,7 +372,7 @@ std::tuple, std::vector> bulk_allocate_mx // Helper function to construct tensor view // Note: Deleter holds a shared_ptr for the buffer, so the buffer // will survive until all views are deleted. - auto make_torch_view = [](std::shared_ptr &buffer, const std::vector &shape, + auto make_torch_view = [](std::shared_ptr &buffer, const NVTEShapeWrapper &shape, size_t offset, at::ScalarType dtype) -> at::Tensor { std::vector shape_int64(shape.begin(), shape.end()); bool is_empty_shape = product(shape) == 0; @@ -386,13 +387,13 @@ std::tuple, std::vector> bulk_allocate_mx // Allocate row-wise data std::vector rowwise_data_list, rowwise_scale_list; - std::vector> rowwise_data_shapes, rowwise_scale_shapes; + std::vector rowwise_data_shapes, rowwise_scale_shapes; if (rowwise_usage) { // Tensor sizes for (size_t i = 0; i < num_tensors; ++i) { rowwise_data_shapes.emplace_back(shape_list[i]); rowwise_scale_shapes.emplace_back( - quantizer_cpp_list[i]->get_scale_shape(shape_list[i], false)); + quantizer_cpp_list[i]->get_scale_shape(rowwise_data_shapes[i], false)); } // Offsets in full buffer @@ -424,7 +425,7 @@ std::tuple, std::vector> bulk_allocate_mx // Allocate column-wise data std::vector columnwise_data_list, columnwise_scale_list; - std::vector> columnwise_data_shapes, columnwise_scale_shapes; + std::vector columnwise_data_shapes, columnwise_scale_shapes; if (columnwise_usage) { // Tensor sizes for (size_t i = 0; i < num_tensors; ++i) { @@ -477,17 +478,19 @@ std::tuple, std::vector> bulk_allocate_mx tensor_py_list.emplace_back(MXFP8TensorClass(rowwise_data, rowwise_scale, columnwise_data, columnwise_scale, fp8_dtype, quantizer_py_list[i])); - + NVTEShape zero_shape; + zero_shape.ndim = 1; + zero_shape.data[0] = 0; // Construct C++ tensor tensor_cpp_list.emplace_back(makeTransformerEngineTensor( rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_data_list[i].data_ptr() : nullptr, - rowwise_usage ? rowwise_data_shapes[i] : std::vector{0}, - columnwise_usage ? columnwise_data_shapes[i] : std::vector{0}, fp8_dtype, nullptr, + rowwise_usage ? static_cast(rowwise_data_shapes[i]) : zero_shape, + columnwise_usage ? static_cast(columnwise_data_shapes[i]) : zero_shape, fp8_dtype, nullptr, nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr, - rowwise_usage ? rowwise_scale_shapes[i] : std::vector{0}, - columnwise_usage ? columnwise_scale_shapes[i] : std::vector{0}, scaling_mode)); + rowwise_usage ? static_cast(rowwise_scale_shapes[i]) : zero_shape, + columnwise_usage ? static_cast(columnwise_scale_shapes[i]) : zero_shape, scaling_mode)); } return retval; @@ -497,7 +500,7 @@ std::tuple, std::vector> bulk_allocate_mx // layout: [fp4_data0, ..., fp4_dataN, fp8_scaling0, ..., fp8_scalingN, amax0, ..., amaxN] // amax buffer will be zeroed out by later amax kernels, so we can use empty to allocate std::tuple, std::vector, bool> bulk_allocate_nvfp4_tensors( - std::vector> &shape_list, std::vector &quantizer_py_list, + std::vector &shape_list, std::vector &quantizer_py_list, std::vector &quantizer_cpp_list) { init_extension(); std::tuple, std::vector, bool> retval; @@ -522,7 +525,7 @@ std::tuple, std::vector, bool> bulk_alloc // Helper function to construct tensor view // Note: Deleter holds a shared_ptr for the buffer, so the buffer // will survive until all views are deleted. - auto make_torch_view = [](std::shared_ptr &buffer, const std::vector &shape, + auto make_torch_view = [](std::shared_ptr &buffer, const NVTEShapeWrapper &shape, size_t offset, at::ScalarType dtype) -> at::Tensor { std::vector shape_int64(shape.begin(), shape.end()); bool is_empty_shape = product(shape) == 0; @@ -535,9 +538,9 @@ std::tuple, std::vector, bool> bulk_alloc at::device(at::kCUDA).dtype(dtype)); }; - // Lambda function for converting std::vector shape to NVFP4 shape (last dim divided by 2) - auto to_fp4_shape = [](const std::vector &shape) { - std::vector fp4_shape(shape.begin(), shape.end()); + // Lambda function for converting NVTEShapeWrapper shape to NVFP4 shape (last dim divided by 2) + auto to_fp4_shape = [](const NVTEShapeWrapper &shape) { + NVTEShapeWrapper fp4_shape(shape); if (!fp4_shape.empty()) { fp4_shape.back() /= 2; } @@ -546,13 +549,13 @@ std::tuple, std::vector, bool> bulk_alloc // Allocate row-wise data std::vector rowwise_data_list, rowwise_scale_list, amax_rowwise_list; - std::vector> rowwise_data_shapes, rowwise_scale_shapes; + std::vector rowwise_data_shapes, rowwise_scale_shapes; if (rowwise_usage) { // Tensor sizes for (size_t i = 0; i < num_tensors; ++i) { rowwise_data_shapes.emplace_back(shape_list[i]); rowwise_scale_shapes.emplace_back( - quantizer_cpp_list[i]->get_scale_shape(shape_list[i], false)); + quantizer_cpp_list[i]->get_scale_shape(rowwise_data_shapes[i], false)); } // Offsets in full buffer @@ -587,7 +590,9 @@ std::tuple, std::vector, bool> bulk_alloc // Allocate full buffer auto buffer = std::make_shared( at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); - + NVTEShape amax_shape; + amax_shape.ndim = 1; + amax_shape.data[0] = 1; // Construct tensor views for (size_t i = 0; i < num_tensors; ++i) { rowwise_data_list.emplace_back(make_torch_view(buffer, to_fp4_shape(rowwise_data_shapes[i]), @@ -595,13 +600,13 @@ std::tuple, std::vector, bool> bulk_alloc rowwise_scale_list.emplace_back( make_torch_view(buffer, rowwise_scale_shapes[i], scale_offsets[i], torch::kUInt8)); amax_rowwise_list.emplace_back( - make_torch_view(buffer, std::vector{1}, amax_offsets[i], torch::kFloat32)); + make_torch_view(buffer, amax_shape, amax_offsets[i], torch::kFloat32)); } } // Allocate column-wise data std::vector columnwise_data_list, columnwise_scale_list, amax_columnwise_list; - std::vector> columnwise_data_shapes, columnwise_scale_shapes; + std::vector columnwise_data_shapes, columnwise_scale_shapes; if (columnwise_usage) { // Tensor sizes for (size_t i = 0; i < num_tensors; ++i) { @@ -649,7 +654,9 @@ std::tuple, std::vector, bool> bulk_alloc // Allocate full buffer auto buffer = std::make_shared( at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); - + NVTEShape amax_shape; + amax_shape.ndim = 1; + amax_shape.data[0] = 1; // Construct tensor views for (size_t i = 0; i < num_tensors; ++i) { columnwise_data_list.emplace_back(make_torch_view( @@ -657,7 +664,7 @@ std::tuple, std::vector, bool> bulk_alloc columnwise_scale_list.emplace_back( make_torch_view(buffer, columnwise_scale_shapes[i], scale_offsets[i], torch::kUInt8)); amax_columnwise_list.emplace_back( - make_torch_view(buffer, std::vector{1}, amax_offsets[i], torch::kFloat32)); + make_torch_view(buffer, amax_shape, amax_offsets[i], torch::kFloat32)); } } @@ -682,26 +689,31 @@ std::tuple, std::vector, bool> bulk_alloc // Construct C++ tensor // Use a TensorWrapper variable to hold the output of makeTransformerEngineTensor, // then set the amax and amax_columnwise values. + NVTEShape zero_shape, amax_shape; + zero_shape.ndim = 1; + amax_shape.ndim = 1; + zero_shape.data[0] = 0; + amax_shape.data[0] = 1; { auto tensor_wrapper = makeTransformerEngineTensor( rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_data_list[i].data_ptr() : nullptr, - rowwise_usage ? rowwise_data_shapes[i] : std::vector{0}, - columnwise_usage ? columnwise_data_shapes[i] : std::vector{0}, fp4_dtype, + rowwise_usage ? static_cast(rowwise_data_shapes[i]) : zero_shape, + columnwise_usage ? static_cast(columnwise_data_shapes[i]) : zero_shape, fp4_dtype, /*amax_ptr=*/nullptr, /*scale_ptr=*/nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr, - rowwise_usage ? rowwise_scale_shapes[i] : std::vector{0}, - columnwise_usage ? columnwise_scale_shapes[i] : std::vector{0}, scaling_mode); - + rowwise_usage ? static_cast(rowwise_scale_shapes[i]) : zero_shape, + columnwise_usage ? static_cast(columnwise_scale_shapes[i]) : zero_shape, scaling_mode); + // Set the amax rowwise and amax columnwise if available if (rowwise_usage) { tensor_wrapper.set_amax(amax_rowwise_list[i].data_ptr(), DType::kFloat32, - std::vector{1}); + amax_shape); } if (columnwise_usage) { tensor_wrapper.set_columnwise_amax(amax_columnwise_list[i].data_ptr(), DType::kFloat32, - std::vector{1}); + amax_shape); } tensor_cpp_list.emplace_back(std::move(tensor_wrapper)); } @@ -765,9 +777,11 @@ static StochasticRngStateResources setup_stochastic_rounding_rng_states_helper( at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); int64_t *rng_state_ptr = static_cast(res.rng_states_tensor.data_ptr()) + i * 2; philox_unpack(philox_args, rng_state_ptr); - + NVTEShape rng_state_shape; + rng_state_shape.ndim = 1; + rng_state_shape.data[0] = 2; res.te_rng_state_list.push_back(makeTransformerEngineTensor( - static_cast(rng_state_ptr), std::vector{2}, DType::kInt64)); + static_cast(rng_state_ptr), rng_state_shape, DType::kInt64)); quant_config_list_rowwise[i].set_rng_state(res.te_rng_state_list[i].data()); quant_config_list_rowwise[i].set_stochastic_rounding(true); @@ -781,7 +795,7 @@ static StochasticRngStateResources setup_stochastic_rounding_rng_states_helper( philox_unpack(philox_args_col, rng_state_ptr_colwise); res.te_rng_state_list_colwise.push_back(makeTransformerEngineTensor( - static_cast(rng_state_ptr_colwise), std::vector{2}, DType::kInt64)); + static_cast(rng_state_ptr_colwise), rng_state_shape, DType::kInt64)); quant_config_list_colwise[i].set_rng_state(res.te_rng_state_list_colwise[i].data()); quant_config_list_colwise[i].set_stochastic_rounding(true); } @@ -997,18 +1011,21 @@ void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, // Note that the multi compute amax API expects rowwise amax pointer to be not null // So we need to set the pointer accordingly to make colwise-only quantization work std::vector orig_amax_ptr_list; + NVTEShape amax_shape; + amax_shape.ndim = 1; + amax_shape.data[0] = 1; for (size_t i = 0; i < num_tensors; i++) { auto rowwise_amax_ptr = output_list[i].get_amax().data_ptr; orig_amax_ptr_list.push_back(rowwise_amax_ptr); auto columnwise_amax_ptr = output_list[i].get_columnwise_amax().data_ptr; void *amax_ptr = rowwise_amax_ptr != nullptr ? rowwise_amax_ptr : columnwise_amax_ptr; NVTE_CHECK(amax_ptr != nullptr, "Could not find amax pointer"); - output_list[i].set_amax(amax_ptr, DType::kFloat32, std::vector{1}); + output_list[i].set_amax(amax_ptr, DType::kFloat32, amax_shape); } nvte_group_amax(input.data(), reinterpret_cast(nvte_tensor_output_list.data()), split_sections.data(), num_tensors, stream); for (size_t i = 0; i < num_tensors; i++) { - output_list[i].set_amax(orig_amax_ptr_list[i], DType::kFloat32, std::vector{1}); + output_list[i].set_amax(orig_amax_ptr_list[i], DType::kFloat32, amax_shape); } // Quantize tensors individually @@ -1104,7 +1121,7 @@ std::vector split_quantize(const at::Tensor &tensor, auto input_py = tensor.contiguous(); uint8_t *input_dptr = reinterpret_cast(input_py.data_ptr()); auto input_dtype = GetTransformerEngineDType(input_py.scalar_type()); - std::vector input_shape; + NVTEShapeWrapper input_shape; size_t input_size = 1; for (const auto &d : input_py.sizes()) { input_shape.push_back(d); @@ -1114,7 +1131,7 @@ std::vector split_quantize(const at::Tensor &tensor, // Split input tensor along dim 0 std::vector input_list; - std::vector> split_shapes; + std::vector split_shapes; size_t dim0_offset = 0; const size_t dim0_stride = input_shape[0] == 0 ? 0 : input_py.element_size() * input_size / input_shape[0]; @@ -1122,11 +1139,14 @@ std::vector split_quantize(const at::Tensor &tensor, NVTE_CHECK(dim0_offset + split_sections[i] <= input_shape[0], "Attempted to split tensor with shape=", input_shape, " along dim 0 with split_sections=", split_sections); - split_shapes.push_back(input_shape); + split_shapes.emplace_back(); auto &split_shape = split_shapes.back(); - split_shape[0] = split_sections[i]; + split_shape.push_back(split_sections[i]); + for (size_t j = 1; j < input_shape.size(); ++j) { + split_shape.push_back(input_shape[j]); + } void *split_dptr = static_cast(input_dptr + dim0_offset * dim0_stride); - input_list.emplace_back(makeTransformerEngineTensor(split_dptr, split_shape, input_dtype)); + input_list.emplace_back(makeTransformerEngineTensor(split_dptr, static_cast(split_shape), input_dtype)); dim0_offset += split_sections[i]; } diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index fa01af53fe8..07acd44170a 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -406,7 +406,7 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type, counter_shape.data[0] = static_cast(counter.size(0)); auto te_counter = makeTransformerEngineTensor(counter.data_ptr(), counter_shape, DType::kInt32); - NVTEShape gelu_shape; + NVTEShape gelu_shape, workspace_shape; if (pre_gelu_out.data_ptr() == nullptr) { gelu_shape.ndim = 1; gelu_shape.data[0] = static_cast(pre_gelu_out.size(0)); @@ -415,10 +415,12 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type, gelu_shape.data[0] = static_cast(pre_gelu_out.size(0)); gelu_shape.data[1] = static_cast(pre_gelu_out.size(1)); } + workspace_shape.ndim = 1; + workspace_shape.data[0] = workspaceSize; auto te_pre_gelu_out = makeTransformerEngineTensor( pre_gelu_out.data_ptr(), gelu_shape, GetTransformerEngineDType(pre_gelu_out.scalar_type())); auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(), - std::vector{workspaceSize}, DType::kByte); + workspace_shape, DType::kByte); NVTE_SCOPED_GIL_RELEASE({ nvte_cublas_atomic_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(), @@ -509,10 +511,14 @@ std::optional> te_general_grouped_gemm( auto te_bias = makeTransformerEngineTensor(bias[i]); auto te_pre_gelu_out = makeTransformerEngineTensor(pre_gelu_out[i]); - const auto gelu_shape = pre_gelu_out[i].data_ptr() == nullptr - ? std::vector{static_cast(te_pre_gelu_out.size(0))} - : std::vector{static_cast(te_pre_gelu_out.size(0)), - static_cast(te_pre_gelu_out.size(1))}; + NVTEShape gelu_shape; + gelu_shape.data[0] = te_pre_gelu_out.size(0); + if (pre_gelu_out[i].data_ptr() == nullptr) { + gelu_shape.ndim = 1; + } else { + gelu_shape.ndim = 2; + gelu_shape.data[1] = te_pre_gelu_out.size(1); + } DType gelu_type = bias_type; te_pre_gelu_out = @@ -579,9 +585,12 @@ std::optional> te_general_grouped_gemm( std::vector te_workspace_vector; std::vector te_workspace_wrappers; + NVTEShape workspace_shape; + workspace_shape.ndim = 1; + workspace_shape.data[0] = workspaceSize; for (size_t i = 0; i < workspace.size(); i++) { auto wsp = makeTransformerEngineTensor(workspace[i].data_ptr(), - std::vector{workspaceSize}, DType::kByte); + workspace_shape, DType::kByte); te_workspace_vector.emplace_back(wsp.data()); te_workspace_wrappers.emplace_back(std::move(wsp)); } diff --git a/transformer_engine/pytorch/csrc/extensions/padding.cpp b/transformer_engine/pytorch/csrc/extensions/padding.cpp index 389308405b1..cabb65233f7 100644 --- a/transformer_engine/pytorch/csrc/extensions/padding.cpp +++ b/transformer_engine/pytorch/csrc/extensions/padding.cpp @@ -20,7 +20,7 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, const auto num_tensors = input_row_list.size(); // Extract properties from PyTorch tensors std::vector input_dptr_list, output_dptr_list; - std::vector> input_shape_list, output_shape_list; + std::vector input_shape_list, output_shape_list; std::vector input_type_list; void* d_input_ptr = reinterpret_cast(input.data_ptr()); void* d_output_ptr = reinterpret_cast(output.data_ptr()); @@ -34,8 +34,11 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, input_row_list[tensor_id] * input.size(1) * input.element_size(); input_char_ptr += input_dptr_offset; d_input_ptr = reinterpret_cast(input_char_ptr); - NVTEShape input_shape = {input_row_list[tensor_id], static_cast(input.size(1))}; - input_shape_list.push_back({input_row_list[tensor_id], static_cast(input.size(1))}); + NVTEShape input_shape; + input_shape.ndim = 2; + input_shape.data[0] = input_row_list[tensor_id]; + input_shape.data[1] = static_cast(input.size(1)); + input_shape_list.push_back(input_shape); input_type_list.push_back(GetTransformerEngineDType(input.scalar_type())); // Move the output pointer to the next split. @@ -45,14 +48,17 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, output_char_ptr += output_dptr_offset; d_output_ptr = reinterpret_cast(output_char_ptr); - output_shape_list.push_back( - {padded_input_row_list[tensor_id], static_cast(output.size(1))}); + NVTEShape output_shape; + output_shape.ndim = 2; + output_shape.data[0] = padded_input_row_list[tensor_id]; + output_shape.data[1] = static_cast(output.size(1)); + output_shape_list.push_back(output_shape); } // Construct TE tensors std::vector nvte_input_list, nvte_output_list; std::vector tensor_wrappers; - auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector& shape, + auto make_tensor = [&tensor_wrappers](void* dptr, const NVTEShape& shape, DType dtype) -> NVTETensor { tensor_wrappers.emplace_back(makeTransformerEngineTensor(dptr, shape, dtype)); return tensor_wrappers.back().data(); @@ -95,7 +101,7 @@ void fused_multi_row_unpadding(at::Tensor input, at::Tensor output, const auto num_tensors = input_row_list.size(); // Extract properties from PyTorch tensors std::vector input_dptr_list, output_dptr_list; - std::vector> input_shape_list, output_shape_list; + std::vector input_shape_list, output_shape_list; std::vector input_type_list; void* d_input_ptr = reinterpret_cast(input.data_ptr()); void* d_output_ptr = reinterpret_cast(output.data_ptr()); @@ -109,8 +115,11 @@ void fused_multi_row_unpadding(at::Tensor input, at::Tensor output, input_row_list[tensor_id] * input.size(1) * input.element_size(); input_char_ptr += input_dptr_offset; d_input_ptr = reinterpret_cast(input_char_ptr); - - input_shape_list.push_back({input_row_list[tensor_id], static_cast(input.size(1))}); + NVTEShape input_shape; + input_shape.ndim = 2; + input_shape.data[0] = input_row_list[tensor_id]; + input_shape.data[1] = static_cast(input.size(1)); + input_shape_list.push_back(input_shape); input_type_list.push_back(GetTransformerEngineDType(input.scalar_type())); // Move the output pointer to the next split. @@ -120,14 +129,17 @@ void fused_multi_row_unpadding(at::Tensor input, at::Tensor output, output_char_ptr += output_dptr_offset; d_output_ptr = reinterpret_cast(output_char_ptr); - output_shape_list.push_back( - {unpadded_input_row_list[tensor_id], static_cast(output.size(1))}); + NVTEShape output_shape; + output_shape.ndim = 2; + output_shape.data[0] = unpadded_input_row_list[tensor_id]; + output_shape.data[1] = static_cast(output.size(1)); + output_shape_list.push_back(output_shape); } // Construct TE tensors std::vector nvte_input_list, nvte_output_list; std::vector tensor_wrappers; - auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector& shape, + auto make_tensor = [&tensor_wrappers](void* dptr, const NVTEShape& shape, transformer_engine::DType dtype) -> NVTETensor { tensor_wrappers.emplace_back(makeTransformerEngineTensor(dptr, shape, dtype)); return tensor_wrappers.back().data(); diff --git a/transformer_engine/pytorch/csrc/extensions/permutation.cpp b/transformer_engine/pytorch/csrc/extensions/permutation.cpp index 97cf4008511..189c7127731 100644 --- a/transformer_engine/pytorch/csrc/extensions/permutation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/permutation.cpp @@ -60,18 +60,25 @@ std::tuple> moe_permute_fwd( {num_tokens * topK}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); auto stream = at::cuda::getCurrentCUDAStream().stream(); - + NVTEShape input_shape, permuted_output_shape, sorted_row_id_cu_shape; + input_shape.ndim = 2; + permuted_output_shape.ndim = 2; + sorted_row_id_cu_shape.ndim = 1; + input_shape.data[0] = static_cast(input.size(0)); + input_shape.data[1] = static_cast(input.size(1)); + permuted_output_shape.data[0] = static_cast(permuted_output.size(0)); + permuted_output_shape.data[1] = static_cast(permuted_output.size(1)); + sorted_row_id_cu_shape.data[0] = static_cast(num_tokens * topK); auto input_cu = makeTransformerEngineTensor( input.data_ptr(), - std::vector{static_cast(input.size(0)), static_cast(num_cols)}, + input_shape, dtype); auto permuted_output_cu = makeTransformerEngineTensor(permuted_output.data_ptr(), - std::vector{static_cast(permuted_output.size(0)), - static_cast(num_cols)}, + permuted_output_shape, dtype); auto sorted_row_id_cu = makeTransformerEngineTensor( - sorted_row_id_ptr, std::vector{static_cast(num_tokens * topK)}, + sorted_row_id_ptr, sorted_row_id_cu_shape, DType::kInt32); auto row_id_map_cu = makeTransformerEngineTensor(row_id_map); @@ -97,15 +104,20 @@ at::Tensor moe_unpermute_fwd(at::Tensor input, const DType dtype, at::Tensor row torch::dtype(input.scalar_type()).device(torch::kCUDA).requires_grad(false)); auto stream = at::cuda::getCurrentCUDAStream().stream(); - + NVTEShape input_shape, unpermuted_output_shape; + input_shape.ndim = 2; + unpermuted_output_shape.ndim = 2; + input_shape.data[0] = static_cast(input.size(0)); + input_shape.data[1] = static_cast(input.size(1)); + unpermuted_output_shape.data[0] = static_cast(unpermuted_output.size(0)); + unpermuted_output_shape.data[1] = static_cast(num_cols); auto input_cu = makeTransformerEngineTensor( input.data_ptr(), - std::vector{static_cast(input.size(0)), static_cast(num_cols)}, + input_shape, dtype); auto unpermuted_output_cu = makeTransformerEngineTensor( unpermuted_output.data_ptr(), - std::vector{static_cast(unpermuted_output.size(0)), - static_cast(num_cols)}, + unpermuted_output_shape, dtype); auto row_id_map_cu = makeTransformerEngineTensor(row_id_map); auto prob_cu = makeTransformerEngineTensor(prob); @@ -131,18 +143,27 @@ std::tuple moe_unpermute_bwd(at::Tensor input_bwd, at::T {num_tokens, topK}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(false)); auto stream = at::cuda::getCurrentCUDAStream().stream(); - + NVTEShape input_bwd_shape, act_grad_shape, input_fwd_shape; + input_bwd_shape.ndim = 2; + act_grad_shape.ndim = 2; + input_fwd_shape.ndim = 2; + input_bwd_shape.data[0] = static_cast(input_bwd.size(0)); + input_bwd_shape.data[1] = static_cast(num_cols); + act_grad_shape.data[0] = static_cast(act_grad.size(0)); + act_grad_shape.data[1] = static_cast(num_cols); + input_fwd_shape.data[0] = static_cast(input_fwd.size(0)); + input_fwd_shape.data[1] = static_cast(num_cols); auto input_bwd_cu = makeTransformerEngineTensor( input_bwd.data_ptr(), - std::vector{static_cast(input_bwd.size(0)), static_cast(num_cols)}, + input_bwd_shape, dtype); auto act_grad_cu = makeTransformerEngineTensor( act_grad.data_ptr(), - std::vector{static_cast(act_grad.size(0)), static_cast(num_cols)}, + act_grad_shape, dtype); auto input_fwd_cu = makeTransformerEngineTensor( input_fwd.data_ptr(), - std::vector{static_cast(input_fwd.size(0)), static_cast(num_cols)}, + input_fwd_shape, dtype); auto row_id_map_cu = makeTransformerEngineTensor(row_id_map); auto prob_cu = makeTransformerEngineTensor(prob); diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cpp b/transformer_engine/pytorch/csrc/extensions/transpose.cpp index 55c7fd57d79..6c1ff313c5c 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cpp +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cpp @@ -45,9 +45,16 @@ at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional{M, N}, otype); - auto output_cu = makeTransformerEngineTensor(out.data_ptr(), std::vector{N, M}, otype); + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), input_shape, otype); + auto output_cu = makeTransformerEngineTensor(out.data_ptr(), output_shape, otype); nvte_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); return out; diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 6345ae3894c..650847dd86d 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -869,12 +869,24 @@ void Float8BlockQuantizer::quantize(const TensorWrapper& input, TensorWrapper& o std::vector Float8BlockQuantizer::get_scale_shape(const std::vector& shape, bool columnwise) const { + return get_scale_shape_impl(shape, columnwise); +} + +NVTEShapeWrapper Float8BlockQuantizer::get_scale_shape(const NVTEShapeWrapper& shape, + bool columnwise) const { + return get_scale_shape_impl(shape, columnwise); +} + +template +ShapeT Float8BlockQuantizer::get_scale_shape_impl(const ShapeT& shape, bool columnwise) const { size_t numel = 1; + size_t k_dim; + for (auto s : shape) { numel *= s; } + k_dim = shape.size() == 0 ? 1u : shape.back(); - size_t k_dim = shape.size() == 0 ? 1u : shape.back(); size_t m_dim = numel / k_dim; constexpr size_t kBlockLen = 128; @@ -882,27 +894,20 @@ std::vector Float8BlockQuantizer::get_scale_shape(const std::vector scale_shape; - + size_t sinv0 = 0; + size_t sinv1 = 0; bool rowwise_usage = !columnwise; if (rowwise_usage) { - // rowwise scaling factor shape - size_t sinv0 = 0; - size_t sinv1 = 0; if (block_scaling_dim == 2) { - // 2D scaling is always GEMM_READY for now NVTE_CHECK(data_format == Float8BlockScaleTensorFormat::GEMM_READY, "2D scaling is always GEMM_READY for now."); sinv0 = (m_dim + kBlockLen - 1) / kBlockLen; sinv1 = roundup((k_dim + kBlockLen - 1) / kBlockLen, 4); } else if (block_scaling_dim == 1) { - // 1D scaling can be GEMM_READY or COMPACT bool rowwise_compact = data_format == Float8BlockScaleTensorFormat::COMPACT; - // default rowwise scaling factor shape already transpose the scaling factor so it's GEMM_READY sinv0 = (k_dim + kBlockLen - 1) / kBlockLen; sinv1 = rowwise_compact ? m_dim : roundup(m_dim, 4); - // if the rowwise format is compact, the scaling factor is not be transposed if (rowwise_compact) { std::swap(sinv0, sinv1); } @@ -912,13 +917,8 @@ std::vector Float8BlockQuantizer::get_scale_shape(const std::vector Float8BlockQuantizer::get_scale_shape(const std::vector MXFP8Quantizer::get_scale_shape(const std::vector& shape, bool columnwise) const { + return get_scale_shape_impl(shape, columnwise); +} + +NVTEShapeWrapper MXFP8Quantizer::get_scale_shape(const NVTEShapeWrapper& shape, + bool columnwise) const { + return get_scale_shape_impl(shape, columnwise); +} + +template +ShapeT MXFP8Quantizer::get_scale_shape_impl(const ShapeT& shape, bool columnwise) const { size_t numel = 1; + size_t last_dim; + for (auto s : shape) { numel *= s; } - - auto last_dim = shape.back(); + last_dim = shape.back(); NVTE_CHECK(last_dim % MXFP8_BLOCK_SIZE == 0 && (numel / last_dim) % MXFP8_BLOCK_SIZE == 0, "MXFP8 requires tensor dims that are divisible by ", MXFP8_BLOCK_SIZE, " (got shape=", shape, ")"); - std::vector scale_shape; - + size_t sinv0 = 0; + size_t sinv1 = 0; bool rowwise_usage = !columnwise; if (rowwise_usage) { - // rowwise scaling factor shape - size_t sinv0 = roundup(numel / last_dim, 128); - size_t sinv1 = roundup(last_dim / MXFP8_BLOCK_SIZE, 4); - scale_shape = {sinv0, sinv1}; + sinv0 = roundup(numel / last_dim, 128); + sinv1 = roundup(last_dim / MXFP8_BLOCK_SIZE, 4); } else { - // columnwise scaling factor shape - size_t sinv0 = roundup(numel / (last_dim * MXFP8_BLOCK_SIZE), 4); - size_t sinv1 = roundup(last_dim, 128); - scale_shape = {sinv0, sinv1}; + sinv0 = roundup(numel / (last_dim * MXFP8_BLOCK_SIZE), 4); + sinv1 = roundup(last_dim, 128); } - return scale_shape; + + ShapeT result; + result.resize(2); + result[0] = sinv0; + result[1] = sinv1; + return result; } NVFP4Quantizer::NVFP4Quantizer(const py::handle& quantizer) : Quantizer(quantizer) { @@ -1766,12 +1781,24 @@ void NVFP4Quantizer::quantize_with_amax(TensorWrapper& input, TensorWrapper& out std::vector NVFP4Quantizer::get_scale_shape(const std::vector& shape, bool columnwise) const { + return get_scale_shape_impl(shape, columnwise); +} + +NVTEShapeWrapper NVFP4Quantizer::get_scale_shape(const NVTEShapeWrapper& shape, + bool columnwise) const { + return get_scale_shape_impl(shape, columnwise); +} + +template +ShapeT NVFP4Quantizer::get_scale_shape_impl(const ShapeT& shape, bool columnwise) const { size_t numel = 1; + size_t last_dim; + for (auto s : shape) { numel *= s; } + last_dim = shape.back(); - auto last_dim = shape.back(); auto flat_first_dim = numel / last_dim; NVTE_CHECK(last_dim % NVFP4_BLOCK_SIZE == 0, "Last dim for NVFP4 must be divisible by ", @@ -1780,22 +1807,23 @@ std::vector NVFP4Quantizer::get_scale_shape(const std::vector& s "NVFP4 requires tensor dims that are divisible by ", NVFP4_BLOCK_SIZE, " (got shape=", shape, ")"); - std::vector scale_shape; - + size_t sinv0 = 0; + size_t sinv1 = 0; bool rowwise_usage = !columnwise; if (rowwise_usage) { - // rowwise scaling factor shape - size_t sinv0 = roundup(flat_first_dim, 128); - size_t sinv1 = roundup(last_dim / NVFP4_BLOCK_SIZE, 4); - scale_shape = {sinv0, sinv1}; + sinv0 = roundup(flat_first_dim, 128); + sinv1 = roundup(last_dim / NVFP4_BLOCK_SIZE, 4); } else { - // columnwise scaling factor shape - size_t sinv0 = roundup(last_dim, 128); - size_t sinv1 = roundup(flat_first_dim / NVFP4_BLOCK_SIZE, 4); - scale_shape = {sinv0, sinv1}; + sinv0 = roundup(last_dim, 128); + sinv1 = roundup(flat_first_dim / NVFP4_BLOCK_SIZE, 4); } - return scale_shape; + + ShapeT result; + result.resize(2); + result[0] = sinv0; + result[1] = sinv1; + return result; } } // namespace transformer_engine::pytorch From a66f46b831f9665d6edf0ed9109ac6c52c075b73 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 30 Dec 2025 12:42:50 +0000 Subject: [PATCH 10/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../transformer_engine/transformer_engine.h | 31 ++++---- transformer_engine/pytorch/csrc/common.cpp | 3 +- transformer_engine/pytorch/csrc/common.h | 8 +-- .../pytorch/csrc/extensions/attention.cpp | 55 +++++++------- .../pytorch/csrc/extensions/cast.cpp | 44 ++++++------ .../pytorch/csrc/extensions/gemm.cpp | 16 ++--- .../pytorch/csrc/extensions/permutation.cpp | 40 +++-------- transformer_engine/pytorch/csrc/quantizer.cpp | 72 +++++++++---------- 8 files changed, 125 insertions(+), 144 deletions(-) diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index f9eb244cd9c..26a07c707a6 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -527,39 +527,36 @@ class NVTEShapeWrapper { public: // Default constructor - NVTEShapeWrapper() { - data.ndim = 0; - } + NVTEShapeWrapper() { data.ndim = 0; } // Constructor from NVTEShape (direct assignment by reference) - NVTEShapeWrapper(const NVTEShape& shape) { - data = shape; - } + NVTEShapeWrapper(const NVTEShape &shape) { data = shape; } // Constructor from vector (creates a copy) - template NVTEShapeWrapper(const std::vector& shape_vec) { + template + NVTEShapeWrapper(const std::vector &shape_vec) { data.ndim = shape_vec.size(); for (size_t i = 0; i < data.ndim; ++i) { data.data[i] = static_cast(shape_vec[i]); } } - operator NVTEShape&() { return data; } - operator const NVTEShape&() const { return data; } + operator NVTEShape &() { return data; } + operator const NVTEShape &() const { return data; } // Iterator support - size_t* begin() { return data.data; } - const size_t* begin() const { return data.data; } - size_t* end() { return data.data + data.ndim; } - const size_t* end() const { return data.data + data.ndim; } + size_t *begin() { return data.data; } + const size_t *begin() const { return data.data; } + size_t *end() { return data.data + data.ndim; } + const size_t *end() const { return data.data + data.ndim; } // Index access - size_t& operator[](size_t idx) { return data.data[idx]; } - const size_t& operator[](size_t idx) const { return data.data[idx]; } + size_t &operator[](size_t idx) { return data.data[idx]; } + const size_t &operator[](size_t idx) const { return data.data[idx]; } // Back access - size_t& back() { return data.data[data.ndim - 1]; } - const size_t& back() const { return data.data[data.ndim - 1]; } + size_t &back() { return data.data[data.ndim - 1]; } + const size_t &back() const { return data.data[data.ndim - 1]; } // Size access size_t size() const { return data.ndim; } diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index 66b1e227c25..33060732777 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -175,7 +175,6 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( return ret; } - transformer_engine::TensorWrapper makeTransformerEngineTensor( void* data_ptr, void* columnwise_data_ptr, const NVTEShape& shape, const NVTEShape& columnwise_shape, const transformer_engine::DType type, void* amax_ptr, @@ -230,7 +229,7 @@ template size_t product(const std::vector& shape); template int64_t product(const std::vector& shape); size_t product(const NVTEShape& shape, size_t begin, size_t end) { - if(end == -1) { + if (end == -1) { end = shape.ndim; } NVTE_CHECK(begin <= end && end <= shape.ndim, "Attempted to access entries ", begin, " to ", end, diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index b703cbc6810..39b73c96443 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -142,7 +142,8 @@ class NoneQuantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype, at::Tensor data) const; - std::pair create_tensor(const NVTEShape& shape, DType dtype) const override; + std::pair create_tensor(const NVTEShape& shape, + DType dtype) const override; /*! @brief Construct a tensor with pre-initialized data */ std::pair create_tensor(const NVTEShape& shape, DType dtype, @@ -457,7 +458,6 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape, NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); - transformer_engine::TensorWrapper makeTransformerEngineTensor( void* data_ptr, void* columnwise_data_ptr, const NVTEShape& shape, const NVTEShape& columnwise_shape, const transformer_engine::DType type, void* amax_ptr, @@ -484,7 +484,7 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( template T product(const std::vector& shape); -size_t product(const NVTEShape& shape, size_t begin=0, size_t end=-1); +size_t product(const NVTEShape& shape, size_t begin = 0, size_t end = -1); std::vector nvte_shape_to_vector(const NVTEShape& nvte_shape); @@ -560,4 +560,4 @@ inline string to_string(const NVTEShape& s) { } } // namespace std -#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_ \ No newline at end of file +#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_ diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 6d6effce6d9..bfa989c26c0 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -164,26 +164,29 @@ std::vector fused_attn_fwd( if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) { auto bias_sizes = Bias.value().sizes().vec(); NVTEShapeWrapper bias_shape{bias_sizes}; - te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), static_cast(bias_shape), DType::kFloat32); + te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), + static_cast(bias_shape), DType::kFloat32); } auto cu_seqlens_q_sizes = cu_seqlens_q.sizes().vec(); NVTEShapeWrapper cu_seqlens_q_shape{cu_seqlens_q_sizes}; auto cu_seqlens_kv_sizes = cu_seqlens_kv.sizes().vec(); NVTEShapeWrapper cu_seqlens_kv_shape{cu_seqlens_kv_sizes}; - te_cu_seqlens_q = - makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), static_cast(cu_seqlens_q_shape), DType::kInt32); - te_cu_seqlens_kv = - makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), static_cast(cu_seqlens_kv_shape), DType::kInt32); + te_cu_seqlens_q = makeTransformerEngineTensor( + cu_seqlens_q.data_ptr(), static_cast(cu_seqlens_q_shape), DType::kInt32); + te_cu_seqlens_kv = makeTransformerEngineTensor( + cu_seqlens_kv.data_ptr(), static_cast(cu_seqlens_kv_shape), DType::kInt32); if ((cu_seqlens_q_padded.has_value()) && (cu_seqlens_kv_padded.has_value())) { auto cu_seqlens_q_padded_sizes = cu_seqlens_q_padded.value().sizes().vec(); NVTEShapeWrapper cu_seqlens_q_padded_shape{cu_seqlens_q_padded_sizes}; auto cu_seqlens_kv_padded_sizes = cu_seqlens_kv_padded.value().sizes().vec(); NVTEShapeWrapper cu_seqlens_kv_padded_shape{cu_seqlens_kv_padded_sizes}; - te_cu_seqlens_q_padded = makeTransformerEngineTensor(cu_seqlens_q_padded.value().data_ptr(), - static_cast(cu_seqlens_q_padded_shape), DType::kInt32); + te_cu_seqlens_q_padded = makeTransformerEngineTensor( + cu_seqlens_q_padded.value().data_ptr(), static_cast(cu_seqlens_q_padded_shape), + DType::kInt32); te_cu_seqlens_kv_padded = makeTransformerEngineTensor( - cu_seqlens_kv_padded.value().data_ptr(), static_cast(cu_seqlens_kv_padded_shape), DType::kInt32); + cu_seqlens_kv_padded.value().data_ptr(), + static_cast(cu_seqlens_kv_padded_shape), DType::kInt32); } NVTEShape default_scale_inv_shape; default_scale_inv_shape.ndim = 1; @@ -193,12 +196,12 @@ std::vector fused_attn_fwd( NVTEShapeWrapper page_table_k_shape{page_table_k_sizes}; auto page_table_v_sizes = page_table_v.value().sizes().vec(); NVTEShapeWrapper page_table_v_shape{page_table_v_sizes}; - te_page_table_k = - makeTransformerEngineTensor(page_table_k.value().data_ptr(), static_cast(page_table_k_shape), - DType::kInt32, nullptr, nullptr, nullptr, default_scale_inv_shape); - te_page_table_v = - makeTransformerEngineTensor(page_table_v.value().data_ptr(), static_cast(page_table_v_shape), - DType::kInt32, nullptr, nullptr, nullptr, default_scale_inv_shape); + te_page_table_k = makeTransformerEngineTensor( + page_table_k.value().data_ptr(), static_cast(page_table_k_shape), + DType::kInt32, nullptr, nullptr, nullptr, default_scale_inv_shape); + te_page_table_v = makeTransformerEngineTensor( + page_table_v.value().data_ptr(), static_cast(page_table_v_shape), + DType::kInt32, nullptr, nullptr, nullptr, default_scale_inv_shape); } // softmax offset @@ -206,9 +209,9 @@ std::vector fused_attn_fwd( if ((softmax_type != NVTE_VANILLA_SOFTMAX) && (SoftmaxOffset.has_value())) { auto SoftmaxOffset_sizes = SoftmaxOffset.value().sizes().vec(); NVTEShapeWrapper SoftmaxOffset_shape{SoftmaxOffset_sizes}; - te_SoftmaxOffset = - makeTransformerEngineTensor(SoftmaxOffset.value().data_ptr(), static_cast(SoftmaxOffset_shape), - DType::kFloat32, nullptr, nullptr, nullptr, default_scale_inv_shape); + te_SoftmaxOffset = makeTransformerEngineTensor( + SoftmaxOffset.value().data_ptr(), static_cast(SoftmaxOffset_shape), + DType::kFloat32, nullptr, nullptr, nullptr, default_scale_inv_shape); } // extract rng seed and offset @@ -468,10 +471,12 @@ std::vector fused_attn_bwd( zero_scale_inv_shape.ndim = 1; zero_scale_inv_shape.data[0] = 0; TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv; - te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), static_cast(cu_seqlens_q_shape), - DType::kInt32, nullptr, nullptr, nullptr, zero_scale_inv_shape); - te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), static_cast(cu_seqlens_kv_shape), - DType::kInt32, nullptr, nullptr, nullptr, zero_scale_inv_shape); + te_cu_seqlens_q = makeTransformerEngineTensor( + cu_seqlens_q.data_ptr(), static_cast(cu_seqlens_q_shape), DType::kInt32, nullptr, + nullptr, nullptr, zero_scale_inv_shape); + te_cu_seqlens_kv = makeTransformerEngineTensor( + cu_seqlens_kv.data_ptr(), static_cast(cu_seqlens_kv_shape), DType::kInt32, + nullptr, nullptr, nullptr, zero_scale_inv_shape); TensorWrapper te_cu_seqlens_q_padded, te_cu_seqlens_kv_padded; if ((cu_seqlens_q_padded.has_value()) && (cu_seqlens_kv_padded.has_value())) { @@ -480,13 +485,11 @@ std::vector fused_attn_bwd( auto cu_seqlens_kv_padded_sizes = cu_seqlens_kv_padded.value().sizes().vec(); NVTEShapeWrapper cu_seqlens_kv_padded_shape{cu_seqlens_kv_padded_sizes}; te_cu_seqlens_q_padded = makeTransformerEngineTensor( - cu_seqlens_q_padded.value().data_ptr(), - static_cast(cu_seqlens_q_padded_shape), + cu_seqlens_q_padded.value().data_ptr(), static_cast(cu_seqlens_q_padded_shape), DType::kInt32); te_cu_seqlens_kv_padded = makeTransformerEngineTensor( cu_seqlens_kv_padded.value().data_ptr(), - static_cast(cu_seqlens_kv_padded_shape), - DType::kInt32); + static_cast(cu_seqlens_kv_padded_shape), DType::kInt32); } // convert auxiliary tensors from forward to NVTETensors @@ -500,7 +503,7 @@ std::vector fused_attn_bwd( NVTEBasicTensor temp_data = { Aux_CTX_Tensors[i].data_ptr(), static_cast(GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type())), - static_cast(tmp)}; + static_cast(tmp)}; nvte_set_tensor_param(&nvte_aux_tensor_pack.tensors[i], kNVTERowwiseData, &temp_data); } diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index af04328948b..95e06278dff 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -336,12 +336,13 @@ std::tuple, std::vector> bulk_allocate_fp tensor_cpp_list.emplace_back(makeTransformerEngineTensor( rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_data_list[i].data_ptr() : nullptr, - rowwise_usage ? static_cast(rowwise_data_shapes[i]) : zero_shape, - columnwise_usage ? static_cast(columnwise_data_shapes[i]) : zero_shape, fp8_dtype, nullptr, - nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr, + rowwise_usage ? static_cast(rowwise_data_shapes[i]) : zero_shape, + columnwise_usage ? static_cast(columnwise_data_shapes[i]) : zero_shape, + fp8_dtype, nullptr, nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr, - rowwise_usage ? static_cast(rowwise_scale_shapes[i]) : zero_shape, - columnwise_usage ? static_cast(columnwise_scale_shapes[i]) : zero_shape, scaling_mode)); + rowwise_usage ? static_cast(rowwise_scale_shapes[i]) : zero_shape, + columnwise_usage ? static_cast(columnwise_scale_shapes[i]) : zero_shape, + scaling_mode)); } return retval; @@ -485,12 +486,13 @@ std::tuple, std::vector> bulk_allocate_mx tensor_cpp_list.emplace_back(makeTransformerEngineTensor( rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_data_list[i].data_ptr() : nullptr, - rowwise_usage ? static_cast(rowwise_data_shapes[i]) : zero_shape, - columnwise_usage ? static_cast(columnwise_data_shapes[i]) : zero_shape, fp8_dtype, nullptr, - nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr, + rowwise_usage ? static_cast(rowwise_data_shapes[i]) : zero_shape, + columnwise_usage ? static_cast(columnwise_data_shapes[i]) : zero_shape, + fp8_dtype, nullptr, nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr, - rowwise_usage ? static_cast(rowwise_scale_shapes[i]) : zero_shape, - columnwise_usage ? static_cast(columnwise_scale_shapes[i]) : zero_shape, scaling_mode)); + rowwise_usage ? static_cast(rowwise_scale_shapes[i]) : zero_shape, + columnwise_usage ? static_cast(columnwise_scale_shapes[i]) : zero_shape, + scaling_mode)); } return retval; @@ -698,18 +700,19 @@ std::tuple, std::vector, bool> bulk_alloc auto tensor_wrapper = makeTransformerEngineTensor( rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_data_list[i].data_ptr() : nullptr, - rowwise_usage ? static_cast(rowwise_data_shapes[i]) : zero_shape, - columnwise_usage ? static_cast(columnwise_data_shapes[i]) : zero_shape, fp4_dtype, + rowwise_usage ? static_cast(rowwise_data_shapes[i]) : zero_shape, + columnwise_usage ? static_cast(columnwise_data_shapes[i]) : zero_shape, + fp4_dtype, /*amax_ptr=*/nullptr, /*scale_ptr=*/nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr, - rowwise_usage ? static_cast(rowwise_scale_shapes[i]) : zero_shape, - columnwise_usage ? static_cast(columnwise_scale_shapes[i]) : zero_shape, scaling_mode); - + rowwise_usage ? static_cast(rowwise_scale_shapes[i]) : zero_shape, + columnwise_usage ? static_cast(columnwise_scale_shapes[i]) : zero_shape, + scaling_mode); + // Set the amax rowwise and amax columnwise if available if (rowwise_usage) { - tensor_wrapper.set_amax(amax_rowwise_list[i].data_ptr(), DType::kFloat32, - amax_shape); + tensor_wrapper.set_amax(amax_rowwise_list[i].data_ptr(), DType::kFloat32, amax_shape); } if (columnwise_usage) { tensor_wrapper.set_columnwise_amax(amax_columnwise_list[i].data_ptr(), DType::kFloat32, @@ -780,8 +783,8 @@ static StochasticRngStateResources setup_stochastic_rounding_rng_states_helper( NVTEShape rng_state_shape; rng_state_shape.ndim = 1; rng_state_shape.data[0] = 2; - res.te_rng_state_list.push_back(makeTransformerEngineTensor( - static_cast(rng_state_ptr), rng_state_shape, DType::kInt64)); + res.te_rng_state_list.push_back(makeTransformerEngineTensor(static_cast(rng_state_ptr), + rng_state_shape, DType::kInt64)); quant_config_list_rowwise[i].set_rng_state(res.te_rng_state_list[i].data()); quant_config_list_rowwise[i].set_stochastic_rounding(true); @@ -1146,7 +1149,8 @@ std::vector split_quantize(const at::Tensor &tensor, split_shape.push_back(input_shape[j]); } void *split_dptr = static_cast(input_dptr + dim0_offset * dim0_stride); - input_list.emplace_back(makeTransformerEngineTensor(split_dptr, static_cast(split_shape), input_dtype)); + input_list.emplace_back(makeTransformerEngineTensor( + split_dptr, static_cast(split_shape), input_dtype)); dim0_offset += split_sections[i]; } diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 07acd44170a..82c1ce1e7b6 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -82,8 +82,8 @@ bool checkGemmShape(const NVTEShape& expected, const NVTEShape& actual) { } // namespace detail -std::pair createOutputTensor(const NVTEShape& shape, - DType dtype, py::handle quantizer) { +std::pair createOutputTensor(const NVTEShape& shape, DType dtype, + py::handle quantizer) { std::unique_ptr my_quantizer = convert_quantizer(quantizer); return my_quantizer->create_tensor(shape, dtype); } @@ -170,8 +170,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans if (unfused_quantization_needed) { NoneQuantizer q{none}; - std::tie(unquantized_D_tensor, unquantized_out) = - q.create_tensor(D_shape, output_dtype); + std::tie(unquantized_D_tensor, unquantized_out) = q.create_tensor(D_shape, output_dtype); } TensorWrapper& out_tensor = unfused_quantization_needed ? unquantized_D_tensor : D_tensor; @@ -419,8 +418,8 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type, workspace_shape.data[0] = workspaceSize; auto te_pre_gelu_out = makeTransformerEngineTensor( pre_gelu_out.data_ptr(), gelu_shape, GetTransformerEngineDType(pre_gelu_out.scalar_type())); - auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(), - workspace_shape, DType::kByte); + auto te_workspace = + makeTransformerEngineTensor(workspace.data_ptr(), workspace_shape, DType::kByte); NVTE_SCOPED_GIL_RELEASE({ nvte_cublas_atomic_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(), @@ -589,8 +588,7 @@ std::optional> te_general_grouped_gemm( workspace_shape.ndim = 1; workspace_shape.data[0] = workspaceSize; for (size_t i = 0; i < workspace.size(); i++) { - auto wsp = makeTransformerEngineTensor(workspace[i].data_ptr(), - workspace_shape, DType::kByte); + auto wsp = makeTransformerEngineTensor(workspace[i].data_ptr(), workspace_shape, DType::kByte); te_workspace_vector.emplace_back(wsp.data()); te_workspace_wrappers.emplace_back(std::move(wsp)); } @@ -605,4 +603,4 @@ std::optional> te_general_grouped_gemm( return bias; } -} // namespace transformer_engine::pytorch \ No newline at end of file +} // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/permutation.cpp b/transformer_engine/pytorch/csrc/extensions/permutation.cpp index 189c7127731..b0654c326e7 100644 --- a/transformer_engine/pytorch/csrc/extensions/permutation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/permutation.cpp @@ -69,17 +69,11 @@ std::tuple> moe_permute_fwd( permuted_output_shape.data[0] = static_cast(permuted_output.size(0)); permuted_output_shape.data[1] = static_cast(permuted_output.size(1)); sorted_row_id_cu_shape.data[0] = static_cast(num_tokens * topK); - auto input_cu = makeTransformerEngineTensor( - input.data_ptr(), - input_shape, - dtype); + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), input_shape, dtype); auto permuted_output_cu = - makeTransformerEngineTensor(permuted_output.data_ptr(), - permuted_output_shape, - dtype); - auto sorted_row_id_cu = makeTransformerEngineTensor( - sorted_row_id_ptr, sorted_row_id_cu_shape, - DType::kInt32); + makeTransformerEngineTensor(permuted_output.data_ptr(), permuted_output_shape, dtype); + auto sorted_row_id_cu = + makeTransformerEngineTensor(sorted_row_id_ptr, sorted_row_id_cu_shape, DType::kInt32); auto row_id_map_cu = makeTransformerEngineTensor(row_id_map); nvte_permute(input_cu.data(), permuted_output_cu.data(), sorted_row_id_cu.data(), @@ -111,14 +105,9 @@ at::Tensor moe_unpermute_fwd(at::Tensor input, const DType dtype, at::Tensor row input_shape.data[1] = static_cast(input.size(1)); unpermuted_output_shape.data[0] = static_cast(unpermuted_output.size(0)); unpermuted_output_shape.data[1] = static_cast(num_cols); - auto input_cu = makeTransformerEngineTensor( - input.data_ptr(), - input_shape, - dtype); - auto unpermuted_output_cu = makeTransformerEngineTensor( - unpermuted_output.data_ptr(), - unpermuted_output_shape, - dtype); + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), input_shape, dtype); + auto unpermuted_output_cu = + makeTransformerEngineTensor(unpermuted_output.data_ptr(), unpermuted_output_shape, dtype); auto row_id_map_cu = makeTransformerEngineTensor(row_id_map); auto prob_cu = makeTransformerEngineTensor(prob); @@ -153,18 +142,9 @@ std::tuple moe_unpermute_bwd(at::Tensor input_bwd, at::T act_grad_shape.data[1] = static_cast(num_cols); input_fwd_shape.data[0] = static_cast(input_fwd.size(0)); input_fwd_shape.data[1] = static_cast(num_cols); - auto input_bwd_cu = makeTransformerEngineTensor( - input_bwd.data_ptr(), - input_bwd_shape, - dtype); - auto act_grad_cu = makeTransformerEngineTensor( - act_grad.data_ptr(), - act_grad_shape, - dtype); - auto input_fwd_cu = makeTransformerEngineTensor( - input_fwd.data_ptr(), - input_fwd_shape, - dtype); + auto input_bwd_cu = makeTransformerEngineTensor(input_bwd.data_ptr(), input_bwd_shape, dtype); + auto act_grad_cu = makeTransformerEngineTensor(act_grad.data_ptr(), act_grad_shape, dtype); + auto input_fwd_cu = makeTransformerEngineTensor(input_fwd.data_ptr(), input_fwd_shape, dtype); auto row_id_map_cu = makeTransformerEngineTensor(row_id_map); auto prob_cu = makeTransformerEngineTensor(prob); auto prob_grad_cu = makeTransformerEngineTensor(prob_grad); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 650847dd86d..1520541ade7 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -135,13 +135,13 @@ std::pair Float8Quantizer::create_tensor( at::Tensor scale_inv = at::empty(std::vector{1}, opts); return create_tensor(shape, dtype, std::nullopt, std::nullopt, std::move(scale_inv)); } -std::pair Float8Quantizer::create_tensor( - const NVTEShape& shape, DType dtype) const { - std::vector shape_vec; - for (size_t i = 0; i < shape.ndim; ++i) { - shape_vec.push_back(shape.data[i]); - } - return create_tensor(shape_vec, dtype); +std::pair Float8Quantizer::create_tensor(const NVTEShape& shape, + DType dtype) const { + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); + } + return create_tensor(shape_vec, dtype); } std::pair Float8Quantizer::create_tensor( @@ -351,11 +351,11 @@ void Float8CurrentScalingQuantizer::set_quantization_params(TensorWrapper* tenso } std::pair Float8CurrentScalingQuantizer::create_tensor( const NVTEShape& shape, DType dtype) const { - std::vector shape_vec; - for (size_t i = 0; i < shape.ndim; ++i) { - shape_vec.push_back(shape.data[i]); - } - return create_tensor(shape_vec, dtype); + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); + } + return create_tensor(shape_vec, dtype); } std::pair Float8CurrentScalingQuantizer::create_tensor( const std::vector& shape, DType dtype) const { @@ -597,13 +597,13 @@ Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quanti void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const {} -std::pair Float8BlockQuantizer::create_tensor( - const NVTEShape& shape, DType dtype) const { - std::vector shape_vec; - for (size_t i = 0; i < shape.ndim; ++i) { - shape_vec.push_back(shape.data[i]); - } - return create_tensor(shape_vec, dtype); +std::pair Float8BlockQuantizer::create_tensor(const NVTEShape& shape, + DType dtype) const { + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); + } + return create_tensor(shape_vec, dtype); } std::pair Float8BlockQuantizer::create_tensor( @@ -873,7 +873,7 @@ std::vector Float8BlockQuantizer::get_scale_shape(const std::vector MXFP8Quantizer::create_tensor( - const NVTEShape& shape, DType dtype) const { - std::vector shape_vec; - for (size_t i = 0; i < shape.ndim; ++i) { - shape_vec.push_back(shape.data[i]); - } - return create_tensor(shape_vec, dtype); +std::pair MXFP8Quantizer::create_tensor(const NVTEShape& shape, + DType dtype) const { + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); + } + return create_tensor(shape_vec, dtype); } std::pair MXFP8Quantizer::create_tensor(const std::vector& shape, @@ -1166,7 +1166,7 @@ std::vector MXFP8Quantizer::get_scale_shape(const std::vector& s } NVTEShapeWrapper MXFP8Quantizer::get_scale_shape(const NVTEShapeWrapper& shape, - bool columnwise) const { + bool columnwise) const { return get_scale_shape_impl(shape, columnwise); } @@ -1238,13 +1238,13 @@ void NVFP4Quantizer::set_quantization_params(TensorWrapper* tensor) const { tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), columnwise_data.shape); } -std::pair NVFP4Quantizer::create_tensor( - const NVTEShape& shape, DType dtype) const { - std::vector shape_vec; - for (size_t i = 0; i < shape.ndim; ++i) { - shape_vec.push_back(shape.data[i]); - } - return create_tensor(shape_vec, dtype); +std::pair NVFP4Quantizer::create_tensor(const NVTEShape& shape, + DType dtype) const { + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); + } + return create_tensor(shape_vec, dtype); } std::pair NVFP4Quantizer::create_tensor(const std::vector& shape, DType dtype) const { @@ -1785,7 +1785,7 @@ std::vector NVFP4Quantizer::get_scale_shape(const std::vector& s } NVTEShapeWrapper NVFP4Quantizer::get_scale_shape(const NVTEShapeWrapper& shape, - bool columnwise) const { + bool columnwise) const { return get_scale_shape_impl(shape, columnwise); }