diff --git a/setup.py b/setup.py index 9d2d7bce1c..1a7f0e781e 100644 --- a/setup.py +++ b/setup.py @@ -710,13 +710,17 @@ def get_extensions(): print("Building mxfp8_cuda extension") ext_modules.append( CUDAExtension( - name="torchao.prototype.mxfp8_cuda", + name="torchao._C_mxfp8", sources=mxfp8_sources, include_dirs=[ mxfp8_extension_dir, # For mxfp8_quantize.cuh, mxfp8_extension.cpp, and mxfp8_cuda.cu ], extra_compile_args={ - "cxx": ["-std=c++17", "-O3"], + "cxx": [ + f"-DPy_LIMITED_API={min_supported_cpython_hexcode}", + "-std=c++17", + "-O3", + ], "nvcc": nvcc_args + [ "-gencode=arch=compute_100,code=sm_100", diff --git a/test/prototype/moe_training/test_kernels.py b/test/prototype/moe_training/test_kernels.py index ecd4cefe6a..0387cc28e0 100644 --- a/test/prototype/moe_training/test_kernels.py +++ b/test/prototype/moe_training/test_kernels.py @@ -21,6 +21,7 @@ triton_fp8_per_group_rowwise_scales, ) from torchao.prototype.moe_training.kernels.mxfp8 import ( + mxfp8_quantize_cuda_3d, torch_to_blocked_2d_K_groups, torch_to_blocked_2d_M_groups, torch_to_blocked_per_group_3d, @@ -317,8 +318,6 @@ def test_triton_mx_block_rearrange_2d_K_groups( @pytest.mark.parametrize("input_dtype", (torch.bfloat16,)) @pytest.mark.parametrize("scaling_mode", (ScaleCalculationMode.FLOOR,)) def test_cuda_mx_dim1_3d_numerics(E, N, K, input_dtype, scaling_mode): - from torchao.prototype import mxfp8_cuda - scaling_mode_str = ( "floor" if scaling_mode == ScaleCalculationMode.FLOOR else "rceil" ) @@ -344,9 +343,8 @@ def test_cuda_mx_dim1_3d_numerics(E, N, K, input_dtype, scaling_mode): y_d1_ref = y_d1_ref.transpose(-2, -1) s_d1_ref = s_d1_ref.transpose(-2, -1) - # CUDA implementation (should work with any stride pattern) - y_d1, s_d1 = mxfp8_cuda.quantize_3d( - x, scale_dim_n=block_size, scaling_mode=scaling_mode_str + y_d1, s_d1 = mxfp8_quantize_cuda_3d( + x, block_size=block_size, scaling_mode=scaling_mode_str ) # Check scales torch.testing.assert_close(s_d1, s_d1_ref, rtol=0, atol=0) diff --git a/test/prototype/mx_formats/test_kernels.py b/test/prototype/mx_formats/test_kernels.py index 254b749767..a4839f7c61 100644 --- a/test/prototype/mx_formats/test_kernels.py +++ b/test/prototype/mx_formats/test_kernels.py @@ -33,6 +33,7 @@ f32_to_f6_e2m3_unpacked, f32_to_f6_e3m2_unpacked, get_bits, + mxfp8_quantize_cuda, pack_uint4, triton_mxfp8_dequant_dim0, triton_to_mxfp8_dim0, @@ -541,8 +542,6 @@ def test_rearrange(shape): "scaling_mode", (ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL) ) def test_cuda_mx_dim1_numerics(M, K, input_dtype, scaling_mode): - from torchao.prototype import mxfp8_cuda - scaling_mode_str = ( "floor" if scaling_mode == ScaleCalculationMode.FLOOR else "rceil" ) @@ -561,13 +560,11 @@ def test_cuda_mx_dim1_numerics(M, K, input_dtype, scaling_mode): scaling_mode=scaling_mode, ) - _, y_d1, _, s_d1 = mxfp8_cuda.quantize( + _, y_d1, _, s_d1 = mxfp8_quantize_cuda( x, rowwise=False, colwise=True, scaling_mode=scaling_mode_str, - scale_dim_x=1, - scale_dim_y=block_size, ) # check scales @@ -587,48 +584,15 @@ def test_cuda_mx_dim1_numerics(M, K, input_dtype, scaling_mode): reason="CUDA version >= 12.8 required for MXFP8 CUDA kernels", ) def test_cuda_mx_dim0_not_supported(): - from torchao.prototype import mxfp8_cuda - M, K = 64, 64 - block_size = 32 x = ( torch.arange(0, M * K, dtype=torch.bfloat16, device="cuda") .reshape(M, K) .contiguous() ) with pytest.raises(RuntimeError): - _, y_d1, _, s_d1 = mxfp8_cuda.quantize( + _, y_d1, _, s_d1 = mxfp8_quantize_cuda( x, rowwise=True, colwise=False, - scale_dim_x=block_size, - scale_dim_y=1, - ) - - -@pytest.mark.skipif( - not is_sm_at_least_100(), - reason="MXFP8 requires CUDA capability 10.0 or greater", -) -@pytest.mark.skipif( - not is_cuda_version_at_least(12, 8), - reason="CUDA version >= 12.8 required for MXFP8 CUDA kernels", -) -def test_cuda_mx_dim1_invalid_block_size(): - from torchao.prototype import mxfp8_cuda - - M, K = 64, 64 - x = ( - torch.arange(0, M * K, dtype=torch.bfloat16, device="cuda") - .reshape(M, K) - .contiguous() - ) - invalid_block_size = 4 - with pytest.raises(RuntimeError): - _, y_d1, _, s_d1 = mxfp8_cuda.quantize( - x, - rowwise=False, - colwise=True, - scale_dim_x=1, - scale_dim_y=invalid_block_size, ) diff --git a/torchao/csrc/cuda/mx_kernels/mxfp8_extension.cpp b/torchao/csrc/cuda/mx_kernels/mxfp8_extension.cpp index d445fcad4d..21869c3b8f 100644 --- a/torchao/csrc/cuda/mx_kernels/mxfp8_extension.cpp +++ b/torchao/csrc/cuda/mx_kernels/mxfp8_extension.cpp @@ -1,32 +1,33 @@ -// PyBind wrapping for the mxfp8 extension +// MXFP8 extension using TORCH_LIBRARY (CPython ABI agnostic) +#include +#include #include #include #include #include -#include namespace mxfp8 { // Forward declarations -void mxfp8_quantize_cuda(const torch::Tensor &input, - torch::Tensor &output_rowwise, - torch::Tensor &output_columnwise, - torch::Tensor &scales_rowwise, - torch::Tensor &scales_colwise, +void mxfp8_quantize_cuda(const at::Tensor &input, + at::Tensor &output_rowwise, + at::Tensor &output_columnwise, + at::Tensor &scales_rowwise, + at::Tensor &scales_colwise, int64_t scale_dim_x, int64_t scale_dim_y, const std::string &fp8_format, const std::string &scaling_mode); -void mxfp8_quantize_3d_cuda(const torch::Tensor &input, - torch::Tensor &output_colwise, - torch::Tensor &scales_colwise, +void mxfp8_quantize_3d_cuda(const at::Tensor &input, + at::Tensor &output_colwise, + at::Tensor &scales_colwise, int64_t scale_dim_n, const std::string &fp8_format, const std::string &scaling_mode); // Helper for tensor validation -void check_cuda_tensor(const torch::Tensor &t, const char *name) { +void check_cuda_tensor(const at::Tensor &t, const char *name) { TORCH_CHECK(t.is_cuda(), name, " must be a CUDA tensor"); TORCH_CHECK(t.is_contiguous(), name, " must be contiguous"); } @@ -46,8 +47,8 @@ void validate_scale_dimensions(int64_t scale_dim_x, int64_t scale_dim_y) { } // Main quantization function -std::tuple -mxfp8_quantize(torch::Tensor input, bool rowwise, bool colwise, +std::tuple +mxfp8_quantize(const at::Tensor& input, bool rowwise, bool colwise, int64_t scale_dim_x, int64_t scale_dim_y, const std::string &fp8_format, const std::string &scaling_mode) { @@ -57,9 +58,9 @@ mxfp8_quantize(torch::Tensor input, bool rowwise, bool colwise, TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor"); TORCH_CHECK(input.is_contiguous(), "input must be contiguous"); TORCH_CHECK(input.dim() == 2, "input must be 2D"); - TORCH_CHECK(input.scalar_type() == torch::kFloat32 || - input.scalar_type() == torch::kFloat16 || - input.scalar_type() == torch::kBFloat16, + TORCH_CHECK(input.scalar_type() == at::kFloat || + input.scalar_type() == at::kHalf || + input.scalar_type() == at::kBFloat16, "Input must be float32, float16, or bfloat16"); TORCH_CHECK(rowwise || colwise, "At least one of rowwise or colwise must be true"); @@ -75,40 +76,40 @@ mxfp8_quantize(torch::Tensor input, bool rowwise, bool colwise, c10::cuda::CUDAGuard device_guard(input.device()); // Create tensor options - const auto options_fp8 = torch::TensorOptions() - .dtype(torch::kFloat8_e4m3fn) // FP8 stored as uint8 + const auto options_fp8 = at::TensorOptions() + .dtype(at::kFloat8_e4m3fn) .device(input.device()); - const auto options_scale = torch::TensorOptions() - .dtype(torch::kFloat8_e8m0fnu) // E8M0 stored as uint8 + const auto options_scale = at::TensorOptions() + .dtype(at::kFloat8_e8m0fnu) .device(input.device()); // Allocate output tensors - torch::Tensor output_rowwise, output_colwise; - torch::Tensor scales_rowwise, scales_colwise; + at::Tensor output_rowwise, output_colwise; + at::Tensor scales_rowwise, scales_colwise; if (rowwise) { const int64_t num_col_blocks = (cols + scale_dim_x - 1) / scale_dim_x; - output_rowwise = torch::empty({rows, cols}, options_fp8); - scales_rowwise = torch::empty({rows, num_col_blocks}, options_scale); + output_rowwise = at::empty({rows, cols}, options_fp8); + scales_rowwise = at::empty({rows, num_col_blocks}, options_scale); } else { - output_rowwise = torch::empty({0}, options_fp8); - scales_rowwise = torch::empty({0}, options_scale); + output_rowwise = at::empty({0}, options_fp8); + scales_rowwise = at::empty({0}, options_scale); } if (colwise) { const int64_t num_row_blocks = (rows + scale_dim_y - 1) / scale_dim_y; - output_colwise = torch::empty_strided({rows, cols}, {1, rows}, options_fp8); + output_colwise = at::empty_strided({rows, cols}, {1, rows}, options_fp8); // Need scales_colwise to be this shape so the 'col' dim stride is 1, // for colwise scaling, we can avoid uncoalesced writes to global memory. // This is because each of the 32 threads in a warp will be computing // a scale for a different column of 32 input data values, then each writing // that scale to global memory - so the stride along this `col` dim should be 1 // so writes can be coalesced into a single transaction. - scales_colwise = torch::empty_strided({cols, num_row_blocks}, {1, cols}, options_scale); + scales_colwise = at::empty_strided({cols, num_row_blocks}, {1, cols}, options_scale); } else { - output_colwise = torch::empty({0}, options_fp8); - scales_colwise = torch::empty({0}, options_scale); + output_colwise = at::empty({0}, options_fp8); + scales_colwise = at::empty({0}, options_scale); } // Call CUDA kernels @@ -124,8 +125,8 @@ mxfp8_quantize(torch::Tensor input, bool rowwise, bool colwise, } // 3D tensor quantization function -std::tuple -mxfp8_quantize_3d(torch::Tensor input, int64_t scale_dim_n, +std::tuple +mxfp8_quantize_3d(const at::Tensor& input, int64_t scale_dim_n, const std::string &fp8_format, const std::string &scaling_mode) { @@ -134,9 +135,9 @@ mxfp8_quantize_3d(torch::Tensor input, int64_t scale_dim_n, TORCH_CHECK(input.is_contiguous(), "input must be contiguous"); // Note: We don't check contiguous for 3D as it may have column major strides TORCH_CHECK(input.dim() == 3, "input must be 3D"); - TORCH_CHECK(input.scalar_type() == torch::kFloat32 || - input.scalar_type() == torch::kFloat16 || - input.scalar_type() == torch::kBFloat16, + TORCH_CHECK(input.scalar_type() == at::kFloat || + input.scalar_type() == at::kHalf || + input.scalar_type() == at::kBFloat16, "Input must be float32, float16, or bfloat16"); TORCH_CHECK(scale_dim_n == 32, "scale_dim_n must be 32 for now"); @@ -154,21 +155,21 @@ mxfp8_quantize_3d(torch::Tensor input, int64_t scale_dim_n, c10::cuda::CUDAGuard device_guard(input.device()); // Create tensor options - const auto options_fp8 = torch::TensorOptions() - .dtype(torch::kFloat8_e4m3fn) + const auto options_fp8 = at::TensorOptions() + .dtype(at::kFloat8_e4m3fn) .device(input.device()); - const auto options_scale = torch::TensorOptions() - .dtype(torch::kFloat8_e8m0fnu) + const auto options_scale = at::TensorOptions() + .dtype(at::kFloat8_e8m0fnu) .device(input.device()); // Create output tensor with column major layout (required for downstream ops) - torch::Tensor output_colwise = torch::empty_strided( + at::Tensor output_colwise = at::empty_strided( {E, N, K}, {N * K, 1, N}, options_fp8); // Create scales tensor with shape (E, num_n_blocks, K) const int64_t num_n_blocks = (N + scale_dim_n - 1) / scale_dim_n; - torch::Tensor scales_colwise = torch::empty({E, num_n_blocks, K}, options_scale); + at::Tensor scales_colwise = at::empty({E, num_n_blocks, K}, options_scale); // Call CUDA kernel mxfp8_quantize_3d_cuda(input, output_colwise, scales_colwise, @@ -179,17 +180,8 @@ mxfp8_quantize_3d(torch::Tensor input, int64_t scale_dim_n, } // namespace mxfp8 -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.doc() = "MXFP8 Quantization PyTorch Extension"; - - m.def("quantize", &mxfp8::mxfp8_quantize, "MXFP8 quantization", - py::arg("input"), py::arg("rowwise") = true, py::arg("colwise") = false, - py::arg("scale_dim_x") = 32, py::arg("scale_dim_y") = 32, - py::arg("fp8_format") = "e4m3", - py::arg("scaling_mode") = "floor"); - - m.def("quantize_3d", &mxfp8::mxfp8_quantize_3d, "MXFP8 3D quantization", - py::arg("input"), py::arg("scale_dim_n") = 32, - py::arg("fp8_format") = "e4m3", - py::arg("scaling_mode") = "floor"); +// Register CUDA implementations +TORCH_LIBRARY_IMPL(torchao, CUDA, m) { + m.impl("mxfp8_quantize", &mxfp8::mxfp8_quantize); + m.impl("mxfp8_quantize_3d", &mxfp8::mxfp8_quantize_3d); } diff --git a/torchao/prototype/moe_training/kernels/mxfp8/quant.py b/torchao/prototype/moe_training/kernels/mxfp8/quant.py index 24915d6359..8b023694c7 100644 --- a/torchao/prototype/moe_training/kernels/mxfp8/quant.py +++ b/torchao/prototype/moe_training/kernels/mxfp8/quant.py @@ -1,4 +1,3 @@ -import logging from typing import Tuple import torch @@ -10,6 +9,7 @@ from torchao.prototype.mx_formats.utils import to_blocked from torchao.utils import ( ceil_div, + is_cuda_version_at_least, is_sm_at_least_100, ) @@ -673,24 +673,18 @@ def _blocked_group_start_idx( return group_start_idx -mxfp8_cuda_extension_available = False -if is_sm_at_least_100(): - try: - # MXFP8 CUDA kernel is only built on SM100+. Furthermore, - # currently our CI runners are not SM100+, so the user needs to build - # from source. - # TODO(#2932): improve this - from torchao.prototype import mxfp8_cuda +mxfp8_cuda_extension_available = is_sm_at_least_100() and is_cuda_version_at_least( + 12, 8 +) - mxfp8_cuda_extension_available = True - except ImportError: - logging.debug("Skipping import of torchao.prototype.mxfp8_cuda") +lib = torch.library.Library("torchao", "FRAGMENT") +lib.define( + "mxfp8_quantize_3d(Tensor input, int scale_dim_n, str fp8_format, str scaling_mode) -> (Tensor, Tensor)", + tags=[torch._C.Tag.needs_fixed_stride_order], +) if mxfp8_cuda_extension_available: - # TODO: Make `scaling_mode` a choice (enum-like) rather than arbitrary string. - # Currently we have to use an arbitrary string because custom ops don't support enum - # params. - @torch.library.custom_op("torchao::mxfp8_quantize_cuda_3d", mutates_args=()) + def mxfp8_quantize_cuda_3d( x: torch.Tensor, block_size: int = 32, @@ -699,40 +693,42 @@ def mxfp8_quantize_cuda_3d( """ Quantizes a 3D tensor of shape (E,N,K) to MXFP8 format, scaling along N. + This is a high-level wrapper that calls the underlying CUDA kernel via + torch.ops.torchao.mxfp8_quantize_3d. + Args: x (torch.Tensor): Input tensor to be quantized. block_size (int, optional): Block size for quantization. Defaults to 32. scaling_mode (str, optional): Scaling mode for quantization. Defaults to "floor". Returns: - torch.Tensor: quantized tensor + torch.Tensor: quantized tensor in column-major layout torch.Tensor: scales tensor """ assert x.ndim == 3, "Input tensor must be 3D" - assert x.dtype in (torch.float32, torch.bfloat16), ( - "Input tensor must be float32 or bfloat16" - ) - q_data, scales = mxfp8_cuda.quantize_3d( - x, scale_dim_n=block_size, scaling_mode=scaling_mode + assert x.dtype in ( + torch.float32, + torch.bfloat16, + ), "Input tensor must be float32 or bfloat16" + return torch.ops.torchao.mxfp8_quantize_3d.default( + x, block_size, "e4m3", scaling_mode ) - return q_data, scales - @mxfp8_quantize_cuda_3d.register_fake - def _fake_mxfp8_quantize_cuda_3d( + @torch.library.register_fake("torchao::mxfp8_quantize_3d") + def _fake_mxfp8_quantize_3d( x: torch.Tensor, - block_size: int = 32, - scaling_mode: str = "floor", + scale_dim_n: int, + fp8_format: str, + scaling_mode: str, ) -> Tuple[torch.Tensor, torch.Tensor]: + """Fake/meta implementation for mxfp8_quantize_3d.""" assert x.ndim == 3, "Input tensor must be 3D" - assert x.dtype in (torch.float32, torch.bfloat16), ( - "Input tensor must be float32 or bfloat16" - ) E, N, K = x.shape - # Quantized tensor is in column major layouts + # Quantized tensor is in column major layout q_data = x.new_empty(x.shape, dtype=torch.float8_e4m3fn).as_strided( x.shape, (N * K, 1, N) ) - scales = x.new_empty((E, N // block_size, K), dtype=torch.float8_e8m0fnu) + scales = x.new_empty((E, N // scale_dim_n, K), dtype=torch.float8_e8m0fnu) return q_data, scales else: diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index b4cd192244..903c3ac464 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -18,6 +18,7 @@ _floatx_unpacked_to_f32, ) from torchao.utils import ( + is_cuda_version_at_least, is_sm_at_least_100, torch_version_at_least, ) @@ -626,9 +627,10 @@ def triton_mxfp8_dequant_dim0( scale_block_size: int = 32, ) -> torch.Tensor: assert scale_block_size == 32, "scale_block_size must be 32 for now" - assert out_dtype in (torch.bfloat16, torch.float32), ( - "out_dtype must be bf16 or fp32" - ) + assert out_dtype in ( + torch.bfloat16, + torch.float32, + ), "out_dtype must be bf16 or fp32" # Input shape must be 2D. orig_shape = e4m3_data.shape @@ -1055,6 +1057,7 @@ def _(scale_tensor): padded_cols = n_col_blocks * 4 return scale_tensor.new_empty((padded_rows, padded_cols)) + else: def triton_to_mxfp8_dim0( @@ -1091,30 +1094,39 @@ def triton_mxfp8_dequant_dim0( raise AssertionError("needs torch version 2.8+ and triton") -mxfp8_cuda_extension_available = False -if is_sm_at_least_100(): - try: - # MXFP8 CUDA kernel is only built on SM100+. Furthermore, - # currently our CI runners are not SM100+, so the user needs to build - # from source. - # TODO(#2932): improve this - from torchao.prototype import mxfp8_cuda +mxfp8_cuda_extension_available = is_sm_at_least_100() and is_cuda_version_at_least( + 12, 8 +) - mxfp8_cuda_extension_available = True - except ImportError: - logging.debug("Skipping import of torchao.prototype.mxfp8_cuda") +lib = torch.library.Library("torchao", "FRAGMENT") +lib.define( + "mxfp8_quantize(Tensor input, bool rowwise, bool colwise, int scale_dim_x, int scale_dim_y, str fp8_format, str scaling_mode) -> (Tensor, Tensor, Tensor, Tensor)", + tags=[torch._C.Tag.needs_fixed_stride_order], +) if mxfp8_cuda_extension_available: - # TODO: Make `scaling_mode` a choice (enum-like) rather than arbitrary string. - # Currently we have to use an arbitrary string because custom ops don't support enum - # params. - @torch.library.custom_op("torchao::mxfp8_quantize_cuda", mutates_args=()) + def mxfp8_quantize_cuda( x: torch.Tensor, rowwise: bool = False, colwise: bool = True, scaling_mode: str = "floor", ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Quantizes a 2D tensor to MXFP8 format using CUDA kernels. + + This is a high-level wrapper that calls the underlying CUDA kernel via + torch.ops.torchao.mxfp8_quantize. + + Args: + x: Input tensor to be quantized. Must be 2D with shape (rows, cols). + rowwise: If True, compute rowwise scales. + colwise: If True, compute colwise scales. + scaling_mode: Scaling mode for quantization. Defaults to "floor". + + Returns: + Tuple of (output_rowwise, output_colwise, scales_rowwise, scales_colwise) + """ # Input shape must be 2D. assert x.ndim == 2 rows, cols = x.shape @@ -1124,29 +1136,32 @@ def mxfp8_quantize_cuda( assert rows % block_size == 0, "rows must be a multiple of 32" assert cols % block_size == 0, "cols must be a multiple of 32" - # Convert scaling mode to expected string format and call into kernel. output_rowwise, output_colwise, scales_rowwise, scales_colwise = ( - mxfp8_cuda.quantize( + torch.ops.torchao.mxfp8_quantize.default( x, - rowwise=rowwise, - colwise=colwise, - scaling_mode=scaling_mode, + rowwise, + colwise, + 1, # scale_dim_x + block_size, # scale_dim_y + "e4m3", # fp8_format + scaling_mode, ) ) return output_rowwise, output_colwise, scales_rowwise, scales_colwise - @mxfp8_quantize_cuda.register_fake - def _( + @torch.library.register_fake("torchao::mxfp8_quantize") + def _fake_mxfp8_quantize( x: torch.Tensor, - rowwise: bool = False, - colwise: bool = True, - scaling_mode: str = "floor", + rowwise: bool, + colwise: bool, + scale_dim_x: int, + scale_dim_y: int, + fp8_format: str, + scaling_mode: str, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Fake/meta implementation for mxfp8_quantize.""" assert x.ndim == 2 rows, cols = x.shape - block_size = 32 - assert rows % block_size == 0, "rows must be a multiple of 32" - assert cols % block_size == 0, "cols must be a multiple of 32" num_row_blocks = rows // 32 num_col_blocks = cols // 32 @@ -1180,7 +1195,7 @@ def _( return output_rowwise, output_colwise, scales_rowwise, scales_colwise - @register_sharding(torch.ops.torchao.mxfp8_quantize_cuda.default) + @register_sharding(torch.ops.torchao.mxfp8_quantize.default) def custom_mxfp8_quantize_cuda_dim1_sharding( x: torch.Tensor, rowwise: bool = False, @@ -1216,6 +1231,7 @@ def custom_mxfp8_quantize_cuda_dim1_sharding( rule_for_input_sharded_dim1, ] return acceptable_shardings + else: def mxfp8_quantize_cuda(