Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions transformer_engine/common/gemm/cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
ret.Atype = A.data.dtype;
ret.A_scale_inv = A.scale_inv.dptr;
ret.lda = is_A_transposed ? k : m;
if (!nvte_is_non_tn_fp8_gemm_supported() && !is_A_transposed) {
int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported();
if (!is_nvte_non_tn_fp8_gemm_supported && !is_A_transposed) {
// Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
if (A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype)) {
ret.A = A.columnwise_data.dptr;
Expand All @@ -140,7 +141,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
} else {
NVTE_CHECK(!is_fp8_dtype(ret.Atype), "Input A is missing column-wise usage");
}
} else if (nvte_is_non_tn_fp8_gemm_supported() && !A.has_data()) {
} else if (is_nvte_non_tn_fp8_gemm_supported && !A.has_data()) {
// Blackwell supports any GEMM layout for FP8, so we can use column-wise/transposed
// data with the mirrored transpose-flag if we don't have row-wise data.
NVTE_CHECK(A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype),
Expand Down Expand Up @@ -220,7 +221,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
ret.Btype = B.data.dtype;
ret.B_scale_inv = B.scale_inv.dptr;
ret.ldb = is_B_transposed ? n : k;
if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) {
int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported();
if (!is_nvte_non_tn_fp8_gemm_supported && is_B_transposed) {
// Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
if (B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype)) {
ret.B = B.columnwise_data.dptr;
Expand All @@ -231,7 +233,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
} else {
NVTE_CHECK(!is_fp8_dtype(ret.Btype), "Input B is missing column-wise usage");
}
} else if (nvte_is_non_tn_fp8_gemm_supported() && !B.has_data()) {
} else if (is_nvte_non_tn_fp8_gemm_supported && !B.has_data()) {
// Blackwell supports any GEMM layout for FP8, so we can use column-wise/transposed
// data with the mirrored transpose-flag if we don't have row-wise data.
NVTE_CHECK(B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype),
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/common/transformer_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -961,7 +961,7 @@ void nvte_destroy_quantization_config(NVTEQuantizationConfig config) {
}

int nvte_is_non_tn_fp8_gemm_supported() {
int num_devices = transformer_engine::cuda::num_devices();
static int num_devices = transformer_engine::cuda::num_devices();
static std::vector<int> cache(num_devices, -1);
static std::vector<std::once_flag> flags(num_devices);
int device_id = transformer_engine::cuda::current_device();
Expand Down
67 changes: 62 additions & 5 deletions transformer_engine/pytorch/csrc/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ std::vector<size_t> convert_shape_back_from_fp4(const std::vector<size_t>& shape
return ret;
}

std::vector<size_t> getTensorShape(const at::Tensor& t) {
NVTEShape getTensorShape(const at::Tensor& t) { return convertTorchShape(t.sizes()); }

std::vector<size_t> getTensorShapeVector(const at::Tensor& t) {
std::vector<size_t> shape;
for (auto s : t.sizes()) {
shape.push_back(s);
Expand Down Expand Up @@ -119,10 +121,7 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor(

transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor) {
transformer_engine::DType dtype = GetTransformerEngineDType(tensor.scalar_type());
std::vector<size_t> shape;
for (auto s : tensor.sizes()) {
shape.push_back(s);
}
NVTEShape shape = getTensorShape(tensor);
return makeTransformerEngineTensor(tensor.data_ptr(), shape, dtype);
}

Expand Down Expand Up @@ -178,6 +177,41 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor(
return ret;
}

transformer_engine::TensorWrapper makeTransformerEngineTensor(
void* data_ptr, const NVTEShape& shape, const transformer_engine::DType type, void* amax_ptr,
void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape,
NVTEScalingMode scaling_mode) {
TensorWrapper ret(scaling_mode);
ret.set_rowwise_data(data_ptr, type, shape);
const size_t meta_shape_data[1] = {1};
NVTEShape meta_shape;
meta_shape.ndim = 1;
meta_shape.data[0] = 1;
ret.set_amax(amax_ptr, DType::kFloat32, meta_shape);
ret.set_scale(scale_ptr, DType::kFloat32, meta_shape);
auto scale_inv_dtype =
(scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 : DType::kFloat32;
ret.set_rowwise_scale_inv(scale_inv_ptr, scale_inv_dtype, scale_inv_shape);
return ret;
}

transformer_engine::TensorWrapper makeTransformerEngineTensor(
void* data_ptr, const std::vector<size_t>& shape, const transformer_engine::DType type,
void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape,
NVTEScalingMode scaling_mode) {
TensorWrapper ret(scaling_mode);
ret.set_rowwise_data(data_ptr, type, shape);
NVTEShape meta_shape;
meta_shape.ndim = 1;
meta_shape.data[0] = 1;
ret.set_amax(amax_ptr, DType::kFloat32, meta_shape);
ret.set_scale(scale_ptr, DType::kFloat32, meta_shape);
auto scale_inv_dtype =
(scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 : DType::kFloat32;
ret.set_rowwise_scale_inv(scale_inv_ptr, scale_inv_dtype, scale_inv_shape);
return ret;
}

transformer_engine::TensorWrapper makeTransformerEngineTensor(
void* data_ptr, void* columnwise_data_ptr, const std::vector<size_t>& shape,
const std::vector<size_t>& columnwise_shape, const transformer_engine::DType type,
Expand All @@ -199,6 +233,29 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor(
return ret;
}

transformer_engine::TensorWrapper makeTransformerEngineTensor(
void* data_ptr, void* columnwise_data_ptr, const NVTEShape& shape,
const NVTEShape& columnwise_shape, const transformer_engine::DType type, void* amax_ptr,
void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr,
const NVTEShape& scale_inv_shape, const NVTEShape& columnwise_scale_inv_shape,
NVTEScalingMode scaling_mode) {
TensorWrapper ret(scaling_mode);
ret.set_rowwise_data(data_ptr, type, shape);
ret.set_columnwise_data(columnwise_data_ptr, type, columnwise_shape);
NVTEShape meta_shape;
meta_shape.ndim = 1;
meta_shape.data[0] = 1;
ret.set_amax(amax_ptr, DType::kFloat32, meta_shape);
ret.set_scale(scale_ptr, DType::kFloat32, meta_shape);
auto scale_inv_dtype = (scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0
: (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat8E4M3
: DType::kFloat32;
ret.set_rowwise_scale_inv(scale_inv_ptr, scale_inv_dtype, scale_inv_shape);
ret.set_columnwise_scale_inv(columnwise_scale_inv_ptr, scale_inv_dtype,
columnwise_scale_inv_shape);
return ret;
}

transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor, at::Tensor amax,
const at::Tensor scale,
at::Tensor scale_inv,
Expand Down
27 changes: 26 additions & 1 deletion transformer_engine/pytorch/csrc/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,12 @@ class NoneQuantizer : public Quantizer {
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape, DType dtype,
at::Tensor data) const;

std::pair<TensorWrapper, py::object> create_tensor(const NVTEShape& shape, DType dtype) const;

/*! @brief Construct a tensor with pre-initialized data */
std::pair<TensorWrapper, py::object> create_tensor(const NVTEShape& shape, DType dtype,
at::Tensor data) const;

std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object tensor) const override;

void quantize(const TensorWrapper& input, TensorWrapper& out,
Expand Down Expand Up @@ -339,7 +345,9 @@ class NVFP4Quantizer : public Quantizer {

std::unique_ptr<Quantizer> convert_quantizer(py::handle quantizer);

std::vector<size_t> getTensorShape(const at::Tensor& t);
NVTEShape getTensorShape(const at::Tensor& t);

std::vector<size_t> getTensorShapeVector(const at::Tensor& t);

transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid,
const std::string& fp8_recipe);
Expand Down Expand Up @@ -432,6 +440,16 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor(
void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, std::vector<size_t> scale_inv_shape = {1},
NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING);

transformer_engine::TensorWrapper makeTransformerEngineTensor(
void* data_ptr, const NVTEShape& shape, const transformer_engine::DType type, void* amax_ptr,
void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape,
NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING);

transformer_engine::TensorWrapper makeTransformerEngineTensor(
void* data_ptr, const std::vector<size_t>& shape, const transformer_engine::DType type,
void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape,
NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING);

transformer_engine::TensorWrapper makeTransformerEngineTensor(
void* data_ptr, void* columnwise_data_ptr, const std::vector<size_t>& shape,
const std::vector<size_t>& columnwise_shape, const transformer_engine::DType type,
Expand All @@ -440,6 +458,13 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor(
const std::vector<size_t>& columnwise_scale_inv_shape = {1},
NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING);

transformer_engine::TensorWrapper makeTransformerEngineTensor(
void* data_ptr, void* columnwise_data_ptr, const NVTEShape& shape,
const NVTEShape& columnwise_shape, const transformer_engine::DType type, void* amax_ptr,
void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr,
const NVTEShape& scale_inv_shape, const NVTEShape& columnwise_scale_inv_shape,
NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING);

transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr,
const NVTEShape& shape,
const transformer_engine::DType type);
Expand Down
10 changes: 7 additions & 3 deletions transformer_engine/pytorch/csrc/extensions/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -478,10 +478,14 @@ std::vector<py::object> fused_attn_bwd(
auto cu_seqlens_kv_padded_sizes = cu_seqlens_kv_padded.value().sizes().vec();
std::vector<size_t> cu_seqlens_kv_padded_shape{cu_seqlens_kv_padded_sizes.begin(),
cu_seqlens_kv_padded_sizes.end()};
te_cu_seqlens_q_padded = makeTransformerEngineTensor(cu_seqlens_q_padded.value().data_ptr(),
cu_seqlens_q_padded_shape, DType::kInt32);
te_cu_seqlens_q_padded = makeTransformerEngineTensor(
cu_seqlens_q_padded.value().data_ptr(),
nvte_make_shape(cu_seqlens_q_padded_shape.data(), cu_seqlens_q_padded_shape.size()),
DType::kInt32);
te_cu_seqlens_kv_padded = makeTransformerEngineTensor(
cu_seqlens_kv_padded.value().data_ptr(), cu_seqlens_kv_padded_shape, DType::kInt32);
cu_seqlens_kv_padded.value().data_ptr(),
nvte_make_shape(cu_seqlens_kv_padded_shape.data(), cu_seqlens_kv_padded_shape.size()),
DType::kInt32);
}

// convert auxiliary tensors from forward to NVTETensors
Expand Down
6 changes: 3 additions & 3 deletions transformer_engine/pytorch/csrc/extensions/bias.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ std::vector<py::object> bgrad_quantize(const at::Tensor &grad_output, py::handle
// Grad output tensor
auto grad_output_torch = grad_output.contiguous();
const TensorWrapper &grad_output_nvte = makeTransformerEngineTensor(grad_output_torch);
const auto shape = getTensorShape(grad_output_torch);
const auto shape = getTensorShapeVector(grad_output_torch);
auto grad_output_dtype = GetTransformerEngineDType(grad_output_torch.scalar_type());

// Construct grad bias tensor
Expand Down Expand Up @@ -116,11 +116,11 @@ std::vector<py::object> dact_dbias(
// Grad output and activation input tensors
grad_output_torch = grad_output_torch.contiguous();
const TensorWrapper &grad_output_nvte = makeTransformerEngineTensor(grad_output_torch);
const auto output_shape = getTensorShape(grad_output_torch);
const auto output_shape = getTensorShapeVector(grad_output_torch);
auto grad_output_dtype = GetTransformerEngineDType(grad_output_torch.scalar_type());
act_input_torch = act_input_torch.contiguous();
const TensorWrapper &act_input_nvte = makeTransformerEngineTensor(act_input_torch);
const auto input_shape = getTensorShape(act_input_torch);
const auto input_shape = getTensorShapeVector(act_input_torch);

// Construct tensors
auto quantizer_cpp = convert_quantizer(quantizer_py);
Expand Down
Loading