diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 118bf19335..899f5fe5e6 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -129,7 +129,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.Atype = A.data.dtype; ret.A_scale_inv = A.scale_inv.dptr; ret.lda = is_A_transposed ? k : m; - if (!nvte_is_non_tn_fp8_gemm_supported() && !is_A_transposed) { + int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); + if (!is_nvte_non_tn_fp8_gemm_supported && !is_A_transposed) { // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. if (A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype)) { ret.A = A.columnwise_data.dptr; @@ -140,7 +141,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla } else { NVTE_CHECK(!is_fp8_dtype(ret.Atype), "Input A is missing column-wise usage"); } - } else if (nvte_is_non_tn_fp8_gemm_supported() && !A.has_data()) { + } else if (is_nvte_non_tn_fp8_gemm_supported && !A.has_data()) { // Blackwell supports any GEMM layout for FP8, so we can use column-wise/transposed // data with the mirrored transpose-flag if we don't have row-wise data. NVTE_CHECK(A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype), @@ -220,7 +221,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.Btype = B.data.dtype; ret.B_scale_inv = B.scale_inv.dptr; ret.ldb = is_B_transposed ? n : k; - if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) { + int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); + if (!is_nvte_non_tn_fp8_gemm_supported && is_B_transposed) { // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. if (B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype)) { ret.B = B.columnwise_data.dptr; @@ -231,7 +233,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla } else { NVTE_CHECK(!is_fp8_dtype(ret.Btype), "Input B is missing column-wise usage"); } - } else if (nvte_is_non_tn_fp8_gemm_supported() && !B.has_data()) { + } else if (is_nvte_non_tn_fp8_gemm_supported && !B.has_data()) { // Blackwell supports any GEMM layout for FP8, so we can use column-wise/transposed // data with the mirrored transpose-flag if we don't have row-wise data. NVTE_CHECK(B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype), diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 19cb646be2..26a07c707a 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -518,6 +518,66 @@ NVTEShape nvte_get_grouped_tensor_logical_shape(const NVTEGroupedTensor tensor); */ namespace transformer_engine { +/*! \class NVTEShapeWrapper + * \brief C++ wrapper for NVTEShape with container-like interface. + */ +class NVTEShapeWrapper { + private: + NVTEShape data; + + public: + // Default constructor + NVTEShapeWrapper() { data.ndim = 0; } + + // Constructor from NVTEShape (direct assignment by reference) + NVTEShapeWrapper(const NVTEShape &shape) { data = shape; } + + // Constructor from vector (creates a copy) + template + NVTEShapeWrapper(const std::vector &shape_vec) { + data.ndim = shape_vec.size(); + for (size_t i = 0; i < data.ndim; ++i) { + data.data[i] = static_cast(shape_vec[i]); + } + } + + operator NVTEShape &() { return data; } + operator const NVTEShape &() const { return data; } + + // Iterator support + size_t *begin() { return data.data; } + const size_t *begin() const { return data.data; } + size_t *end() { return data.data + data.ndim; } + const size_t *end() const { return data.data + data.ndim; } + + // Index access + size_t &operator[](size_t idx) { return data.data[idx]; } + const size_t &operator[](size_t idx) const { return data.data[idx]; } + + // Back access + size_t &back() { return data.data[data.ndim - 1]; } + const size_t &back() const { return data.data[data.ndim - 1]; } + + // Size access + size_t size() const { return data.ndim; } + bool empty() const { return data.ndim == 0; } + + // Container operations + void push_back(size_t value) { + if (data.ndim < 15) { + data.data[data.ndim++] = value; + } + } + + void clear() { data.ndim = 0; } + + void resize(size_t new_size) { + if (new_size <= 15) { + data.ndim = new_size; + } + } +}; + /*! \enum DType * \brief TE datatype. */ diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 4a140b4376..0c4c9456c6 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -961,7 +961,7 @@ void nvte_destroy_quantization_config(NVTEQuantizationConfig config) { } int nvte_is_non_tn_fp8_gemm_supported() { - int num_devices = transformer_engine::cuda::num_devices(); + static int num_devices = transformer_engine::cuda::num_devices(); static std::vector cache(num_devices, -1); static std::vector flags(num_devices); int device_id = transformer_engine::cuda::current_device(); diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index e054424dd4..3306073277 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -26,7 +26,9 @@ std::vector convert_shape_back_from_fp4(const std::vector& shape return ret; } -std::vector getTensorShape(const at::Tensor& t) { +NVTEShape getTensorShape(const at::Tensor& t) { return convertTorchShape(t.sizes()); } + +std::vector getTensorShapeVector(const at::Tensor& t) { std::vector shape; for (auto s : t.sizes()) { shape.push_back(s); @@ -112,17 +114,9 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( return transformer_engine::TensorWrapper(data_ptr, shape, type); } -transformer_engine::TensorWrapper makeTransformerEngineTensor( - void* data_ptr, const std::vector& shape, const transformer_engine::DType type) { - return transformer_engine::TensorWrapper(data_ptr, shape, type); -} - transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor) { transformer_engine::DType dtype = GetTransformerEngineDType(tensor.scalar_type()); - std::vector shape; - for (auto s : tensor.sizes()) { - shape.push_back(s); - } + NVTEShape shape = getTensorShape(tensor); return makeTransformerEngineTensor(tensor.data_ptr(), shape, dtype); } @@ -164,12 +158,15 @@ makeTransformerEngineTensorList(std::vector> at_tensor_l } transformer_engine::TensorWrapper makeTransformerEngineTensor( - void* data_ptr, const std::vector& shape, const transformer_engine::DType type, - void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, std::vector scale_inv_shape, + void* data_ptr, const NVTEShape& shape, const transformer_engine::DType type, void* amax_ptr, + void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape, NVTEScalingMode scaling_mode) { TensorWrapper ret(scaling_mode); ret.set_rowwise_data(data_ptr, type, shape); - const std::vector meta_shape{1}; + const size_t meta_shape_data[1] = {1}; + NVTEShape meta_shape; + meta_shape.ndim = 1; + meta_shape.data[0] = 1; ret.set_amax(amax_ptr, DType::kFloat32, meta_shape); ret.set_scale(scale_ptr, DType::kFloat32, meta_shape); auto scale_inv_dtype = @@ -179,15 +176,17 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( } transformer_engine::TensorWrapper makeTransformerEngineTensor( - void* data_ptr, void* columnwise_data_ptr, const std::vector& shape, - const std::vector& columnwise_shape, const transformer_engine::DType type, - void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr, - const std::vector& scale_inv_shape, - const std::vector& columnwise_scale_inv_shape, NVTEScalingMode scaling_mode) { + void* data_ptr, void* columnwise_data_ptr, const NVTEShape& shape, + const NVTEShape& columnwise_shape, const transformer_engine::DType type, void* amax_ptr, + void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr, + const NVTEShape& scale_inv_shape, const NVTEShape& columnwise_scale_inv_shape, + NVTEScalingMode scaling_mode) { TensorWrapper ret(scaling_mode); ret.set_rowwise_data(data_ptr, type, shape); ret.set_columnwise_data(columnwise_data_ptr, type, columnwise_shape); - const std::vector meta_shape{1}; + NVTEShape meta_shape; + meta_shape.ndim = 1; + meta_shape.data[0] = 1; ret.set_amax(amax_ptr, DType::kFloat32, meta_shape); ret.set_scale(scale_ptr, DType::kFloat32, meta_shape); auto scale_inv_dtype = (scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 @@ -230,6 +229,9 @@ template size_t product(const std::vector& shape); template int64_t product(const std::vector& shape); size_t product(const NVTEShape& shape, size_t begin, size_t end) { + if (end == -1) { + end = shape.ndim; + } NVTE_CHECK(begin <= end && end <= shape.ndim, "Attempted to access entries ", begin, " to ", end, " in a shape with ", shape.ndim, " entries"); size_t ret = 1; diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 978bee52dc..39b73c9644 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -102,7 +102,8 @@ class Quantizer { /*! @brief Construct a tensor with uninitialized data */ virtual std::pair create_tensor(const std::vector& shape, DType dtype) const = 0; - + virtual std::pair create_tensor(const NVTEShape& shape, + DType dtype) const = 0; /*! @brief Convert a PyTorch tensor into a Transformer Engine C++ tensor * * The PyTorch tensor's attributes are modified to match the @@ -141,6 +142,13 @@ class NoneQuantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype, at::Tensor data) const; + std::pair create_tensor(const NVTEShape& shape, + DType dtype) const override; + + /*! @brief Construct a tensor with pre-initialized data */ + std::pair create_tensor(const NVTEShape& shape, DType dtype, + at::Tensor data) const; + std::pair convert_and_update_tensor(py::object tensor) const override; void quantize(const TensorWrapper& input, TensorWrapper& out, @@ -162,7 +170,8 @@ class Float8Quantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; - + std::pair create_tensor(const NVTEShape& shape, + DType dtype) const override; /*! @brief Construct a tensor with pre-initialized data */ std::pair create_tensor(const std::vector& shape, DType dtype, std::optional data, @@ -194,7 +203,8 @@ class Float8CurrentScalingQuantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; - + std::pair create_tensor(const NVTEShape& shape, + DType dtype) const override; /*! @brief Construct an unquantized tensor that shares the quantizer's amax pointer. * * The amax is zeroed out. Most TE kernels that output amax expect @@ -253,6 +263,8 @@ class Float8BlockQuantizer : public Quantizer { // and optionally columnwise usage. std::pair create_tensor(const std::vector& shape, DType dtype) const override; + std::pair create_tensor(const NVTEShape& shape, + DType dtype) const override; std::pair convert_and_update_tensor(py::object shape) const override; @@ -260,6 +272,11 @@ class Float8BlockQuantizer : public Quantizer { const std::optional& noop_flag = std::nullopt) override; std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; + NVTEShapeWrapper get_scale_shape(const NVTEShapeWrapper& shape, bool columnwise) const; + + private: + template + ShapeT get_scale_shape_impl(const ShapeT& shape, bool columnwise) const; }; class MXFP8Quantizer : public Quantizer { @@ -274,6 +291,8 @@ class MXFP8Quantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; + std::pair create_tensor(const NVTEShape& shape, + DType dtype) const override; std::pair convert_and_update_tensor(py::object shape) const override; @@ -281,6 +300,11 @@ class MXFP8Quantizer : public Quantizer { const std::optional& noop_flag = std::nullopt) override; std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; + NVTEShapeWrapper get_scale_shape(const NVTEShapeWrapper& shape, bool columnwise) const; + + private: + template + ShapeT get_scale_shape_impl(const ShapeT& shape, bool columnwise) const; }; class NVFP4Quantizer : public Quantizer { @@ -308,7 +332,8 @@ class NVFP4Quantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; - + std::pair create_tensor(const NVTEShape& shape, + DType dtype) const override; /*! @brief Construct an unquantized tensor that shares NVFP4 tensor's amax pointer * * The amax is zeroed out. Most TE kernels that output amax expect @@ -331,15 +356,20 @@ class NVFP4Quantizer : public Quantizer { void quantize_with_amax(TensorWrapper& input, TensorWrapper& out); std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; + NVTEShapeWrapper get_scale_shape(const NVTEShapeWrapper& shape, bool columnwise) const; private: + template + ShapeT get_scale_shape_impl(const ShapeT& shape, bool columnwise) const; void quantize_impl(const TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag, bool compute_amax); }; std::unique_ptr convert_quantizer(py::handle quantizer); -std::vector getTensorShape(const at::Tensor& t); +NVTEShape getTensorShape(const at::Tensor& t); + +std::vector getTensorShapeVector(const at::Tensor& t); transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, const std::string& fp8_recipe); @@ -423,21 +453,16 @@ inline transformer_engine::DType GetTransformerEngineDType(int DType_value) { return static_cast(DType_value); } -transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr, - const std::vector& shape, - const transformer_engine::DType type); - transformer_engine::TensorWrapper makeTransformerEngineTensor( - void* data_ptr, const std::vector& shape, const transformer_engine::DType type, - void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, std::vector scale_inv_shape = {1}, + void* data_ptr, const NVTEShape& shape, const transformer_engine::DType type, void* amax_ptr, + void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape, NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); transformer_engine::TensorWrapper makeTransformerEngineTensor( - void* data_ptr, void* columnwise_data_ptr, const std::vector& shape, - const std::vector& columnwise_shape, const transformer_engine::DType type, - void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr, - const std::vector& scale_inv_shape = {1}, - const std::vector& columnwise_scale_inv_shape = {1}, + void* data_ptr, void* columnwise_data_ptr, const NVTEShape& shape, + const NVTEShape& columnwise_shape, const transformer_engine::DType type, void* amax_ptr, + void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr, + const NVTEShape& scale_inv_shape, const NVTEShape& columnwise_scale_inv_shape, NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr, @@ -459,7 +484,7 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( template T product(const std::vector& shape); -size_t product(const NVTEShape& shape, size_t begin, size_t end); +size_t product(const NVTEShape& shape, size_t begin = 0, size_t end = -1); std::vector nvte_shape_to_vector(const NVTEShape& nvte_shape); diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 2480d9aba9..bfa989c26c 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -163,52 +163,55 @@ std::vector fused_attn_fwd( } if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) { auto bias_sizes = Bias.value().sizes().vec(); - std::vector bias_shape{bias_sizes.begin(), bias_sizes.end()}; - te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), bias_shape, DType::kFloat32); + NVTEShapeWrapper bias_shape{bias_sizes}; + te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), + static_cast(bias_shape), DType::kFloat32); } auto cu_seqlens_q_sizes = cu_seqlens_q.sizes().vec(); - std::vector cu_seqlens_q_shape{cu_seqlens_q_sizes.begin(), cu_seqlens_q_sizes.end()}; + NVTEShapeWrapper cu_seqlens_q_shape{cu_seqlens_q_sizes}; auto cu_seqlens_kv_sizes = cu_seqlens_kv.sizes().vec(); - std::vector cu_seqlens_kv_shape{cu_seqlens_kv_sizes.begin(), cu_seqlens_kv_sizes.end()}; - te_cu_seqlens_q = - makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), cu_seqlens_q_shape, DType::kInt32); - te_cu_seqlens_kv = - makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape, DType::kInt32); + NVTEShapeWrapper cu_seqlens_kv_shape{cu_seqlens_kv_sizes}; + te_cu_seqlens_q = makeTransformerEngineTensor( + cu_seqlens_q.data_ptr(), static_cast(cu_seqlens_q_shape), DType::kInt32); + te_cu_seqlens_kv = makeTransformerEngineTensor( + cu_seqlens_kv.data_ptr(), static_cast(cu_seqlens_kv_shape), DType::kInt32); if ((cu_seqlens_q_padded.has_value()) && (cu_seqlens_kv_padded.has_value())) { auto cu_seqlens_q_padded_sizes = cu_seqlens_q_padded.value().sizes().vec(); - std::vector cu_seqlens_q_padded_shape{cu_seqlens_q_padded_sizes.begin(), - cu_seqlens_q_padded_sizes.end()}; + NVTEShapeWrapper cu_seqlens_q_padded_shape{cu_seqlens_q_padded_sizes}; auto cu_seqlens_kv_padded_sizes = cu_seqlens_kv_padded.value().sizes().vec(); - std::vector cu_seqlens_kv_padded_shape{cu_seqlens_kv_padded_sizes.begin(), - cu_seqlens_kv_padded_sizes.end()}; - te_cu_seqlens_q_padded = makeTransformerEngineTensor(cu_seqlens_q_padded.value().data_ptr(), - cu_seqlens_q_padded_shape, DType::kInt32); + NVTEShapeWrapper cu_seqlens_kv_padded_shape{cu_seqlens_kv_padded_sizes}; + te_cu_seqlens_q_padded = makeTransformerEngineTensor( + cu_seqlens_q_padded.value().data_ptr(), static_cast(cu_seqlens_q_padded_shape), + DType::kInt32); te_cu_seqlens_kv_padded = makeTransformerEngineTensor( - cu_seqlens_kv_padded.value().data_ptr(), cu_seqlens_kv_padded_shape, DType::kInt32); + cu_seqlens_kv_padded.value().data_ptr(), + static_cast(cu_seqlens_kv_padded_shape), DType::kInt32); } - + NVTEShape default_scale_inv_shape; + default_scale_inv_shape.ndim = 1; + default_scale_inv_shape.data[0] = 1; if ((page_table_k.has_value()) && (page_table_v.has_value())) { auto page_table_k_sizes = page_table_k.value().sizes().vec(); - std::vector page_table_k_shape{page_table_k_sizes.begin(), page_table_k_sizes.end()}; + NVTEShapeWrapper page_table_k_shape{page_table_k_sizes}; auto page_table_v_sizes = page_table_v.value().sizes().vec(); - std::vector page_table_v_shape{page_table_v_sizes.begin(), page_table_v_sizes.end()}; - te_page_table_k = - makeTransformerEngineTensor(page_table_k.value().data_ptr(), page_table_k_shape, - DType::kInt32, nullptr, nullptr, nullptr); - te_page_table_v = - makeTransformerEngineTensor(page_table_v.value().data_ptr(), page_table_v_shape, - DType::kInt32, nullptr, nullptr, nullptr); + NVTEShapeWrapper page_table_v_shape{page_table_v_sizes}; + te_page_table_k = makeTransformerEngineTensor( + page_table_k.value().data_ptr(), static_cast(page_table_k_shape), + DType::kInt32, nullptr, nullptr, nullptr, default_scale_inv_shape); + te_page_table_v = makeTransformerEngineTensor( + page_table_v.value().data_ptr(), static_cast(page_table_v_shape), + DType::kInt32, nullptr, nullptr, nullptr, default_scale_inv_shape); } // softmax offset TensorWrapper te_SoftmaxOffset; if ((softmax_type != NVTE_VANILLA_SOFTMAX) && (SoftmaxOffset.has_value())) { auto SoftmaxOffset_sizes = SoftmaxOffset.value().sizes().vec(); - std::vector SoftmaxOffset_shape{SoftmaxOffset_sizes.begin(), SoftmaxOffset_sizes.end()}; - te_SoftmaxOffset = - makeTransformerEngineTensor(SoftmaxOffset.value().data_ptr(), SoftmaxOffset_shape, - DType::kFloat32, nullptr, nullptr, nullptr); + NVTEShapeWrapper SoftmaxOffset_shape{SoftmaxOffset_sizes}; + te_SoftmaxOffset = makeTransformerEngineTensor( + SoftmaxOffset.value().data_ptr(), static_cast(SoftmaxOffset_shape), + DType::kFloat32, nullptr, nullptr, nullptr, default_scale_inv_shape); } // extract rng seed and offset @@ -461,27 +464,32 @@ std::vector fused_attn_bwd( // create cu_seqlens tensorwrappers auto cu_seqlens_q_sizes = cu_seqlens_q.sizes().vec(); - std::vector cu_seqlens_q_shape{cu_seqlens_q_sizes.begin(), cu_seqlens_q_sizes.end()}; + NVTEShapeWrapper cu_seqlens_q_shape{cu_seqlens_q_sizes}; auto cu_seqlens_kv_sizes = cu_seqlens_kv.sizes().vec(); - std::vector cu_seqlens_kv_shape{cu_seqlens_kv_sizes.begin(), cu_seqlens_kv_sizes.end()}; + NVTEShapeWrapper cu_seqlens_kv_shape{cu_seqlens_kv_sizes}; + NVTEShape zero_scale_inv_shape; + zero_scale_inv_shape.ndim = 1; + zero_scale_inv_shape.data[0] = 0; TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv; - te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), cu_seqlens_q_shape, - DType::kInt32, nullptr, nullptr, nullptr); - te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape, - DType::kInt32, nullptr, nullptr, nullptr); + te_cu_seqlens_q = makeTransformerEngineTensor( + cu_seqlens_q.data_ptr(), static_cast(cu_seqlens_q_shape), DType::kInt32, nullptr, + nullptr, nullptr, zero_scale_inv_shape); + te_cu_seqlens_kv = makeTransformerEngineTensor( + cu_seqlens_kv.data_ptr(), static_cast(cu_seqlens_kv_shape), DType::kInt32, + nullptr, nullptr, nullptr, zero_scale_inv_shape); TensorWrapper te_cu_seqlens_q_padded, te_cu_seqlens_kv_padded; if ((cu_seqlens_q_padded.has_value()) && (cu_seqlens_kv_padded.has_value())) { auto cu_seqlens_q_padded_sizes = cu_seqlens_q_padded.value().sizes().vec(); - std::vector cu_seqlens_q_padded_shape{cu_seqlens_q_padded_sizes.begin(), - cu_seqlens_q_padded_sizes.end()}; + NVTEShapeWrapper cu_seqlens_q_padded_shape{cu_seqlens_q_padded_sizes}; auto cu_seqlens_kv_padded_sizes = cu_seqlens_kv_padded.value().sizes().vec(); - std::vector cu_seqlens_kv_padded_shape{cu_seqlens_kv_padded_sizes.begin(), - cu_seqlens_kv_padded_sizes.end()}; - te_cu_seqlens_q_padded = makeTransformerEngineTensor(cu_seqlens_q_padded.value().data_ptr(), - cu_seqlens_q_padded_shape, DType::kInt32); + NVTEShapeWrapper cu_seqlens_kv_padded_shape{cu_seqlens_kv_padded_sizes}; + te_cu_seqlens_q_padded = makeTransformerEngineTensor( + cu_seqlens_q_padded.value().data_ptr(), static_cast(cu_seqlens_q_padded_shape), + DType::kInt32); te_cu_seqlens_kv_padded = makeTransformerEngineTensor( - cu_seqlens_kv_padded.value().data_ptr(), cu_seqlens_kv_padded_shape, DType::kInt32); + cu_seqlens_kv_padded.value().data_ptr(), + static_cast(cu_seqlens_kv_padded_shape), DType::kInt32); } // convert auxiliary tensors from forward to NVTETensors @@ -490,12 +498,12 @@ std::vector fused_attn_bwd( nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size(); for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { const std::vector &signed_shape = Aux_CTX_Tensors[i].sizes().vec(); - const std::vector tmp(signed_shape.begin(), signed_shape.end()); + NVTEShapeWrapper tmp(signed_shape); NVTEBasicTensor temp_data = { Aux_CTX_Tensors[i].data_ptr(), static_cast(GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type())), - nvte_make_shape(tmp.data(), tmp.size())}; + static_cast(tmp)}; nvte_set_tensor_param(&nvte_aux_tensor_pack.tensors[i], kNVTERowwiseData, &temp_data); } diff --git a/transformer_engine/pytorch/csrc/extensions/bias.cpp b/transformer_engine/pytorch/csrc/extensions/bias.cpp index b0435d2723..c3e89ed085 100644 --- a/transformer_engine/pytorch/csrc/extensions/bias.cpp +++ b/transformer_engine/pytorch/csrc/extensions/bias.cpp @@ -26,7 +26,7 @@ std::vector bgrad_quantize(const at::Tensor &grad_output, py::handle // Grad output tensor auto grad_output_torch = grad_output.contiguous(); const TensorWrapper &grad_output_nvte = makeTransformerEngineTensor(grad_output_torch); - const auto shape = getTensorShape(grad_output_torch); + const auto shape = getTensorShapeVector(grad_output_torch); auto grad_output_dtype = GetTransformerEngineDType(grad_output_torch.scalar_type()); // Construct grad bias tensor @@ -116,11 +116,11 @@ std::vector dact_dbias( // Grad output and activation input tensors grad_output_torch = grad_output_torch.contiguous(); const TensorWrapper &grad_output_nvte = makeTransformerEngineTensor(grad_output_torch); - const auto output_shape = getTensorShape(grad_output_torch); + const auto output_shape = getTensorShapeVector(grad_output_torch); auto grad_output_dtype = GetTransformerEngineDType(grad_output_torch.scalar_type()); act_input_torch = act_input_torch.contiguous(); const TensorWrapper &act_input_nvte = makeTransformerEngineTensor(act_input_torch); - const auto input_shape = getTensorShape(act_input_torch); + const auto input_shape = getTensorShapeVector(act_input_torch); // Construct tensors auto quantizer_cpp = convert_quantizer(quantizer_py); diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index aa9d800c7b..95e06278df 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -56,9 +56,8 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob TensorWrapper output_cpp; py::object output_py; if (output.is_none()) { - const auto shape = get_tensor_shape(input_cpp); const auto fake_dtype = input_cpp.dtype(); - std::tie(output_cpp, output_py) = quantizer_cpp->create_tensor(shape, fake_dtype); + std::tie(output_cpp, output_py) = quantizer_cpp->create_tensor(input_cpp.shape(), fake_dtype); } else { std::tie(output_cpp, output_py) = quantizer_cpp->convert_and_update_tensor(output); } @@ -180,8 +179,8 @@ std::vector multi_tensor_quantize(const std::vector &ten const auto input_dtype = GetTransformerEngineDType(input_py.scalar_type()); // Construct output tensor - std::vector output_shape(input_shape.data, input_shape.data + input_shape.ndim); - auto [output_cpp, output_py] = quantizer_cpp_list[i]->create_tensor(output_shape, input_dtype); + // std::vector output_shape(input_shape.data, input_shape.data + input_shape.ndim); + auto [output_cpp, output_py] = quantizer_cpp_list[i]->create_tensor(input_shape, input_dtype); output_cpp_list.emplace_back(std::move(output_cpp)); output_py_list.emplace_back(std::move(output_py)); } @@ -195,7 +194,7 @@ std::vector multi_tensor_quantize(const std::vector &ten namespace { std::tuple, std::vector> bulk_allocate_fp8_blockwise_tensors( - std::vector> &shape_list, std::vector &quantizer_py_list, + std::vector &shape_list, std::vector &quantizer_py_list, std::vector &quantizer_cpp_list) { init_extension(); std::tuple, std::vector> retval; @@ -220,7 +219,7 @@ std::tuple, std::vector> bulk_allocate_fp // Helper function to construct tensor view // Note: Deleter holds a shared_ptr for the buffer, so the buffer // will survive until all views are deleted. - auto make_torch_view = [](std::shared_ptr &buffer, const std::vector &shape, + auto make_torch_view = [](std::shared_ptr &buffer, const NVTEShapeWrapper &shape, size_t offset, at::ScalarType dtype) -> at::Tensor { std::vector shape_int64(shape.begin(), shape.end()); bool is_empty_shape = product(shape) == 0; @@ -235,13 +234,13 @@ std::tuple, std::vector> bulk_allocate_fp // Allocate row-wise data std::vector rowwise_data_list, rowwise_scale_list; - std::vector> rowwise_data_shapes, rowwise_scale_shapes; + std::vector rowwise_data_shapes, rowwise_scale_shapes; if (rowwise_usage) { // Tensor sizes for (size_t i = 0; i < num_tensors; ++i) { rowwise_data_shapes.emplace_back(shape_list[i]); rowwise_scale_shapes.emplace_back( - quantizer_cpp_list[i]->get_scale_shape(shape_list[i], false)); + quantizer_cpp_list[i]->get_scale_shape(rowwise_data_shapes[i], false)); } // Offsets in full buffer @@ -273,7 +272,7 @@ std::tuple, std::vector> bulk_allocate_fp // Allocate column-wise data std::vector columnwise_data_list, columnwise_scale_list; - std::vector> columnwise_data_shapes, columnwise_scale_shapes; + std::vector columnwise_data_shapes, columnwise_scale_shapes; if (columnwise_usage) { // Tensor sizes for (size_t i = 0; i < num_tensors; ++i) { @@ -330,24 +329,27 @@ std::tuple, std::vector> bulk_allocate_fp tensor_py_list.emplace_back(Float8BlockwiseQTensorClass( rowwise_data, rowwise_scale, columnwise_data, columnwise_scale, fp8_dtype, quantizer_py_list[i], is_2D_scaled, Float8BlockScaleTensorFormat::GEMM_READY)); - + NVTEShape zero_shape; + zero_shape.ndim = 1; + zero_shape.data[0] = 0; // Construct C++ tensor tensor_cpp_list.emplace_back(makeTransformerEngineTensor( rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_data_list[i].data_ptr() : nullptr, - rowwise_usage ? rowwise_data_shapes[i] : std::vector{0}, - columnwise_usage ? columnwise_data_shapes[i] : std::vector{0}, fp8_dtype, nullptr, - nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr, + rowwise_usage ? static_cast(rowwise_data_shapes[i]) : zero_shape, + columnwise_usage ? static_cast(columnwise_data_shapes[i]) : zero_shape, + fp8_dtype, nullptr, nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr, - rowwise_usage ? rowwise_scale_shapes[i] : std::vector{0}, - columnwise_usage ? columnwise_scale_shapes[i] : std::vector{0}, scaling_mode)); + rowwise_usage ? static_cast(rowwise_scale_shapes[i]) : zero_shape, + columnwise_usage ? static_cast(columnwise_scale_shapes[i]) : zero_shape, + scaling_mode)); } return retval; } std::tuple, std::vector> bulk_allocate_mxfp8_tensors( - std::vector> &shape_list, std::vector &quantizer_py_list, + std::vector &shape_list, std::vector &quantizer_py_list, std::vector &quantizer_cpp_list) { init_extension(); std::tuple, std::vector> retval; @@ -371,7 +373,7 @@ std::tuple, std::vector> bulk_allocate_mx // Helper function to construct tensor view // Note: Deleter holds a shared_ptr for the buffer, so the buffer // will survive until all views are deleted. - auto make_torch_view = [](std::shared_ptr &buffer, const std::vector &shape, + auto make_torch_view = [](std::shared_ptr &buffer, const NVTEShapeWrapper &shape, size_t offset, at::ScalarType dtype) -> at::Tensor { std::vector shape_int64(shape.begin(), shape.end()); bool is_empty_shape = product(shape) == 0; @@ -386,13 +388,13 @@ std::tuple, std::vector> bulk_allocate_mx // Allocate row-wise data std::vector rowwise_data_list, rowwise_scale_list; - std::vector> rowwise_data_shapes, rowwise_scale_shapes; + std::vector rowwise_data_shapes, rowwise_scale_shapes; if (rowwise_usage) { // Tensor sizes for (size_t i = 0; i < num_tensors; ++i) { rowwise_data_shapes.emplace_back(shape_list[i]); rowwise_scale_shapes.emplace_back( - quantizer_cpp_list[i]->get_scale_shape(shape_list[i], false)); + quantizer_cpp_list[i]->get_scale_shape(rowwise_data_shapes[i], false)); } // Offsets in full buffer @@ -424,7 +426,7 @@ std::tuple, std::vector> bulk_allocate_mx // Allocate column-wise data std::vector columnwise_data_list, columnwise_scale_list; - std::vector> columnwise_data_shapes, columnwise_scale_shapes; + std::vector columnwise_data_shapes, columnwise_scale_shapes; if (columnwise_usage) { // Tensor sizes for (size_t i = 0; i < num_tensors; ++i) { @@ -477,17 +479,20 @@ std::tuple, std::vector> bulk_allocate_mx tensor_py_list.emplace_back(MXFP8TensorClass(rowwise_data, rowwise_scale, columnwise_data, columnwise_scale, fp8_dtype, quantizer_py_list[i])); - + NVTEShape zero_shape; + zero_shape.ndim = 1; + zero_shape.data[0] = 0; // Construct C++ tensor tensor_cpp_list.emplace_back(makeTransformerEngineTensor( rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_data_list[i].data_ptr() : nullptr, - rowwise_usage ? rowwise_data_shapes[i] : std::vector{0}, - columnwise_usage ? columnwise_data_shapes[i] : std::vector{0}, fp8_dtype, nullptr, - nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr, + rowwise_usage ? static_cast(rowwise_data_shapes[i]) : zero_shape, + columnwise_usage ? static_cast(columnwise_data_shapes[i]) : zero_shape, + fp8_dtype, nullptr, nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr, - rowwise_usage ? rowwise_scale_shapes[i] : std::vector{0}, - columnwise_usage ? columnwise_scale_shapes[i] : std::vector{0}, scaling_mode)); + rowwise_usage ? static_cast(rowwise_scale_shapes[i]) : zero_shape, + columnwise_usage ? static_cast(columnwise_scale_shapes[i]) : zero_shape, + scaling_mode)); } return retval; @@ -497,7 +502,7 @@ std::tuple, std::vector> bulk_allocate_mx // layout: [fp4_data0, ..., fp4_dataN, fp8_scaling0, ..., fp8_scalingN, amax0, ..., amaxN] // amax buffer will be zeroed out by later amax kernels, so we can use empty to allocate std::tuple, std::vector, bool> bulk_allocate_nvfp4_tensors( - std::vector> &shape_list, std::vector &quantizer_py_list, + std::vector &shape_list, std::vector &quantizer_py_list, std::vector &quantizer_cpp_list) { init_extension(); std::tuple, std::vector, bool> retval; @@ -522,7 +527,7 @@ std::tuple, std::vector, bool> bulk_alloc // Helper function to construct tensor view // Note: Deleter holds a shared_ptr for the buffer, so the buffer // will survive until all views are deleted. - auto make_torch_view = [](std::shared_ptr &buffer, const std::vector &shape, + auto make_torch_view = [](std::shared_ptr &buffer, const NVTEShapeWrapper &shape, size_t offset, at::ScalarType dtype) -> at::Tensor { std::vector shape_int64(shape.begin(), shape.end()); bool is_empty_shape = product(shape) == 0; @@ -535,9 +540,9 @@ std::tuple, std::vector, bool> bulk_alloc at::device(at::kCUDA).dtype(dtype)); }; - // Lambda function for converting std::vector shape to NVFP4 shape (last dim divided by 2) - auto to_fp4_shape = [](const std::vector &shape) { - std::vector fp4_shape(shape.begin(), shape.end()); + // Lambda function for converting NVTEShapeWrapper shape to NVFP4 shape (last dim divided by 2) + auto to_fp4_shape = [](const NVTEShapeWrapper &shape) { + NVTEShapeWrapper fp4_shape(shape); if (!fp4_shape.empty()) { fp4_shape.back() /= 2; } @@ -546,13 +551,13 @@ std::tuple, std::vector, bool> bulk_alloc // Allocate row-wise data std::vector rowwise_data_list, rowwise_scale_list, amax_rowwise_list; - std::vector> rowwise_data_shapes, rowwise_scale_shapes; + std::vector rowwise_data_shapes, rowwise_scale_shapes; if (rowwise_usage) { // Tensor sizes for (size_t i = 0; i < num_tensors; ++i) { rowwise_data_shapes.emplace_back(shape_list[i]); rowwise_scale_shapes.emplace_back( - quantizer_cpp_list[i]->get_scale_shape(shape_list[i], false)); + quantizer_cpp_list[i]->get_scale_shape(rowwise_data_shapes[i], false)); } // Offsets in full buffer @@ -587,7 +592,9 @@ std::tuple, std::vector, bool> bulk_alloc // Allocate full buffer auto buffer = std::make_shared( at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); - + NVTEShape amax_shape; + amax_shape.ndim = 1; + amax_shape.data[0] = 1; // Construct tensor views for (size_t i = 0; i < num_tensors; ++i) { rowwise_data_list.emplace_back(make_torch_view(buffer, to_fp4_shape(rowwise_data_shapes[i]), @@ -595,13 +602,13 @@ std::tuple, std::vector, bool> bulk_alloc rowwise_scale_list.emplace_back( make_torch_view(buffer, rowwise_scale_shapes[i], scale_offsets[i], torch::kUInt8)); amax_rowwise_list.emplace_back( - make_torch_view(buffer, std::vector{1}, amax_offsets[i], torch::kFloat32)); + make_torch_view(buffer, amax_shape, amax_offsets[i], torch::kFloat32)); } } // Allocate column-wise data std::vector columnwise_data_list, columnwise_scale_list, amax_columnwise_list; - std::vector> columnwise_data_shapes, columnwise_scale_shapes; + std::vector columnwise_data_shapes, columnwise_scale_shapes; if (columnwise_usage) { // Tensor sizes for (size_t i = 0; i < num_tensors; ++i) { @@ -649,7 +656,9 @@ std::tuple, std::vector, bool> bulk_alloc // Allocate full buffer auto buffer = std::make_shared( at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); - + NVTEShape amax_shape; + amax_shape.ndim = 1; + amax_shape.data[0] = 1; // Construct tensor views for (size_t i = 0; i < num_tensors; ++i) { columnwise_data_list.emplace_back(make_torch_view( @@ -657,7 +666,7 @@ std::tuple, std::vector, bool> bulk_alloc columnwise_scale_list.emplace_back( make_torch_view(buffer, columnwise_scale_shapes[i], scale_offsets[i], torch::kUInt8)); amax_columnwise_list.emplace_back( - make_torch_view(buffer, std::vector{1}, amax_offsets[i], torch::kFloat32)); + make_torch_view(buffer, amax_shape, amax_offsets[i], torch::kFloat32)); } } @@ -682,26 +691,32 @@ std::tuple, std::vector, bool> bulk_alloc // Construct C++ tensor // Use a TensorWrapper variable to hold the output of makeTransformerEngineTensor, // then set the amax and amax_columnwise values. + NVTEShape zero_shape, amax_shape; + zero_shape.ndim = 1; + amax_shape.ndim = 1; + zero_shape.data[0] = 0; + amax_shape.data[0] = 1; { auto tensor_wrapper = makeTransformerEngineTensor( rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_data_list[i].data_ptr() : nullptr, - rowwise_usage ? rowwise_data_shapes[i] : std::vector{0}, - columnwise_usage ? columnwise_data_shapes[i] : std::vector{0}, fp4_dtype, + rowwise_usage ? static_cast(rowwise_data_shapes[i]) : zero_shape, + columnwise_usage ? static_cast(columnwise_data_shapes[i]) : zero_shape, + fp4_dtype, /*amax_ptr=*/nullptr, /*scale_ptr=*/nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr, - rowwise_usage ? rowwise_scale_shapes[i] : std::vector{0}, - columnwise_usage ? columnwise_scale_shapes[i] : std::vector{0}, scaling_mode); + rowwise_usage ? static_cast(rowwise_scale_shapes[i]) : zero_shape, + columnwise_usage ? static_cast(columnwise_scale_shapes[i]) : zero_shape, + scaling_mode); // Set the amax rowwise and amax columnwise if available if (rowwise_usage) { - tensor_wrapper.set_amax(amax_rowwise_list[i].data_ptr(), DType::kFloat32, - std::vector{1}); + tensor_wrapper.set_amax(amax_rowwise_list[i].data_ptr(), DType::kFloat32, amax_shape); } if (columnwise_usage) { tensor_wrapper.set_columnwise_amax(amax_columnwise_list[i].data_ptr(), DType::kFloat32, - std::vector{1}); + amax_shape); } tensor_cpp_list.emplace_back(std::move(tensor_wrapper)); } @@ -765,9 +780,11 @@ static StochasticRngStateResources setup_stochastic_rounding_rng_states_helper( at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); int64_t *rng_state_ptr = static_cast(res.rng_states_tensor.data_ptr()) + i * 2; philox_unpack(philox_args, rng_state_ptr); - - res.te_rng_state_list.push_back(makeTransformerEngineTensor( - static_cast(rng_state_ptr), std::vector{2}, DType::kInt64)); + NVTEShape rng_state_shape; + rng_state_shape.ndim = 1; + rng_state_shape.data[0] = 2; + res.te_rng_state_list.push_back(makeTransformerEngineTensor(static_cast(rng_state_ptr), + rng_state_shape, DType::kInt64)); quant_config_list_rowwise[i].set_rng_state(res.te_rng_state_list[i].data()); quant_config_list_rowwise[i].set_stochastic_rounding(true); @@ -781,7 +798,7 @@ static StochasticRngStateResources setup_stochastic_rounding_rng_states_helper( philox_unpack(philox_args_col, rng_state_ptr_colwise); res.te_rng_state_list_colwise.push_back(makeTransformerEngineTensor( - static_cast(rng_state_ptr_colwise), std::vector{2}, DType::kInt64)); + static_cast(rng_state_ptr_colwise), rng_state_shape, DType::kInt64)); quant_config_list_colwise[i].set_rng_state(res.te_rng_state_list_colwise[i].data()); quant_config_list_colwise[i].set_stochastic_rounding(true); } @@ -997,18 +1014,21 @@ void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, // Note that the multi compute amax API expects rowwise amax pointer to be not null // So we need to set the pointer accordingly to make colwise-only quantization work std::vector orig_amax_ptr_list; + NVTEShape amax_shape; + amax_shape.ndim = 1; + amax_shape.data[0] = 1; for (size_t i = 0; i < num_tensors; i++) { auto rowwise_amax_ptr = output_list[i].get_amax().data_ptr; orig_amax_ptr_list.push_back(rowwise_amax_ptr); auto columnwise_amax_ptr = output_list[i].get_columnwise_amax().data_ptr; void *amax_ptr = rowwise_amax_ptr != nullptr ? rowwise_amax_ptr : columnwise_amax_ptr; NVTE_CHECK(amax_ptr != nullptr, "Could not find amax pointer"); - output_list[i].set_amax(amax_ptr, DType::kFloat32, std::vector{1}); + output_list[i].set_amax(amax_ptr, DType::kFloat32, amax_shape); } nvte_group_amax(input.data(), reinterpret_cast(nvte_tensor_output_list.data()), split_sections.data(), num_tensors, stream); for (size_t i = 0; i < num_tensors; i++) { - output_list[i].set_amax(orig_amax_ptr_list[i], DType::kFloat32, std::vector{1}); + output_list[i].set_amax(orig_amax_ptr_list[i], DType::kFloat32, amax_shape); } // Quantize tensors individually @@ -1104,7 +1124,7 @@ std::vector split_quantize(const at::Tensor &tensor, auto input_py = tensor.contiguous(); uint8_t *input_dptr = reinterpret_cast(input_py.data_ptr()); auto input_dtype = GetTransformerEngineDType(input_py.scalar_type()); - std::vector input_shape; + NVTEShapeWrapper input_shape; size_t input_size = 1; for (const auto &d : input_py.sizes()) { input_shape.push_back(d); @@ -1114,7 +1134,7 @@ std::vector split_quantize(const at::Tensor &tensor, // Split input tensor along dim 0 std::vector input_list; - std::vector> split_shapes; + std::vector split_shapes; size_t dim0_offset = 0; const size_t dim0_stride = input_shape[0] == 0 ? 0 : input_py.element_size() * input_size / input_shape[0]; @@ -1122,11 +1142,15 @@ std::vector split_quantize(const at::Tensor &tensor, NVTE_CHECK(dim0_offset + split_sections[i] <= input_shape[0], "Attempted to split tensor with shape=", input_shape, " along dim 0 with split_sections=", split_sections); - split_shapes.push_back(input_shape); + split_shapes.emplace_back(); auto &split_shape = split_shapes.back(); - split_shape[0] = split_sections[i]; + split_shape.push_back(split_sections[i]); + for (size_t j = 1; j < input_shape.size(); ++j) { + split_shape.push_back(input_shape[j]); + } void *split_dptr = static_cast(input_dptr + dim0_offset * dim0_stride); - input_list.emplace_back(makeTransformerEngineTensor(split_dptr, split_shape, input_dtype)); + input_list.emplace_back(makeTransformerEngineTensor( + split_dptr, static_cast(split_shape), input_dtype)); dim0_offset += split_sections[i]; } diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 335052296f..82c1ce1e7b 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -40,8 +40,8 @@ bool is_low_precision(const DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; } -std::vector getGemmOutputShape(const NVTEShape& A_shape, const bool transa, - const NVTEShape& B_shape, const bool transb) { +NVTEShape getGemmOutputShape(const NVTEShape& A_shape, const bool transa, const NVTEShape& B_shape, + const bool transb) { // Flatten outer dims to get 2D matrices const size_t A0 = A_shape.ndim > 0 ? product(A_shape, 0, A_shape.ndim - 1) : 1; const size_t A1 = A_shape.ndim > 0 ? A_shape.data[A_shape.ndim - 1] : 1; @@ -53,35 +53,37 @@ std::vector getGemmOutputShape(const NVTEShape& A_shape, const bool tran A0, ",", A1, "), transa=", transa, ", B=(", B0, ",", B1, "), transb=", transb, ")"); // Construct output dims - std::vector ret; + NVTEShape ret; + size_t idx = 0; if (transb) { - ret.emplace_back(B1); + ret.data[idx++] = B1; } else { // Unflatten B0 for (size_t i = 0; i < B_shape.ndim - 1; ++i) { - ret.emplace_back(B_shape.data[i]); + ret.data[idx++] = B_shape.data[i]; } } if (transa) { - ret.emplace_back(A0); + ret.data[idx++] = A0; } else { - ret.emplace_back(A1); + ret.data[idx++] = A1; } + ret.ndim = idx; return ret; } -bool checkGemmShape(const std::vector& expected, const NVTEShape& actual) { - if (expected.size() != actual.ndim) return false; - for (size_t i = 0; i < expected.size(); ++i) { - if (expected[i] != actual.data[i]) return false; +bool checkGemmShape(const NVTEShape& expected, const NVTEShape& actual) { + if (expected.ndim != actual.ndim) return false; + for (size_t i = 0; i < expected.ndim; ++i) { + if (expected.data[i] != actual.data[i]) return false; } return true; } } // namespace detail -std::pair createOutputTensor(const std::vector& shape, - DType dtype, py::handle quantizer) { +std::pair createOutputTensor(const NVTEShape& shape, DType dtype, + py::handle quantizer) { std::unique_ptr my_quantizer = convert_quantizer(quantizer); return my_quantizer->create_tensor(shape, dtype); } @@ -197,8 +199,8 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans auto dtype = GetATenDType(gelu_type); auto opts = torch::TensorOptions().dtype(dtype).device(torch::kCUDA); std::vector torch_shape; - for (auto v : D_shape) { - torch_shape.push_back(v); + for (size_t i = 0; i < D_shape.ndim; ++i) { + torch_shape.push_back(static_cast(D_shape.data[i])); } pre_gelu_out = at::empty(torch_shape, opts); } else { @@ -207,14 +209,22 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } } } - const auto gelu_shape = gelu ? D_shape : std::vector{0}; + NVTEShape gelu_shape; + gelu_shape.ndim = 1; + gelu_shape.data[0] = 0; + if (gelu) { + gelu_shape = D_shape; + } auto te_pre_gelu_out = makeTransformerEngineTensor(get_data_ptr(pre_gelu_out), gelu_shape, gelu_type); // Workspace - auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(), - std::vector{workspaceSize}, DType::kByte); + NVTEShape workspace_shape; + workspace_shape.ndim = 1; + workspace_shape.data[0] = workspaceSize; + auto te_workspace = + makeTransformerEngineTensor(workspace.data_ptr(), workspace_shape, DType::kByte); // Set an external SM Margin to all the GEMMs. // This comes in handy when DP is overlapped with GEMMs @@ -263,8 +273,10 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans if (extra_output.has_value()) { extra_output_tensor = makeTransformerEngineTensor(*extra_output); } else { + NVTEShape extra_output_shape; + extra_output_shape.ndim = 0; extra_output_tensor = - makeTransformerEngineTensor(nullptr, std::vector{0}, DType::kByte); + makeTransformerEngineTensor(nullptr, extra_output_shape, DType::kByte); } // Direct GEMM call to the correct overlap @@ -365,32 +377,49 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type, NVTEScalingMode nvte_scaling_modeA = NVTE_DELAYED_TENSOR_SCALING; NVTEScalingMode nvte_scaling_modeB = NVTE_DELAYED_TENSOR_SCALING; - auto te_A = makeTransformerEngineTensor( - A.data_ptr(), {static_cast(A.size(0)), static_cast(A.size(1))}, A_type, - nullptr, nullptr, A_scale_inverse.data_ptr(), getTensorShape(A_scale_inverse), - nvte_scaling_modeA); - auto te_B = makeTransformerEngineTensor( - B.data_ptr(), {static_cast(B.size(0)), static_cast(B.size(1))}, B_type, - nullptr, nullptr, B_scale_inverse.data_ptr(), getTensorShape(B_scale_inverse), - nvte_scaling_modeB); + const size_t A_shape_data[2] = {static_cast(A.size(0)), static_cast(A.size(1))}; + const NVTEShape A_shape = nvte_make_shape(A_shape_data, 2); + auto te_A = makeTransformerEngineTensor(A.data_ptr(), A_shape, A_type, nullptr, nullptr, + A_scale_inverse.data_ptr(), + getTensorShape(A_scale_inverse), nvte_scaling_modeA); + const size_t B_shape_data[2] = {static_cast(B.size(0)), static_cast(B.size(1))}; + const NVTEShape B_shape = nvte_make_shape(B_shape_data, 2); + auto te_B = makeTransformerEngineTensor(B.data_ptr(), B_shape, B_type, nullptr, nullptr, + B_scale_inverse.data_ptr(), + getTensorShape(B_scale_inverse), nvte_scaling_modeB); // TODO: D_scale_inv cannot be nullptr when D_type is FP8. - auto te_D = makeTransformerEngineTensor( - D.data_ptr(), - std::vector{static_cast(D.size(0)), static_cast(D.size(1))}, D_type, - D_amax.data_ptr(), D_scale.data_ptr(), nullptr); - auto te_bias = makeTransformerEngineTensor( - bias.data_ptr(), std::vector{static_cast(bias.size(0))}, bias_type); - auto te_counter = makeTransformerEngineTensor( - counter.data_ptr(), std::vector{static_cast(counter.size(0))}, DType::kInt32); - - const auto gelu_shape = pre_gelu_out.data_ptr() == nullptr - ? std::vector{static_cast(pre_gelu_out.size(0))} - : std::vector{static_cast(pre_gelu_out.size(0)), - static_cast(pre_gelu_out.size(1))}; + NVTEShape D_shape, D_scale_inv_shape; + D_shape.ndim = 2; + D_scale_inv_shape.ndim = 1; + D_scale_inv_shape.data[0] = 1; + D_shape.data[0] = static_cast(D.size(0)); + D_shape.data[1] = static_cast(D.size(1)); + auto te_D = makeTransformerEngineTensor(D.data_ptr(), D_shape, D_type, D_amax.data_ptr(), + D_scale.data_ptr(), nullptr, D_scale_inv_shape); + NVTEShape bias_shape; + bias_shape.ndim = 1; + bias_shape.data[0] = static_cast(bias.size(0)); + auto te_bias = makeTransformerEngineTensor(bias.data_ptr(), bias_shape, bias_type); + NVTEShape counter_shape; + counter_shape.ndim = 1; + counter_shape.data[0] = static_cast(counter.size(0)); + auto te_counter = makeTransformerEngineTensor(counter.data_ptr(), counter_shape, DType::kInt32); + + NVTEShape gelu_shape, workspace_shape; + if (pre_gelu_out.data_ptr() == nullptr) { + gelu_shape.ndim = 1; + gelu_shape.data[0] = static_cast(pre_gelu_out.size(0)); + } else { + gelu_shape.ndim = 2; + gelu_shape.data[0] = static_cast(pre_gelu_out.size(0)); + gelu_shape.data[1] = static_cast(pre_gelu_out.size(1)); + } + workspace_shape.ndim = 1; + workspace_shape.data[0] = workspaceSize; auto te_pre_gelu_out = makeTransformerEngineTensor( pre_gelu_out.data_ptr(), gelu_shape, GetTransformerEngineDType(pre_gelu_out.scalar_type())); - auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(), - std::vector{workspaceSize}, DType::kByte); + auto te_workspace = + makeTransformerEngineTensor(workspace.data_ptr(), workspace_shape, DType::kByte); NVTE_SCOPED_GIL_RELEASE({ nvte_cublas_atomic_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(), @@ -430,12 +459,13 @@ std::optional> te_general_grouped_gemm( // if there is single output at::Tensor out_tensor; - auto size_t_shape = + const NVTEShape nvte_D_shape = pytorch::detail::getGemmOutputShape(te_A.shape(), transa, te_B.shape(), transb); bool D_numel_is_zero = false; std::vector D_shape; - for (size_t t : size_t_shape) { - D_shape.push_back(t); + for (size_t j = 0; j < nvte_D_shape.ndim; ++j) { + const size_t t = nvte_D_shape.data[j]; + D_shape.push_back(static_cast(t)); if (t == 0) { D_numel_is_zero = true; } @@ -480,10 +510,14 @@ std::optional> te_general_grouped_gemm( auto te_bias = makeTransformerEngineTensor(bias[i]); auto te_pre_gelu_out = makeTransformerEngineTensor(pre_gelu_out[i]); - const auto gelu_shape = pre_gelu_out[i].data_ptr() == nullptr - ? std::vector{static_cast(te_pre_gelu_out.size(0))} - : std::vector{static_cast(te_pre_gelu_out.size(0)), - static_cast(te_pre_gelu_out.size(1))}; + NVTEShape gelu_shape; + gelu_shape.data[0] = te_pre_gelu_out.size(0); + if (pre_gelu_out[i].data_ptr() == nullptr) { + gelu_shape.ndim = 1; + } else { + gelu_shape.ndim = 2; + gelu_shape.data[1] = te_pre_gelu_out.size(1); + } DType gelu_type = bias_type; te_pre_gelu_out = @@ -550,9 +584,11 @@ std::optional> te_general_grouped_gemm( std::vector te_workspace_vector; std::vector te_workspace_wrappers; + NVTEShape workspace_shape; + workspace_shape.ndim = 1; + workspace_shape.data[0] = workspaceSize; for (size_t i = 0; i < workspace.size(); i++) { - auto wsp = makeTransformerEngineTensor(workspace[i].data_ptr(), - std::vector{workspaceSize}, DType::kByte); + auto wsp = makeTransformerEngineTensor(workspace[i].data_ptr(), workspace_shape, DType::kByte); te_workspace_vector.emplace_back(wsp.data()); te_workspace_wrappers.emplace_back(std::move(wsp)); } diff --git a/transformer_engine/pytorch/csrc/extensions/padding.cpp b/transformer_engine/pytorch/csrc/extensions/padding.cpp index d4b64a485c..cabb65233f 100644 --- a/transformer_engine/pytorch/csrc/extensions/padding.cpp +++ b/transformer_engine/pytorch/csrc/extensions/padding.cpp @@ -20,7 +20,7 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, const auto num_tensors = input_row_list.size(); // Extract properties from PyTorch tensors std::vector input_dptr_list, output_dptr_list; - std::vector> input_shape_list, output_shape_list; + std::vector input_shape_list, output_shape_list; std::vector input_type_list; void* d_input_ptr = reinterpret_cast(input.data_ptr()); void* d_output_ptr = reinterpret_cast(output.data_ptr()); @@ -34,8 +34,11 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, input_row_list[tensor_id] * input.size(1) * input.element_size(); input_char_ptr += input_dptr_offset; d_input_ptr = reinterpret_cast(input_char_ptr); - - input_shape_list.push_back({input_row_list[tensor_id], static_cast(input.size(1))}); + NVTEShape input_shape; + input_shape.ndim = 2; + input_shape.data[0] = input_row_list[tensor_id]; + input_shape.data[1] = static_cast(input.size(1)); + input_shape_list.push_back(input_shape); input_type_list.push_back(GetTransformerEngineDType(input.scalar_type())); // Move the output pointer to the next split. @@ -45,14 +48,17 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, output_char_ptr += output_dptr_offset; d_output_ptr = reinterpret_cast(output_char_ptr); - output_shape_list.push_back( - {padded_input_row_list[tensor_id], static_cast(output.size(1))}); + NVTEShape output_shape; + output_shape.ndim = 2; + output_shape.data[0] = padded_input_row_list[tensor_id]; + output_shape.data[1] = static_cast(output.size(1)); + output_shape_list.push_back(output_shape); } // Construct TE tensors std::vector nvte_input_list, nvte_output_list; std::vector tensor_wrappers; - auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector& shape, + auto make_tensor = [&tensor_wrappers](void* dptr, const NVTEShape& shape, DType dtype) -> NVTETensor { tensor_wrappers.emplace_back(makeTransformerEngineTensor(dptr, shape, dtype)); return tensor_wrappers.back().data(); @@ -95,7 +101,7 @@ void fused_multi_row_unpadding(at::Tensor input, at::Tensor output, const auto num_tensors = input_row_list.size(); // Extract properties from PyTorch tensors std::vector input_dptr_list, output_dptr_list; - std::vector> input_shape_list, output_shape_list; + std::vector input_shape_list, output_shape_list; std::vector input_type_list; void* d_input_ptr = reinterpret_cast(input.data_ptr()); void* d_output_ptr = reinterpret_cast(output.data_ptr()); @@ -109,8 +115,11 @@ void fused_multi_row_unpadding(at::Tensor input, at::Tensor output, input_row_list[tensor_id] * input.size(1) * input.element_size(); input_char_ptr += input_dptr_offset; d_input_ptr = reinterpret_cast(input_char_ptr); - - input_shape_list.push_back({input_row_list[tensor_id], static_cast(input.size(1))}); + NVTEShape input_shape; + input_shape.ndim = 2; + input_shape.data[0] = input_row_list[tensor_id]; + input_shape.data[1] = static_cast(input.size(1)); + input_shape_list.push_back(input_shape); input_type_list.push_back(GetTransformerEngineDType(input.scalar_type())); // Move the output pointer to the next split. @@ -120,14 +129,17 @@ void fused_multi_row_unpadding(at::Tensor input, at::Tensor output, output_char_ptr += output_dptr_offset; d_output_ptr = reinterpret_cast(output_char_ptr); - output_shape_list.push_back( - {unpadded_input_row_list[tensor_id], static_cast(output.size(1))}); + NVTEShape output_shape; + output_shape.ndim = 2; + output_shape.data[0] = unpadded_input_row_list[tensor_id]; + output_shape.data[1] = static_cast(output.size(1)); + output_shape_list.push_back(output_shape); } // Construct TE tensors std::vector nvte_input_list, nvte_output_list; std::vector tensor_wrappers; - auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector& shape, + auto make_tensor = [&tensor_wrappers](void* dptr, const NVTEShape& shape, transformer_engine::DType dtype) -> NVTETensor { tensor_wrappers.emplace_back(makeTransformerEngineTensor(dptr, shape, dtype)); return tensor_wrappers.back().data(); diff --git a/transformer_engine/pytorch/csrc/extensions/permutation.cpp b/transformer_engine/pytorch/csrc/extensions/permutation.cpp index 97cf400851..b0654c326e 100644 --- a/transformer_engine/pytorch/csrc/extensions/permutation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/permutation.cpp @@ -60,19 +60,20 @@ std::tuple> moe_permute_fwd( {num_tokens * topK}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); auto stream = at::cuda::getCurrentCUDAStream().stream(); - - auto input_cu = makeTransformerEngineTensor( - input.data_ptr(), - std::vector{static_cast(input.size(0)), static_cast(num_cols)}, - dtype); + NVTEShape input_shape, permuted_output_shape, sorted_row_id_cu_shape; + input_shape.ndim = 2; + permuted_output_shape.ndim = 2; + sorted_row_id_cu_shape.ndim = 1; + input_shape.data[0] = static_cast(input.size(0)); + input_shape.data[1] = static_cast(input.size(1)); + permuted_output_shape.data[0] = static_cast(permuted_output.size(0)); + permuted_output_shape.data[1] = static_cast(permuted_output.size(1)); + sorted_row_id_cu_shape.data[0] = static_cast(num_tokens * topK); + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), input_shape, dtype); auto permuted_output_cu = - makeTransformerEngineTensor(permuted_output.data_ptr(), - std::vector{static_cast(permuted_output.size(0)), - static_cast(num_cols)}, - dtype); - auto sorted_row_id_cu = makeTransformerEngineTensor( - sorted_row_id_ptr, std::vector{static_cast(num_tokens * topK)}, - DType::kInt32); + makeTransformerEngineTensor(permuted_output.data_ptr(), permuted_output_shape, dtype); + auto sorted_row_id_cu = + makeTransformerEngineTensor(sorted_row_id_ptr, sorted_row_id_cu_shape, DType::kInt32); auto row_id_map_cu = makeTransformerEngineTensor(row_id_map); nvte_permute(input_cu.data(), permuted_output_cu.data(), sorted_row_id_cu.data(), @@ -97,16 +98,16 @@ at::Tensor moe_unpermute_fwd(at::Tensor input, const DType dtype, at::Tensor row torch::dtype(input.scalar_type()).device(torch::kCUDA).requires_grad(false)); auto stream = at::cuda::getCurrentCUDAStream().stream(); - - auto input_cu = makeTransformerEngineTensor( - input.data_ptr(), - std::vector{static_cast(input.size(0)), static_cast(num_cols)}, - dtype); - auto unpermuted_output_cu = makeTransformerEngineTensor( - unpermuted_output.data_ptr(), - std::vector{static_cast(unpermuted_output.size(0)), - static_cast(num_cols)}, - dtype); + NVTEShape input_shape, unpermuted_output_shape; + input_shape.ndim = 2; + unpermuted_output_shape.ndim = 2; + input_shape.data[0] = static_cast(input.size(0)); + input_shape.data[1] = static_cast(input.size(1)); + unpermuted_output_shape.data[0] = static_cast(unpermuted_output.size(0)); + unpermuted_output_shape.data[1] = static_cast(num_cols); + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), input_shape, dtype); + auto unpermuted_output_cu = + makeTransformerEngineTensor(unpermuted_output.data_ptr(), unpermuted_output_shape, dtype); auto row_id_map_cu = makeTransformerEngineTensor(row_id_map); auto prob_cu = makeTransformerEngineTensor(prob); @@ -131,19 +132,19 @@ std::tuple moe_unpermute_bwd(at::Tensor input_bwd, at::T {num_tokens, topK}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(false)); auto stream = at::cuda::getCurrentCUDAStream().stream(); - - auto input_bwd_cu = makeTransformerEngineTensor( - input_bwd.data_ptr(), - std::vector{static_cast(input_bwd.size(0)), static_cast(num_cols)}, - dtype); - auto act_grad_cu = makeTransformerEngineTensor( - act_grad.data_ptr(), - std::vector{static_cast(act_grad.size(0)), static_cast(num_cols)}, - dtype); - auto input_fwd_cu = makeTransformerEngineTensor( - input_fwd.data_ptr(), - std::vector{static_cast(input_fwd.size(0)), static_cast(num_cols)}, - dtype); + NVTEShape input_bwd_shape, act_grad_shape, input_fwd_shape; + input_bwd_shape.ndim = 2; + act_grad_shape.ndim = 2; + input_fwd_shape.ndim = 2; + input_bwd_shape.data[0] = static_cast(input_bwd.size(0)); + input_bwd_shape.data[1] = static_cast(num_cols); + act_grad_shape.data[0] = static_cast(act_grad.size(0)); + act_grad_shape.data[1] = static_cast(num_cols); + input_fwd_shape.data[0] = static_cast(input_fwd.size(0)); + input_fwd_shape.data[1] = static_cast(num_cols); + auto input_bwd_cu = makeTransformerEngineTensor(input_bwd.data_ptr(), input_bwd_shape, dtype); + auto act_grad_cu = makeTransformerEngineTensor(act_grad.data_ptr(), act_grad_shape, dtype); + auto input_fwd_cu = makeTransformerEngineTensor(input_fwd.data_ptr(), input_fwd_shape, dtype); auto row_id_map_cu = makeTransformerEngineTensor(row_id_map); auto prob_cu = makeTransformerEngineTensor(prob); auto prob_grad_cu = makeTransformerEngineTensor(prob_grad); diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cpp b/transformer_engine/pytorch/csrc/extensions/transpose.cpp index 7dfdf99547..6c1ff313c5 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cpp +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cpp @@ -19,7 +19,8 @@ at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional transpose_shape_int64; if (shape.size() > 0) { transpose_shape_int64.push_back(shape.back()); @@ -44,9 +45,16 @@ at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional{M, N}, otype); - auto output_cu = makeTransformerEngineTensor(out.data_ptr(), std::vector{N, M}, otype); + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), input_shape, otype); + auto output_cu = makeTransformerEngineTensor(out.data_ptr(), output_shape, otype); nvte_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); return out; @@ -60,7 +68,7 @@ at::Tensor swap_first_dims(at::Tensor tensor, std::optional out) { // Allocate output tensor if needed if (!out) { - auto in_shape = getTensorShape(input); + const auto in_shape = getTensorShapeVector(input); NVTE_CHECK(in_shape.size() >= 2, "Invalid input tensor dimensions (shape=", in_shape, ")"); std::vector out_shape_int64(in_shape.begin(), in_shape.end()); out_shape_int64[0] = static_cast(in_shape[1]); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index fd748d1b21..1520541ade 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -77,6 +77,16 @@ std::pair NoneQuantizer::create_tensor(const std::vec return create_tensor(shape, dtype, at::empty(shape_int64, opts)); } +std::pair NoneQuantizer::create_tensor(const NVTEShape& shape, + DType dtype) const { + std::vector shape_int64; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_int64.push_back(static_cast(shape.data[i])); + } + const auto opts = at::TensorOptions().dtype(GetATenDType(dtype)).device(torch::kCUDA); + return create_tensor(shape, dtype, at::empty(shape_int64, opts)); +} + std::pair NoneQuantizer::create_tensor(const std::vector& shape, DType dtype, at::Tensor data) const { @@ -86,6 +96,15 @@ std::pair NoneQuantizer::create_tensor(const std::vec return {std::move(out_cpp), py::cast(data)}; } +std::pair NoneQuantizer::create_tensor(const NVTEShape& shape, + DType dtype, + at::Tensor data) const { + TensorWrapper out_cpp; + out_cpp.set_rowwise_data(data.data_ptr(), dtype, shape); + set_quantization_params(&out_cpp); + return {std::move(out_cpp), py::cast(data)}; +} + std::pair NoneQuantizer::convert_and_update_tensor( py::object tensor) const { auto tensor_pyt = tensor.cast(); @@ -116,14 +135,22 @@ std::pair Float8Quantizer::create_tensor( at::Tensor scale_inv = at::empty(std::vector{1}, opts); return create_tensor(shape, dtype, std::nullopt, std::nullopt, std::move(scale_inv)); } +std::pair Float8Quantizer::create_tensor(const NVTEShape& shape, + DType dtype) const { + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); + } + return create_tensor(shape_vec, dtype); +} std::pair Float8Quantizer::create_tensor( const std::vector& shape, DType dtype, std::optional data, std::optional transpose, std::optional scale_inv) const { using namespace pybind11::literals; - + int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); // Initialize data tensor - const bool with_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); + const bool with_data = rowwise_usage || is_non_tn_fp8_gemm_supported; if (with_data && !data) { const std::vector shape_int64(shape.begin(), shape.end()); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); @@ -134,7 +161,7 @@ std::pair Float8Quantizer::create_tensor( py::object data_py = with_data ? py::cast(*data) : py::none(); // Initialize transpose tensor - const bool with_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + const bool with_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported; if (with_transpose && !transpose) { const auto transpose_shape = make_transpose_shape(shape); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); @@ -187,8 +214,9 @@ std::pair Float8Quantizer::convert_and_update_tensor( NVTE_CHECK(detail::IsFloat8Tensor(tensor.ptr()), "Float8Quantizer must output to Float8Tensor."); // Expected buffers - const bool need_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); - const bool need_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); + const bool need_data = rowwise_usage || is_non_tn_fp8_gemm_supported; + const bool need_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported; NVTE_CHECK(need_data || need_transpose, "Invalid usages for Float8Quantizer."); // Extract buffers from Python tensor @@ -209,7 +237,7 @@ std::pair Float8Quantizer::convert_and_update_tensor( // Tensor dimensions std::vector shape; if (has_transpose) { - const auto transpose_shape = getTensorShape(*transpose_tensor); + const auto transpose_shape = getTensorShapeVector(*transpose_tensor); if (transpose_shape.size() > 0) { for (size_t i = 1; i < transpose_shape.size(); ++i) { shape.push_back(transpose_shape[i]); @@ -217,12 +245,12 @@ std::pair Float8Quantizer::convert_and_update_tensor( shape.push_back(transpose_shape.front()); } if (has_data) { - auto expected_shape = getTensorShape(*data_tensor); + const auto expected_shape = getTensorShapeVector(*data_tensor); NVTE_CHECK(shape == expected_shape, "FP8 data (shape=", expected_shape, ") and transpose (shape=", transpose_shape, ") do not match"); } } else { // Already checked has_data == true - shape = getTensorShape(*data_tensor); + shape = getTensorShapeVector(*data_tensor); } // Coerce data tensor @@ -321,14 +349,22 @@ void Float8CurrentScalingQuantizer::set_quantization_params(TensorWrapper* tenso tensor->set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), getTensorShape(amax)); } - +std::pair Float8CurrentScalingQuantizer::create_tensor( + const NVTEShape& shape, DType dtype) const { + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); + } + return create_tensor(shape_vec, dtype); +} std::pair Float8CurrentScalingQuantizer::create_tensor( const std::vector& shape, DType dtype) const { using namespace pybind11::literals; // Initialize data tensor at::Tensor data_tensor; - const bool with_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); + int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); + const bool with_data = rowwise_usage || is_non_tn_fp8_gemm_supported; if (with_data) { const std::vector shape_int64(shape.begin(), shape.end()); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); @@ -337,7 +373,7 @@ std::pair Float8CurrentScalingQuantizer::create_tenso // Initialize transpose tensor at::Tensor transpose_tensor; - const bool with_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + const bool with_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported; if (with_transpose) { const auto transpose_shape = make_transpose_shape(shape); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); @@ -408,8 +444,9 @@ std::pair Float8CurrentScalingQuantizer::convert_and_ "Float8CurrentScalingQuantizer must output to Float8Tensor."); // Expected buffers - const bool need_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); - const bool need_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); + const bool need_data = rowwise_usage || is_non_tn_fp8_gemm_supported; + const bool need_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported; NVTE_CHECK(need_data || need_transpose, "Invalid quantizer usages."); // Extract buffers from Python tensor @@ -430,7 +467,7 @@ std::pair Float8CurrentScalingQuantizer::convert_and_ // Tensor dimensions std::vector shape; if (has_transpose) { - const auto transpose_shape = getTensorShape(*transpose_tensor); + const auto transpose_shape = getTensorShapeVector(*transpose_tensor); if (transpose_shape.size() > 0) { for (size_t i = 1; i < transpose_shape.size(); ++i) { shape.push_back(transpose_shape[i]); @@ -438,12 +475,12 @@ std::pair Float8CurrentScalingQuantizer::convert_and_ shape.push_back(transpose_shape.front()); } if (has_data) { - auto expected_shape = getTensorShape(*data_tensor); + const auto expected_shape = getTensorShapeVector(*data_tensor); NVTE_CHECK(shape == expected_shape, "FP8 data (shape=", expected_shape, ") and transpose (shape=", transpose_shape, ") do not match"); } } else { // Already checked has_data == true - shape = getTensorShape(*data_tensor); + shape = getTensorShapeVector(*data_tensor); } // Coerce data tensor in Python tensor @@ -560,6 +597,15 @@ Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quanti void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const {} +std::pair Float8BlockQuantizer::create_tensor(const NVTEShape& shape, + DType dtype) const { + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); + } + return create_tensor(shape_vec, dtype); +} + std::pair Float8BlockQuantizer::create_tensor( const std::vector& shape, DType dtype) const { using namespace pybind11::literals; @@ -680,9 +726,9 @@ std::pair Float8BlockQuantizer::convert_and_update_te return std::vector(); } if (all_gather_usage) { - return getTensorShape(*columnwise_data); + return getTensorShapeVector(*columnwise_data); } - std::vector shape = getTensorShape(*columnwise_data); + std::vector shape = getTensorShapeVector(*columnwise_data); std::vector shape_transposed(shape.size()); for (size_t i = 0; i + 1 < shape.size(); ++i) { shape_transposed[i] = shape[i + 1]; @@ -694,7 +740,7 @@ std::pair Float8BlockQuantizer::convert_and_update_te }; std::vector shape; if (rowwise_data) { - shape = getTensorShape(*rowwise_data); + shape = getTensorShapeVector(*rowwise_data); if (columnwise_data) { auto expected_shape = get_columnwise_shape(all_gather_usage); NVTE_CHECK(shape == expected_shape, "BlockwiseFP8 row-wise data (shape=", shape, @@ -823,12 +869,24 @@ void Float8BlockQuantizer::quantize(const TensorWrapper& input, TensorWrapper& o std::vector Float8BlockQuantizer::get_scale_shape(const std::vector& shape, bool columnwise) const { + return get_scale_shape_impl(shape, columnwise); +} + +NVTEShapeWrapper Float8BlockQuantizer::get_scale_shape(const NVTEShapeWrapper& shape, + bool columnwise) const { + return get_scale_shape_impl(shape, columnwise); +} + +template +ShapeT Float8BlockQuantizer::get_scale_shape_impl(const ShapeT& shape, bool columnwise) const { size_t numel = 1; + size_t k_dim; + for (auto s : shape) { numel *= s; } + k_dim = shape.size() == 0 ? 1u : shape.back(); - size_t k_dim = shape.size() == 0 ? 1u : shape.back(); size_t m_dim = numel / k_dim; constexpr size_t kBlockLen = 128; @@ -836,27 +894,20 @@ std::vector Float8BlockQuantizer::get_scale_shape(const std::vector scale_shape; - + size_t sinv0 = 0; + size_t sinv1 = 0; bool rowwise_usage = !columnwise; if (rowwise_usage) { - // rowwise scaling factor shape - size_t sinv0 = 0; - size_t sinv1 = 0; if (block_scaling_dim == 2) { - // 2D scaling is always GEMM_READY for now NVTE_CHECK(data_format == Float8BlockScaleTensorFormat::GEMM_READY, "2D scaling is always GEMM_READY for now."); sinv0 = (m_dim + kBlockLen - 1) / kBlockLen; sinv1 = roundup((k_dim + kBlockLen - 1) / kBlockLen, 4); } else if (block_scaling_dim == 1) { - // 1D scaling can be GEMM_READY or COMPACT bool rowwise_compact = data_format == Float8BlockScaleTensorFormat::COMPACT; - // default rowwise scaling factor shape already transpose the scaling factor so it's GEMM_READY sinv0 = (k_dim + kBlockLen - 1) / kBlockLen; sinv1 = rowwise_compact ? m_dim : roundup(m_dim, 4); - // if the rowwise format is compact, the scaling factor is not be transposed if (rowwise_compact) { std::swap(sinv0, sinv1); } @@ -866,13 +917,8 @@ std::vector Float8BlockQuantizer::get_scale_shape(const std::vector Float8BlockQuantizer::get_scale_shape(const std::vector MXFP8Quantizer::create_tensor(const NVTEShape& shape, + DType dtype) const { + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); + } + return create_tensor(shape_vec, dtype); +} + std::pair MXFP8Quantizer::create_tensor(const std::vector& shape, DType dtype) const { using namespace pybind11::literals; @@ -1004,14 +1062,14 @@ std::pair MXFP8Quantizer::convert_and_update_tensor( // Tensor dimensions std::vector shape; if (columnwise_data) { - shape = getTensorShape(*columnwise_data); + shape = getTensorShapeVector(*columnwise_data); if (rowwise_data) { - auto expected_shape = getTensorShape(*rowwise_data); + const auto expected_shape = getTensorShapeVector(*rowwise_data); NVTE_CHECK(shape == expected_shape, "MXFP8 row-wise data (shape=", expected_shape, ") and column-wise data (shape=", shape, ") do not match"); } } else { // Already checked columnwise_data_tensor == true - shape = getTensorShape(*rowwise_data); + shape = getTensorShapeVector(*rowwise_data); } // Coerce row-wise data @@ -1104,33 +1162,45 @@ void MXFP8Quantizer::quantize(const TensorWrapper& input, TensorWrapper& out, std::vector MXFP8Quantizer::get_scale_shape(const std::vector& shape, bool columnwise) const { + return get_scale_shape_impl(shape, columnwise); +} + +NVTEShapeWrapper MXFP8Quantizer::get_scale_shape(const NVTEShapeWrapper& shape, + bool columnwise) const { + return get_scale_shape_impl(shape, columnwise); +} + +template +ShapeT MXFP8Quantizer::get_scale_shape_impl(const ShapeT& shape, bool columnwise) const { size_t numel = 1; + size_t last_dim; + for (auto s : shape) { numel *= s; } - - auto last_dim = shape.back(); + last_dim = shape.back(); NVTE_CHECK(last_dim % MXFP8_BLOCK_SIZE == 0 && (numel / last_dim) % MXFP8_BLOCK_SIZE == 0, "MXFP8 requires tensor dims that are divisible by ", MXFP8_BLOCK_SIZE, " (got shape=", shape, ")"); - std::vector scale_shape; - + size_t sinv0 = 0; + size_t sinv1 = 0; bool rowwise_usage = !columnwise; if (rowwise_usage) { - // rowwise scaling factor shape - size_t sinv0 = roundup(numel / last_dim, 128); - size_t sinv1 = roundup(last_dim / MXFP8_BLOCK_SIZE, 4); - scale_shape = {sinv0, sinv1}; + sinv0 = roundup(numel / last_dim, 128); + sinv1 = roundup(last_dim / MXFP8_BLOCK_SIZE, 4); } else { - // columnwise scaling factor shape - size_t sinv0 = roundup(numel / (last_dim * MXFP8_BLOCK_SIZE), 4); - size_t sinv1 = roundup(last_dim, 128); - scale_shape = {sinv0, sinv1}; + sinv0 = roundup(numel / (last_dim * MXFP8_BLOCK_SIZE), 4); + sinv1 = roundup(last_dim, 128); } - return scale_shape; + + ShapeT result; + result.resize(2); + result[0] = sinv0; + result[1] = sinv1; + return result; } NVFP4Quantizer::NVFP4Quantizer(const py::handle& quantizer) : Quantizer(quantizer) { @@ -1168,7 +1238,14 @@ void NVFP4Quantizer::set_quantization_params(TensorWrapper* tensor) const { tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), columnwise_data.shape); } - +std::pair NVFP4Quantizer::create_tensor(const NVTEShape& shape, + DType dtype) const { + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); + } + return create_tensor(shape_vec, dtype); +} std::pair NVFP4Quantizer::create_tensor(const std::vector& shape, DType dtype) const { using namespace pybind11::literals; @@ -1320,14 +1397,14 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( // Tensor dimensions, shape means original shape std::vector shape; if (columnwise_data) { - shape = convert_shape_back_from_fp4(getTensorShape(*columnwise_data), true); + shape = convert_shape_back_from_fp4(getTensorShapeVector(*columnwise_data), true); if (rowwise_data) { - auto expected_shape = convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false); + auto expected_shape = convert_shape_back_from_fp4(getTensorShapeVector(*rowwise_data), false); NVTE_CHECK(shape == expected_shape, "NVFP4 row-wise data (shape=", expected_shape, ") and column-wise data (shape=", shape, ") do not match"); } } else { // Already checked columnwise_data_tensor == true - shape = convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false); + shape = convert_shape_back_from_fp4(getTensorShapeVector(*rowwise_data), false); } size_t flat_first_dim = 1; @@ -1704,12 +1781,24 @@ void NVFP4Quantizer::quantize_with_amax(TensorWrapper& input, TensorWrapper& out std::vector NVFP4Quantizer::get_scale_shape(const std::vector& shape, bool columnwise) const { + return get_scale_shape_impl(shape, columnwise); +} + +NVTEShapeWrapper NVFP4Quantizer::get_scale_shape(const NVTEShapeWrapper& shape, + bool columnwise) const { + return get_scale_shape_impl(shape, columnwise); +} + +template +ShapeT NVFP4Quantizer::get_scale_shape_impl(const ShapeT& shape, bool columnwise) const { size_t numel = 1; + size_t last_dim; + for (auto s : shape) { numel *= s; } + last_dim = shape.back(); - auto last_dim = shape.back(); auto flat_first_dim = numel / last_dim; NVTE_CHECK(last_dim % NVFP4_BLOCK_SIZE == 0, "Last dim for NVFP4 must be divisible by ", @@ -1718,22 +1807,23 @@ std::vector NVFP4Quantizer::get_scale_shape(const std::vector& s "NVFP4 requires tensor dims that are divisible by ", NVFP4_BLOCK_SIZE, " (got shape=", shape, ")"); - std::vector scale_shape; - + size_t sinv0 = 0; + size_t sinv1 = 0; bool rowwise_usage = !columnwise; if (rowwise_usage) { - // rowwise scaling factor shape - size_t sinv0 = roundup(flat_first_dim, 128); - size_t sinv1 = roundup(last_dim / NVFP4_BLOCK_SIZE, 4); - scale_shape = {sinv0, sinv1}; + sinv0 = roundup(flat_first_dim, 128); + sinv1 = roundup(last_dim / NVFP4_BLOCK_SIZE, 4); } else { - // columnwise scaling factor shape - size_t sinv0 = roundup(last_dim, 128); - size_t sinv1 = roundup(flat_first_dim / NVFP4_BLOCK_SIZE, 4); - scale_shape = {sinv0, sinv1}; + sinv0 = roundup(last_dim, 128); + sinv1 = roundup(flat_first_dim / NVFP4_BLOCK_SIZE, 4); } - return scale_shape; + + ShapeT result; + result.resize(2); + result[0] = sinv0; + result[1] = sinv1; + return result; } } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/type_converters.cpp b/transformer_engine/pytorch/csrc/type_converters.cpp index 368e9dcdfa..48e9f06cc4 100644 --- a/transformer_engine/pytorch/csrc/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/type_converters.cpp @@ -132,7 +132,7 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) const auto &scale_inv = tensor.attr("_rowwise_scale_inv").cast(); const auto &amax_rowwise = tensor.attr("_amax_rowwise").cast(); ret.set_rowwise_data(data.data_ptr(), dtype, - convert_shape_back_from_fp4(getTensorShape(data), false)); + convert_shape_back_from_fp4(getTensorShapeVector(data), false)); ret.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E4M3, getTensorShape(scale_inv)); ret.set_amax(amax_rowwise.data_ptr(), DType::kFloat32, getTensorShape(amax_rowwise)); } @@ -143,7 +143,7 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) const auto &scale_inv = tensor.attr("_columnwise_scale_inv").cast(); const auto &amax_columnwise = tensor.attr("_amax_columnwise").cast(); ret.set_columnwise_data(data.data_ptr(), DType::kFloat4E2M1, - convert_shape_back_from_fp4(getTensorShape(data), false)); + convert_shape_back_from_fp4(getTensorShapeVector(data), false)); ret.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E4M3, getTensorShape(scale_inv)); ret.set_columnwise_amax(amax_columnwise.data_ptr(), DType::kFloat32, diff --git a/transformer_engine/pytorch/csrc/util.cpp b/transformer_engine/pytorch/csrc/util.cpp index ce547d302e..be4c34b75a 100644 --- a/transformer_engine/pytorch/csrc/util.cpp +++ b/transformer_engine/pytorch/csrc/util.cpp @@ -15,7 +15,7 @@ std::optional swizzle_scaling_factors(transformer_engine::TensorWrap if (input.scaling_mode() == NVTE_INVALID_SCALING) { NVTE_ERROR("Invalid scaling mode for swizzle."); - } else if (input.scaling_mode() != NVTE_MXFP8_1D_SCALING && + } else if (input.scaling_mode() != NVTE_MXFP8_1D_SCALING || input.scaling_mode() != NVTE_NVFP4_1D_SCALING) { return std::nullopt; } @@ -59,24 +59,24 @@ std::optional swizzle_scaling_factors(transformer_engine::TensorWrap (nvfp4) ? transformer_engine::DType::kFloat8E4M3 : transformer_engine::DType::kFloat8E8M0; if (rowwise) { - input_cu.set_rowwise_data(input.dptr(), input_dtype, input_shape); - input_cu.set_rowwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shape); - output_cu.set_rowwise_data(input.dptr(), input_dtype, input_shape); - output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape); + input_cu.set_rowwise_data(input.dptr(), input_dtype, nvte_input_shape); + input_cu.set_rowwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv.shape); + output_cu.set_rowwise_data(input.dptr(), input_dtype, nvte_input_shape); + output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv.shape); } else { - input_cu.set_columnwise_data(input.columnwise_dptr(), input_dtype, input_shape); - input_cu.set_columnwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shape); - output_cu.set_columnwise_data(input.columnwise_dptr(), input_dtype, input_shape); - output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape); + input_cu.set_columnwise_data(input.columnwise_dptr(), input_dtype, nvte_input_shape); + input_cu.set_columnwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv.shape); + output_cu.set_columnwise_data(input.columnwise_dptr(), input_dtype, nvte_input_shape); + output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv.shape); } // Launch kernel nvte_swizzle_scaling_factors(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); if (rowwise) { - input.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape); + input.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv.shape); } else { - input.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape); + input.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv.shape); } return swizzled_scale_inv; diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index b65f7005eb..7e93f2d086 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -1343,7 +1343,7 @@ def reset_parameters(self, defer_init=False): elif self.parallel_mode == "column": set_tensor_model_parallel_attributes(getattr(self, bias), True, 0, 1) - @no_torch_dynamo() + @no_torch_dynamo(recursive=False) def forward( self, inp: torch.Tensor,