From 44db71c0772e5ef5758c38d0e4e8ad9995946c80 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 25 Nov 2025 09:14:49 -0800 Subject: [PATCH 01/14] implement additional cvcuda infra for all branches to avoid duplicate setup --- torchvision/transforms/v2/_transform.py | 4 ++-- torchvision/transforms/v2/_utils.py | 3 ++- .../transforms/v2/functional/__init__.py | 2 +- .../transforms/v2/functional/_augment.py | 11 ++++++++++- .../transforms/v2/functional/_color.py | 12 +++++++++++- .../transforms/v2/functional/_geometry.py | 19 +++++++++++++++++-- torchvision/transforms/v2/functional/_misc.py | 11 +++++++++-- .../transforms/v2/functional/_utils.py | 16 ++++++++++++++++ 8 files changed, 68 insertions(+), 10 deletions(-) diff --git a/torchvision/transforms/v2/_transform.py b/torchvision/transforms/v2/_transform.py index ac84fcb6c82..bec9ffcf714 100644 --- a/torchvision/transforms/v2/_transform.py +++ b/torchvision/transforms/v2/_transform.py @@ -11,7 +11,7 @@ from torchvision.transforms.v2._utils import check_type, has_any, is_pure_tensor from torchvision.utils import _log_api_usage_once -from .functional._utils import _get_kernel +from .functional._utils import _get_kernel, is_cvcuda_tensor class Transform(nn.Module): @@ -23,7 +23,7 @@ class Transform(nn.Module): # Class attribute defining transformed types. Other types are passed-through without any transformation # We support both Types and callables that are able to do further checks on the type of the input. - _transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image) + _transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image, is_cvcuda_tensor) def __init__(self) -> None: super().__init__() diff --git a/torchvision/transforms/v2/_utils.py b/torchvision/transforms/v2/_utils.py index bb6051b4e61..765a772fe41 100644 --- a/torchvision/transforms/v2/_utils.py +++ b/torchvision/transforms/v2/_utils.py @@ -15,7 +15,7 @@ from torchvision._utils import sequence_to_str from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401 -from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor +from torchvision.transforms.v2.functional import get_dimensions, get_size, is_cvcuda_tensor, is_pure_tensor from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT @@ -207,6 +207,7 @@ def query_size(flat_inputs: list[Any]) -> tuple[int, int]: tv_tensors.Mask, tv_tensors.BoundingBoxes, tv_tensors.KeyPoints, + is_cvcuda_tensor, ), ) } diff --git a/torchvision/transforms/v2/functional/__init__.py b/torchvision/transforms/v2/functional/__init__.py index 032a993b1f0..52181e4624b 100644 --- a/torchvision/transforms/v2/functional/__init__.py +++ b/torchvision/transforms/v2/functional/__init__.py @@ -1,6 +1,6 @@ from torchvision.transforms import InterpolationMode # usort: skip -from ._utils import is_pure_tensor, register_kernel # usort: skip +from ._utils import is_pure_tensor, register_kernel, is_cvcuda_tensor # usort: skip from ._meta import ( clamp_bounding_boxes, diff --git a/torchvision/transforms/v2/functional/_augment.py b/torchvision/transforms/v2/functional/_augment.py index a904d8d7cbd..7ce5bdc7b7e 100644 --- a/torchvision/transforms/v2/functional/_augment.py +++ b/torchvision/transforms/v2/functional/_augment.py @@ -1,4 +1,5 @@ import io +from typing import TYPE_CHECKING import PIL.Image @@ -8,7 +9,15 @@ from torchvision.transforms.functional import pil_to_tensor, to_pil_image from torchvision.utils import _log_api_usage_once -from ._utils import _get_kernel, _register_kernel_internal +from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal + + +CVCUDA_AVAILABLE = _is_cvcuda_available() + +if TYPE_CHECKING: + import cvcuda # type: ignore[import-not-found] +if CVCUDA_AVAILABLE: + cvcuda = _import_cvcuda() # noqa: F811 def erase( diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index be254c0d63a..5be9c62902a 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -1,3 +1,5 @@ +from typing import TYPE_CHECKING + import PIL.Image import torch from torch.nn.functional import conv2d @@ -9,7 +11,15 @@ from ._misc import _num_value_bits, to_dtype_image from ._type_conversion import pil_to_tensor, to_pil_image -from ._utils import _get_kernel, _register_kernel_internal +from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal + + +CVCUDA_AVAILABLE = _is_cvcuda_available() + +if TYPE_CHECKING: + import cvcuda # type: ignore[import-not-found] +if CVCUDA_AVAILABLE: + cvcuda = _import_cvcuda() # noqa: F811 def rgb_to_grayscale(inpt: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor: diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 4fcb7fabe0d..c029488001c 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -2,7 +2,7 @@ import numbers import warnings from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any, Optional, TYPE_CHECKING, Union import PIL.Image import torch @@ -26,7 +26,22 @@ from ._meta import _get_size_image_pil, clamp_bounding_boxes, convert_bounding_box_format -from ._utils import _FillTypeJIT, _get_kernel, _register_five_ten_crop_kernel_internal, _register_kernel_internal +from ._utils import ( + _FillTypeJIT, + _get_kernel, + _import_cvcuda, + _is_cvcuda_available, + _register_five_ten_crop_kernel_internal, + _register_kernel_internal, +) + + +CVCUDA_AVAILABLE = _is_cvcuda_available() + +if TYPE_CHECKING: + import cvcuda # type: ignore[import-not-found] +if CVCUDA_AVAILABLE: + cvcuda = _import_cvcuda() # noqa: F811 def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode: diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index daf263df046..0fa05a2113c 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -1,5 +1,5 @@ import math -from typing import Optional +from typing import Optional, TYPE_CHECKING import PIL.Image import torch @@ -13,7 +13,14 @@ from ._meta import _convert_bounding_box_format -from ._utils import _get_kernel, _register_kernel_internal, is_pure_tensor +from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal, is_pure_tensor + +CVCUDA_AVAILABLE = _is_cvcuda_available() + +if TYPE_CHECKING: + import cvcuda # type: ignore[import-not-found] +if CVCUDA_AVAILABLE: + cvcuda = _import_cvcuda() # noqa: F811 def normalize( diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index ad1eddd258b..73fafaf7425 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -169,3 +169,19 @@ def _is_cvcuda_available(): return True except ImportError: return False + + +def is_cvcuda_tensor(inpt: Any) -> bool: + """ + Check if the input is a CVCUDA tensor. + + Args: + inpt: The input to check. + + Returns: + True if the input is a CV-CUDA tensor, False otherwise. + """ + if _is_cvcuda_available(): + cvcuda = _import_cvcuda() + return isinstance(inpt, cvcuda.Tensor) + return False From e3dd70022fa1c87aca7a9a98068b6e13e802a375 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 25 Nov 2025 09:26:19 -0800 Subject: [PATCH 02/14] update make_image_cvcuda to have default batch dim --- test/common_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index 8c3c9dd58a8..e7bae60c41b 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -400,8 +400,9 @@ def make_image_pil(*args, **kwargs): return to_pil_image(make_image(*args, **kwargs)) -def make_image_cvcuda(*args, **kwargs): - return to_cvcuda_tensor(make_image(*args, **kwargs)) +def make_image_cvcuda(*args, batch_dims=(1,), **kwargs): + # explicitly default batch_dims to (1,) since to_cvcuda_tensor requires a batch dimension (ndims == 4) + return to_cvcuda_tensor(make_image(*args, batch_dims=batch_dims, **kwargs)) def make_keypoints(canvas_size=DEFAULT_SIZE, *, num_points=4, dtype=None, device="cpu"): From c035df1c6eaebcad25604f8c298a7d9eaf86864b Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Mon, 1 Dec 2025 18:16:27 -0800 Subject: [PATCH 03/14] add stanardized setup to main for easier updating of PRs and branches --- test/common_utils.py | 21 ++++++++++++++-- test/test_transforms_v2.py | 2 +- torchvision/transforms/v2/_utils.py | 2 +- torchvision/transforms/v2/functional/_meta.py | 24 +++++++++++++++++-- 4 files changed, 43 insertions(+), 6 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index e7bae60c41b..3b889e93d2e 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -20,13 +20,15 @@ from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair from torchvision import io, tv_tensors from torchvision.transforms._functional_tensor import _max_value as get_max_value -from torchvision.transforms.v2.functional import to_cvcuda_tensor, to_image, to_pil_image +from torchvision.transforms.v2.functional import cvcuda_to_tensor, to_cvcuda_tensor, to_image, to_pil_image +from torchvision.transforms.v2.functional._utils import _import_cvcuda, _is_cvcuda_available from torchvision.utils import _Image_fromarray IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"]) IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1" +CVCUDA_AVAILABLE = _is_cvcuda_available() CUDA_NOT_AVAILABLE_MSG = "CUDA device not available" MPS_NOT_AVAILABLE_MSG = "MPS device not available" OSS_CI_GPU_NO_CUDA_MSG = "We're in an OSS GPU machine, and this test doesn't need cuda." @@ -275,6 +277,17 @@ def combinations_grid(**kwargs): return [dict(zip(kwargs.keys(), values)) for values in itertools.product(*kwargs.values())] +def cvcuda_to_pil_compatible_tensor(tensor: "cvcuda.Tensor") -> torch.Tensor: + tensor = cvcuda_to_tensor(tensor) + if tensor.ndim != 4: + raise ValueError(f"CV-CUDA Tensor should be 4 dimensional. Got {tensor.ndim} dimensions.") + if tensor.shape[0] != 1: + raise ValueError( + f"CV-CUDA Tensor should have batch dimension 1 for comparison with PIL.Image.Image. Got {tensor.shape[0]}." + ) + return tensor.squeeze(0).cpu() + + class ImagePair(TensorLikePair): def __init__( self, @@ -287,6 +300,11 @@ def __init__( if all(isinstance(input, PIL.Image.Image) for input in [actual, expected]): actual, expected = (to_image(input) for input in [actual, expected]) + # handle check for CV-CUDA Tensors + if CVCUDA_AVAILABLE and isinstance(actual, _import_cvcuda().Tensor): + # Use the PIL compatible tensor, so we can always compare with PIL.Image.Image + actual = cvcuda_to_pil_compatible_tensor(actual) + super().__init__(actual, expected, **other_parameters) self.mae = mae @@ -401,7 +419,6 @@ def make_image_pil(*args, **kwargs): def make_image_cvcuda(*args, batch_dims=(1,), **kwargs): - # explicitly default batch_dims to (1,) since to_cvcuda_tensor requires a batch dimension (ndims == 4) return to_cvcuda_tensor(make_image(*args, batch_dims=batch_dims, **kwargs)) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 670a9d00ffb..7eba65550da 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -21,6 +21,7 @@ import torchvision.transforms.v2 as transforms from common_utils import ( + assert_close, assert_equal, cache, cpu_and_cuda, @@ -41,7 +42,6 @@ ) from torch import nn -from torch.testing import assert_close from torch.utils._pytree import tree_flatten, tree_map from torch.utils.data import DataLoader, default_collate from torchvision import tv_tensors diff --git a/torchvision/transforms/v2/_utils.py b/torchvision/transforms/v2/_utils.py index 765a772fe41..3fc33ce5964 100644 --- a/torchvision/transforms/v2/_utils.py +++ b/torchvision/transforms/v2/_utils.py @@ -182,7 +182,7 @@ def query_chw(flat_inputs: list[Any]) -> tuple[int, int, int]: chws = { tuple(get_dimensions(inpt)) for inpt in flat_inputs - if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video)) + if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video, is_cvcuda_tensor)) } if not chws: raise TypeError("No image or video was found in the sample") diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index 6b8f19f12f4..ee562cb2aee 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -51,6 +51,16 @@ def get_dimensions_video(video: torch.Tensor) -> list[int]: return get_dimensions_image(video) +def _get_dimensions_cvcuda(image: "cvcuda.Tensor") -> list[int]: + # CV-CUDA tensor is always in NHWC layout + # get_dimensions is CHW + return [image.shape[3], image.shape[1], image.shape[2]] + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(get_dimensions, cvcuda.Tensor)(_get_dimensions_cvcuda) + + def get_num_channels(inpt: torch.Tensor) -> int: if torch.jit.is_scripting(): return get_num_channels_image(inpt) @@ -87,6 +97,16 @@ def get_num_channels_video(video: torch.Tensor) -> int: get_image_num_channels = get_num_channels +def _get_num_channels_cvcuda(image: "cvcuda.Tensor") -> int: + # CV-CUDA tensor is always in NHWC layout + # get_num_channels is C + return image.shape[3] + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(get_num_channels, cvcuda.Tensor)(_get_num_channels_cvcuda) + + def get_size(inpt: torch.Tensor) -> list[int]: if torch.jit.is_scripting(): return get_size_image(inpt) @@ -114,7 +134,7 @@ def _get_size_image_pil(image: PIL.Image.Image) -> list[int]: return [height, width] -def get_size_image_cvcuda(image: "cvcuda.Tensor") -> list[int]: +def _get_size_cvcuda(image: "cvcuda.Tensor") -> list[int]: """Get size of `cvcuda.Tensor` with NHWC layout.""" hw = list(image.shape[-3:-1]) ndims = len(hw) @@ -125,7 +145,7 @@ def get_size_image_cvcuda(image: "cvcuda.Tensor") -> list[int]: if CVCUDA_AVAILABLE: - _get_size_image_cvcuda = _register_kernel_internal(get_size, cvcuda.Tensor)(get_size_image_cvcuda) + _register_kernel_internal(get_size, cvcuda.Tensor)(_get_size_cvcuda) @_register_kernel_internal(get_size, tv_tensors.Video, tv_tensor_wrapper=False) From 98d7dfb2059eaf2c10c3f549ea45f1d27875134c Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Mon, 1 Dec 2025 18:25:09 -0800 Subject: [PATCH 04/14] update is_cvcuda_tensor --- torchvision/transforms/v2/functional/_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index 73fafaf7425..44b2edeaf2d 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -181,7 +181,8 @@ def is_cvcuda_tensor(inpt: Any) -> bool: Returns: True if the input is a CV-CUDA tensor, False otherwise. """ - if _is_cvcuda_available(): + try: cvcuda = _import_cvcuda() return isinstance(inpt, cvcuda.Tensor) - return False + except ImportError: + return False From ddc116d13febdae1d53507bcde9f103a4c14eba7 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 2 Dec 2025 12:37:03 -0800 Subject: [PATCH 05/14] add cvcuda to pil compatible to transforms by default --- test/test_transforms_v2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 7eba65550da..87166477669 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -25,6 +25,7 @@ assert_equal, cache, cpu_and_cuda, + cvcuda_to_pil_compatible_tensor, freeze_rng_state, ignore_jit_no_profile_information_warning, make_bounding_boxes, From e51dc7eabd254261347245f4492892fd0944aae5 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 2 Dec 2025 12:46:23 -0800 Subject: [PATCH 06/14] remove cvcuda from transform class --- torchvision/transforms/v2/_transform.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms/v2/_transform.py b/torchvision/transforms/v2/_transform.py index bec9ffcf714..ac84fcb6c82 100644 --- a/torchvision/transforms/v2/_transform.py +++ b/torchvision/transforms/v2/_transform.py @@ -11,7 +11,7 @@ from torchvision.transforms.v2._utils import check_type, has_any, is_pure_tensor from torchvision.utils import _log_api_usage_once -from .functional._utils import _get_kernel, is_cvcuda_tensor +from .functional._utils import _get_kernel class Transform(nn.Module): @@ -23,7 +23,7 @@ class Transform(nn.Module): # Class attribute defining transformed types. Other types are passed-through without any transformation # We support both Types and callables that are able to do further checks on the type of the input. - _transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image, is_cvcuda_tensor) + _transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image) def __init__(self) -> None: super().__init__() From 4939355a2c7421eeba95d7f155fe7953066aec6d Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Thu, 4 Dec 2025 11:07:08 -0800 Subject: [PATCH 07/14] resolve more formatting naming --- torchvision/transforms/v2/functional/__init__.py | 2 +- torchvision/transforms/v2/functional/_meta.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/torchvision/transforms/v2/functional/__init__.py b/torchvision/transforms/v2/functional/__init__.py index 52181e4624b..032a993b1f0 100644 --- a/torchvision/transforms/v2/functional/__init__.py +++ b/torchvision/transforms/v2/functional/__init__.py @@ -1,6 +1,6 @@ from torchvision.transforms import InterpolationMode # usort: skip -from ._utils import is_pure_tensor, register_kernel, is_cvcuda_tensor # usort: skip +from ._utils import is_pure_tensor, register_kernel # usort: skip from ._meta import ( clamp_bounding_boxes, diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index e8630f788ca..af03ad018d4 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -51,14 +51,14 @@ def get_dimensions_video(video: torch.Tensor) -> list[int]: return get_dimensions_image(video) -def _get_dimensions_cvcuda(image: "cvcuda.Tensor") -> list[int]: +def get_dimensions_image_cvcuda(image: "cvcuda.Tensor") -> list[int]: # CV-CUDA tensor is always in NHWC layout # get_dimensions is CHW return [image.shape[3], image.shape[1], image.shape[2]] if CVCUDA_AVAILABLE: - _register_kernel_internal(get_dimensions, cvcuda.Tensor)(_get_dimensions_cvcuda) + _register_kernel_internal(get_dimensions, cvcuda.Tensor)(get_dimensions_image_cvcuda) def get_num_channels(inpt: torch.Tensor) -> int: @@ -97,14 +97,14 @@ def get_num_channels_video(video: torch.Tensor) -> int: get_image_num_channels = get_num_channels -def _get_num_channels_cvcuda(image: "cvcuda.Tensor") -> int: +def get_num_channels_image_cvcuda(image: "cvcuda.Tensor") -> int: # CV-CUDA tensor is always in NHWC layout # get_num_channels is C return image.shape[3] if CVCUDA_AVAILABLE: - _register_kernel_internal(get_num_channels, cvcuda.Tensor)(_get_num_channels_cvcuda) + _register_kernel_internal(get_num_channels, cvcuda.Tensor)(get_num_channels_image_cvcuda) def get_size(inpt: torch.Tensor) -> list[int]: From fbea584365311ae6b56be7e4f6bbff1f834dd31a Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Thu, 4 Dec 2025 11:15:49 -0800 Subject: [PATCH 08/14] update is cvcuda tensor impl --- torchvision/transforms/v2/_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchvision/transforms/v2/_utils.py b/torchvision/transforms/v2/_utils.py index 3fc33ce5964..e803aa49c60 100644 --- a/torchvision/transforms/v2/_utils.py +++ b/torchvision/transforms/v2/_utils.py @@ -15,8 +15,8 @@ from torchvision._utils import sequence_to_str from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401 -from torchvision.transforms.v2.functional import get_dimensions, get_size, is_cvcuda_tensor, is_pure_tensor -from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT +from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor +from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT, _is_cvcuda_tensor def _setup_number_or_seq(arg: int | float | Sequence[int | float], name: str) -> Sequence[float]: @@ -182,7 +182,7 @@ def query_chw(flat_inputs: list[Any]) -> tuple[int, int, int]: chws = { tuple(get_dimensions(inpt)) for inpt in flat_inputs - if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video, is_cvcuda_tensor)) + if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video, _is_cvcuda_tensor)) } if not chws: raise TypeError("No image or video was found in the sample") @@ -207,7 +207,7 @@ def query_size(flat_inputs: list[Any]) -> tuple[int, int]: tv_tensors.Mask, tv_tensors.BoundingBoxes, tv_tensors.KeyPoints, - is_cvcuda_tensor, + _is_cvcuda_tensor, ), ) } From 9eb3cdf1425deb9aa2cddc59768bb1a04be00e6c Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 2 Dec 2025 18:02:35 -0800 Subject: [PATCH 09/14] first pass at erase --- test/test_transforms_v2.py | 92 +++++++++++++++---- torchvision/transforms/v2/_augment.py | 7 ++ .../transforms/v2/functional/_augment.py | 48 ++++++++++ 3 files changed, 129 insertions(+), 18 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index f767e211125..965beb616f5 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -3779,17 +3779,17 @@ def test_kernel_image(self, dtype, device): @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8]) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_kernel_image_inplace(self, dtype, device): - input = make_image(self.INPUT_SIZE, dtype=dtype, device=device) - input_version = input._version + inpt = make_image(self.INPUT_SIZE, dtype=dtype, device=device) + input_version = inpt._version - output_out_of_place = F.erase_image(input, **self.FUNCTIONAL_KWARGS) - assert output_out_of_place.data_ptr() != input.data_ptr() - assert output_out_of_place is not input + output_out_of_place = F.erase_image(inpt, **self.FUNCTIONAL_KWARGS) + assert output_out_of_place.data_ptr() != inpt.data_ptr() + assert output_out_of_place is not inpt - output_inplace = F.erase_image(input, **self.FUNCTIONAL_KWARGS, inplace=True) - assert output_inplace.data_ptr() == input.data_ptr() + output_inplace = F.erase_image(inpt, **self.FUNCTIONAL_KWARGS, inplace=True) + assert output_inplace.data_ptr() == inpt.data_ptr() assert output_inplace._version > input_version - assert output_inplace is input + assert output_inplace is inpt assert_equal(output_inplace, output_out_of_place) @@ -3798,7 +3798,15 @@ def test_kernel_video(self): @pytest.mark.parametrize( "make_input", - [make_image_tensor, make_image_pil, make_image, make_video], + [ + make_image_tensor, + make_image_pil, + make_image, + make_video, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available") + ), + ], ) def test_functional(self, make_input): check_functional(F.erase, make_input(), **self.FUNCTIONAL_KWARGS) @@ -3810,25 +3818,48 @@ def test_functional(self, make_input): (F._augment._erase_image_pil, PIL.Image.Image), (F.erase_image, tv_tensors.Image), (F.erase_video, tv_tensors.Video), + pytest.param( + F._augment._erase_cvcuda, + "cvcuda.Tensor", + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available"), + ), ], ) def test_functional_signature(self, kernel, input_type): + if input_type == "cvcuda.Tensor": + input_type = _import_cvcuda().Tensor check_functional_kernel_signature_match(F.erase, kernel=kernel, input_type=input_type) @pytest.mark.parametrize( "make_input", - [make_image_tensor, make_image_pil, make_image, make_video], + [ + make_image_tensor, + make_image_pil, + make_image, + make_video, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available") + ), + ], ) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_transform(self, make_input, device): - input = make_input(device=device) + inpt = make_input(device=device) - with pytest.warns(UserWarning, match="currently passing through inputs of type"): + # shouldn't get a warning for cvcuda + if make_input is make_image_cvcuda: check_transform( transforms.RandomErasing(p=1), - input, - check_v1_compatibility=not isinstance(input, PIL.Image.Image), + inpt, + check_v1_compatibility=False, ) + else: + with pytest.warns(UserWarning, match="currently passing through inputs of type"): + check_transform( + transforms.RandomErasing(p=1), + inpt, + check_v1_compatibility=not isinstance(inpt, PIL.Image.Image), + ) def _reference_erase_image(self, image, *, i, j, h, w, v): mask = torch.zeros_like(image, dtype=torch.bool) @@ -3843,16 +3874,38 @@ def _reference_erase_image(self, image, *, i, j, h, w, v): return erased_image + @pytest.mark.parametrize( + "make_input", + [ + make_image, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available") + ), + ], + ) @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8]) @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_functional_image_correctness(self, dtype, device): - image = make_image(dtype=dtype, device=device) + def test_functional_image_correctness(self, make_input, dtype, device): + image = make_input(dtype=dtype, device=device) actual = F.erase(image, **self.FUNCTIONAL_KWARGS) + + if make_input is make_image_cvcuda: + image = cvcuda_to_pil_compatible_tensor(image) + expected = self._reference_erase_image(image, **self.FUNCTIONAL_KWARGS) assert_equal(actual, expected) + @pytest.mark.parametrize( + "make_input", + [ + make_image, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available") + ), + ], + ) @param_value_parametrization( scale=[(0.1, 0.2), [0.0, 1.0]], ratio=[(0.3, 0.7), [0.1, 5.0]], @@ -3861,10 +3914,10 @@ def test_functional_image_correctness(self, dtype, device): @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8]) @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("seed", list(range(5))) - def test_transform_image_correctness(self, param, value, dtype, device, seed): + def test_transform_image_correctness(self, make_input, param, value, dtype, device, seed): transform = transforms.RandomErasing(**{param: value}, p=1) - image = make_image(dtype=dtype, device=device) + image = make_input(dtype=dtype, device=device) with freeze_rng_state(): torch.manual_seed(seed) @@ -3875,6 +3928,9 @@ def test_transform_image_correctness(self, param, value, dtype, device, seed): torch.manual_seed(seed) actual = transform(image) + if make_input is make_image_cvcuda: + image = cvcuda_to_pil_compatible_tensor(image) + expected = self._reference_erase_image(image, **params) assert_equal(actual, expected) diff --git a/torchvision/transforms/v2/_augment.py b/torchvision/transforms/v2/_augment.py index c6da9aba98b..cfa4e3fd2e1 100644 --- a/torchvision/transforms/v2/_augment.py +++ b/torchvision/transforms/v2/_augment.py @@ -10,11 +10,15 @@ from torch.utils._pytree import tree_flatten, tree_unflatten from torchvision import transforms as _transforms, tv_tensors from torchvision.transforms.v2 import functional as F +from torchvision.transforms.v2.functional._utils import _is_cvcuda_available, is_cvcuda_tensor from ._transform import _RandomApplyTransform, Transform from ._utils import _check_sequence_input, _parse_labels_getter, has_any, is_pure_tensor, query_chw, query_size +CVCUDA_AVAILABLE = _is_cvcuda_available() + + class RandomErasing(_RandomApplyTransform): """Randomly select a rectangle region in the input image or video and erase its pixels. @@ -48,6 +52,9 @@ class RandomErasing(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomErasing + if CVCUDA_AVAILABLE: + _transformed_types = _RandomApplyTransform._transformed_types + (is_cvcuda_tensor,) + def _extract_params_for_v1_transform(self) -> dict[str, Any]: return dict( super()._extract_params_for_v1_transform(), diff --git a/torchvision/transforms/v2/functional/_augment.py b/torchvision/transforms/v2/functional/_augment.py index 7ce5bdc7b7e..512017eebde 100644 --- a/torchvision/transforms/v2/functional/_augment.py +++ b/torchvision/transforms/v2/functional/_augment.py @@ -1,6 +1,8 @@ import io from typing import TYPE_CHECKING +import numpy as np + import PIL.Image import torch @@ -67,6 +69,52 @@ def erase_video( return erase_image(video, i=i, j=j, h=h, w=w, v=v, inplace=inplace) +def _erase_cvcuda( + image: "cvcuda.Tensor", + i: int, + j: int, + h: int, + w: int, + v: torch.Tensor, + inplace: bool = False, +) -> "cvcuda.Tensor": + if inplace: + raise ValueError("inplace is not supported for cvcuda.Tensor") + + anchor = torch.tensor(np.array([j, i]), dtype=torch.int32, device="cuda") + cv_anchor = cvcuda.as_tensor(anchor, "NC").reshape((2,), "N") + erasing = torch.tensor(np.array([w, h, 7]), dtype=torch.int32, device="cuda") + cv_erasing = cvcuda.as_tensor(erasing, "NC").reshape((3,), "N") + imgIdx = torch.tensor(np.array([0]), dtype=torch.int32, device="cuda") + cv_imgIdx = cvcuda.as_tensor(imgIdx, "N").reshape((1,), "N") + + num_channels = image.shape[3] + # Flatten v and expand to match the number of channels if it's a single value + # CV-CUDA erase expects values as float32 + # Use repeat instead of expand to create a new tensor (avoids cvcuda state pollution) + v_dup = v.clone() + v_flat = v_dup.flatten().to(dtype=torch.float32, device="cuda") + if v_flat.numel() == 1: + v_flat = v_flat.repeat(num_channels) + cv_values = cvcuda.as_tensor(v_flat, "NC").reshape((num_channels,), "N") + + result = cvcuda.erase( + src=image, + anchor=cv_anchor, + erasing=cv_erasing, + values=cv_values, + imgIdx=cv_imgIdx, + random=False, + seed=0, + ) + + return result + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(erase, _import_cvcuda().Tensor)(_erase_cvcuda) + + def jpeg(image: torch.Tensor, quality: int) -> torch.Tensor: """See :class:`~torchvision.transforms.v2.JPEG` for details.""" if torch.jit.is_scripting(): From 3468d72bcee8ec7188a26d205bef48d260cfb93c Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 2 Dec 2025 18:06:35 -0800 Subject: [PATCH 10/14] update comments --- torchvision/transforms/v2/functional/_augment.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchvision/transforms/v2/functional/_augment.py b/torchvision/transforms/v2/functional/_augment.py index 512017eebde..6f70db83a65 100644 --- a/torchvision/transforms/v2/functional/_augment.py +++ b/torchvision/transforms/v2/functional/_augment.py @@ -91,7 +91,6 @@ def _erase_cvcuda( num_channels = image.shape[3] # Flatten v and expand to match the number of channels if it's a single value # CV-CUDA erase expects values as float32 - # Use repeat instead of expand to create a new tensor (avoids cvcuda state pollution) v_dup = v.clone() v_flat = v_dup.flatten().to(dtype=torch.float32, device="cuda") if v_flat.numel() == 1: From 6eb7d4031189ff74aff9c3055fa4a2a62cc47a85 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Fri, 5 Dec 2025 10:56:06 -0800 Subject: [PATCH 11/14] begin updating erase --- test/common_utils.py | 4 ++-- test/test_transforms_v2.py | 3 +-- torchvision/transforms/v2/functional/_augment.py | 8 ++++---- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index e3fa464b5ea..54587cd6fc8 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -20,8 +20,8 @@ from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair from torchvision import io, tv_tensors from torchvision.transforms._functional_tensor import _max_value as get_max_value -from torchvision.transforms.v2.functional import cvcuda_to_tensor, to_cvcuda_tensor, to_image, to_pil_image -from torchvision.transforms.v2.functional._utils import _is_cvcuda_available, _is_cvcuda_tensor +from torchvision.transforms.v2.functional import to_cvcuda_tensor, to_image, to_pil_image +from torchvision.transforms.v2.functional._utils import _import_cvcuda, _is_cvcuda_available from torchvision.utils import _Image_fromarray diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 965beb616f5..92174113dd8 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -25,7 +25,6 @@ assert_equal, cache, cpu_and_cuda, - cvcuda_to_pil_compatible_tensor, freeze_rng_state, ignore_jit_no_profile_information_warning, make_bounding_boxes, @@ -3891,7 +3890,7 @@ def test_functional_image_correctness(self, make_input, dtype, device): actual = F.erase(image, **self.FUNCTIONAL_KWARGS) if make_input is make_image_cvcuda: - image = cvcuda_to_pil_compatible_tensor(image) + image = F.cvcuda_to_tensor(image)[0].cpu() expected = self._reference_erase_image(image, **self.FUNCTIONAL_KWARGS) diff --git a/torchvision/transforms/v2/functional/_augment.py b/torchvision/transforms/v2/functional/_augment.py index 6f70db83a65..c4adebfcf3e 100644 --- a/torchvision/transforms/v2/functional/_augment.py +++ b/torchvision/transforms/v2/functional/_augment.py @@ -18,8 +18,6 @@ if TYPE_CHECKING: import cvcuda # type: ignore[import-not-found] -if CVCUDA_AVAILABLE: - cvcuda = _import_cvcuda() # noqa: F811 def erase( @@ -69,7 +67,7 @@ def erase_video( return erase_image(video, i=i, j=j, h=h, w=w, v=v, inplace=inplace) -def _erase_cvcuda( +def _erase_image_cvcuda( image: "cvcuda.Tensor", i: int, j: int, @@ -78,6 +76,8 @@ def _erase_cvcuda( v: torch.Tensor, inplace: bool = False, ) -> "cvcuda.Tensor": + cvcuda = _import_cvcuda() + if inplace: raise ValueError("inplace is not supported for cvcuda.Tensor") @@ -111,7 +111,7 @@ def _erase_cvcuda( if CVCUDA_AVAILABLE: - _register_kernel_internal(erase, _import_cvcuda().Tensor)(_erase_cvcuda) + _register_kernel_internal(erase, _import_cvcuda().Tensor)(_erase_image_cvcuda) def jpeg(image: torch.Tensor, quality: int) -> torch.Tensor: From 3a2a9fc4d22674a91ea11cd8bd9b7a455c9e17d9 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Fri, 5 Dec 2025 11:55:18 -0800 Subject: [PATCH 12/14] erase complete and verified for everything but random fill, but uses memcpy hack, need to figure out --- test/common_utils.py | 4 +- test/test_transforms_v2.py | 14 ++-- torchvision/transforms/v2/_augment.py | 4 +- .../transforms/v2/functional/_augment.py | 81 +++++++++++++++---- 4 files changed, 78 insertions(+), 25 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index 54587cd6fc8..e3fa464b5ea 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -20,8 +20,8 @@ from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair from torchvision import io, tv_tensors from torchvision.transforms._functional_tensor import _max_value as get_max_value -from torchvision.transforms.v2.functional import to_cvcuda_tensor, to_image, to_pil_image -from torchvision.transforms.v2.functional._utils import _import_cvcuda, _is_cvcuda_available +from torchvision.transforms.v2.functional import cvcuda_to_tensor, to_cvcuda_tensor, to_image, to_pil_image +from torchvision.transforms.v2.functional._utils import _is_cvcuda_available, _is_cvcuda_tensor from torchvision.utils import _Image_fromarray diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 92174113dd8..ae956de73c9 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -3818,14 +3818,14 @@ def test_functional(self, make_input): (F.erase_image, tv_tensors.Image), (F.erase_video, tv_tensors.Video), pytest.param( - F._augment._erase_cvcuda, - "cvcuda.Tensor", + F._augment._erase_image_cvcuda, + None, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available"), ), ], ) def test_functional_signature(self, kernel, input_type): - if input_type == "cvcuda.Tensor": + if kernel is F._augment._erase_image_cvcuda: input_type = _import_cvcuda().Tensor check_functional_kernel_signature_match(F.erase, kernel=kernel, input_type=input_type) @@ -3928,11 +3928,15 @@ def test_transform_image_correctness(self, make_input, param, value, dtype, devi actual = transform(image) if make_input is make_image_cvcuda: - image = cvcuda_to_pil_compatible_tensor(image) + image = F.cvcuda_to_tensor(image)[0].cpu() expected = self._reference_erase_image(image, **params) - assert_equal(actual, expected) + if make_input is make_image_cvcuda and value == "random": + # CV-CUDA doesnt not support random per-pixel fill types + assert_close(actual, expected, rtol=0, atol=255) + else: + assert_equal(actual, expected) def test_transform_errors(self): with pytest.raises(TypeError, match="Argument value should be either a number or str or a sequence"): diff --git a/torchvision/transforms/v2/_augment.py b/torchvision/transforms/v2/_augment.py index cfa4e3fd2e1..ccb61e57069 100644 --- a/torchvision/transforms/v2/_augment.py +++ b/torchvision/transforms/v2/_augment.py @@ -10,7 +10,7 @@ from torch.utils._pytree import tree_flatten, tree_unflatten from torchvision import transforms as _transforms, tv_tensors from torchvision.transforms.v2 import functional as F -from torchvision.transforms.v2.functional._utils import _is_cvcuda_available, is_cvcuda_tensor +from torchvision.transforms.v2.functional._utils import _is_cvcuda_available, _is_cvcuda_tensor from ._transform import _RandomApplyTransform, Transform from ._utils import _check_sequence_input, _parse_labels_getter, has_any, is_pure_tensor, query_chw, query_size @@ -53,7 +53,7 @@ class RandomErasing(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomErasing if CVCUDA_AVAILABLE: - _transformed_types = _RandomApplyTransform._transformed_types + (is_cvcuda_tensor,) + _transformed_types = _RandomApplyTransform._transformed_types + (_is_cvcuda_tensor,) def _extract_params_for_v1_transform(self) -> dict[str, Any]: return dict( diff --git a/torchvision/transforms/v2/functional/_augment.py b/torchvision/transforms/v2/functional/_augment.py index c4adebfcf3e..df2d0415df4 100644 --- a/torchvision/transforms/v2/functional/_augment.py +++ b/torchvision/transforms/v2/functional/_augment.py @@ -1,8 +1,7 @@ +import ctypes import io from typing import TYPE_CHECKING -import numpy as np - import PIL.Image import torch @@ -81,21 +80,71 @@ def _erase_image_cvcuda( if inplace: raise ValueError("inplace is not supported for cvcuda.Tensor") - anchor = torch.tensor(np.array([j, i]), dtype=torch.int32, device="cuda") - cv_anchor = cvcuda.as_tensor(anchor, "NC").reshape((2,), "N") - erasing = torch.tensor(np.array([w, h, 7]), dtype=torch.int32, device="cuda") - cv_erasing = cvcuda.as_tensor(erasing, "NC").reshape((3,), "N") - imgIdx = torch.tensor(np.array([0]), dtype=torch.int32, device="cuda") - cv_imgIdx = cvcuda.as_tensor(imgIdx, "N").reshape((1,), "N") - - num_channels = image.shape[3] - # Flatten v and expand to match the number of channels if it's a single value - # CV-CUDA erase expects values as float32 - v_dup = v.clone() - v_flat = v_dup.flatten().to(dtype=torch.float32, device="cuda") + # Load CUDA runtime for memory copy + try: + cudart = ctypes.CDLL("libcudart.so") + except OSError: + cudart = ctypes.CDLL("libcudart.so.12") + cudart.cudaMemcpy.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int] + cudart.cudaMemcpy.restype = ctypes.c_int + CUDA_MEMCPY_D2D = 3 # cudaMemcpyDeviceToDevice + + num_erasing_areas = 1 + num_channels = image.shape[3] # NHWC layout + + # Create CV-CUDA tensors with proper compound types + cv_anchor = cvcuda.Tensor((num_erasing_areas,), cvcuda.Type._2S32, "N") + cv_erasing = cvcuda.Tensor((num_erasing_areas,), cvcuda.Type._3S32, "N") + cv_imgIdx = cvcuda.Tensor((num_erasing_areas,), cvcuda.Type.S32, "N") + # Values tensor - 4 floats per erasing area (CV-CUDA standard, supports up to RGBA) + cv_values = cvcuda.Tensor((num_erasing_areas * 4,), cvcuda.Type.F32, "N") + + # Create source torch tensors with the data + anchor_src = torch.tensor([j, i], dtype=torch.int32, device="cuda") + # The third value is a bitmask for which channels to fill: 7 = 0b111 = RGB all channels + channel_mask = (1 << num_channels) - 1 # e.g., 3 channels -> 0b111 = 7 + erasing_src = torch.tensor([w, h, channel_mask], dtype=torch.int32, device="cuda") + imgIdx_src = torch.tensor([0], dtype=torch.int32, device="cuda") + + # Get fill values for erasing - need 4 floats per erasing area (CV-CUDA format) + v_flat = v.flatten().to(dtype=torch.float32, device="cuda") + # Expand to 4 values (CV-CUDA always expects 4) if v_flat.numel() == 1: - v_flat = v_flat.repeat(num_channels) - cv_values = cvcuda.as_tensor(v_flat, "NC").reshape((num_channels,), "N") + # Single value - replicate to all 4 slots + values_src = v_flat.expand(4).contiguous() + elif v_flat.numel() >= 4: + # Has enough values, take first 4 + values_src = v_flat[:4].contiguous() + else: + # Has fewer than 4 values, pad with zeros + padding = torch.zeros(4 - v_flat.numel(), dtype=torch.float32, device="cuda") + values_src = torch.cat([v_flat, padding]) + + # Copy data from torch tensors to CV-CUDA tensors using cudaMemcpy + cudart.cudaMemcpy( + cv_anchor.cuda().__cuda_array_interface__["data"][0], + anchor_src.data_ptr(), + 8, + CUDA_MEMCPY_D2D, # 2 x int32 = 8 bytes + ) + cudart.cudaMemcpy( + cv_erasing.cuda().__cuda_array_interface__["data"][0], + erasing_src.data_ptr(), + 12, + CUDA_MEMCPY_D2D, # 3 x int32 = 12 bytes + ) + cudart.cudaMemcpy( + cv_imgIdx.cuda().__cuda_array_interface__["data"][0], + imgIdx_src.data_ptr(), + 4, + CUDA_MEMCPY_D2D, # 1 x int32 = 4 bytes + ) + cudart.cudaMemcpy( + cv_values.cuda().__cuda_array_interface__["data"][0], + values_src.data_ptr(), + 16, + CUDA_MEMCPY_D2D, # 4 x float32 = 16 bytes + ) result = cvcuda.erase( src=image, From 7a268116475a1607f4da7baa04d2ef081157303b Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Fri, 5 Dec 2025 13:10:09 -0800 Subject: [PATCH 13/14] cleanup cvcuda erase op considerably --- test/test_transforms_v2.py | 6 +- .../transforms/v2/functional/_augment.py | 130 +++++++++--------- 2 files changed, 66 insertions(+), 70 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index ae956de73c9..546879c3e86 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -3933,8 +3933,10 @@ def test_transform_image_correctness(self, make_input, param, value, dtype, devi expected = self._reference_erase_image(image, **params) if make_input is make_image_cvcuda and value == "random": - # CV-CUDA doesnt not support random per-pixel fill types - assert_close(actual, expected, rtol=0, atol=255) + # CV-CUDA doesnt have same random distribution as torchvision + # it uses its own seeding, but we have determinism + # set seed with torch.randint in the kernel + assert_close(actual, expected, rtol=0, atol=256) else: assert_equal(actual, expected) diff --git a/torchvision/transforms/v2/functional/_augment.py b/torchvision/transforms/v2/functional/_augment.py index df2d0415df4..5d51c243fb4 100644 --- a/torchvision/transforms/v2/functional/_augment.py +++ b/torchvision/transforms/v2/functional/_augment.py @@ -1,5 +1,5 @@ -import ctypes import io +from types import SimpleNamespace from typing import TYPE_CHECKING import PIL.Image @@ -80,84 +80,78 @@ def _erase_image_cvcuda( if inplace: raise ValueError("inplace is not supported for cvcuda.Tensor") - # Load CUDA runtime for memory copy - try: - cudart = ctypes.CDLL("libcudart.so") - except OSError: - cudart = ctypes.CDLL("libcudart.so.12") - cudart.cudaMemcpy.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int] - cudart.cudaMemcpy.restype = ctypes.c_int - CUDA_MEMCPY_D2D = 3 # cudaMemcpyDeviceToDevice - - num_erasing_areas = 1 - num_channels = image.shape[3] # NHWC layout - - # Create CV-CUDA tensors with proper compound types - cv_anchor = cvcuda.Tensor((num_erasing_areas,), cvcuda.Type._2S32, "N") - cv_erasing = cvcuda.Tensor((num_erasing_areas,), cvcuda.Type._3S32, "N") - cv_imgIdx = cvcuda.Tensor((num_erasing_areas,), cvcuda.Type.S32, "N") - # Values tensor - 4 floats per erasing area (CV-CUDA standard, supports up to RGBA) - cv_values = cvcuda.Tensor((num_erasing_areas * 4,), cvcuda.Type.F32, "N") - - # Create source torch tensors with the data - anchor_src = torch.tensor([j, i], dtype=torch.int32, device="cuda") - # The third value is a bitmask for which channels to fill: 7 = 0b111 = RGB all channels - channel_mask = (1 << num_channels) - 1 # e.g., 3 channels -> 0b111 = 7 - erasing_src = torch.tensor([w, h, channel_mask], dtype=torch.int32, device="cuda") - imgIdx_src = torch.tensor([0], dtype=torch.int32, device="cuda") - - # Get fill values for erasing - need 4 floats per erasing area (CV-CUDA format) - v_flat = v.flatten().to(dtype=torch.float32, device="cuda") - # Expand to 4 values (CV-CUDA always expects 4) - if v_flat.numel() == 1: - # Single value - replicate to all 4 slots - values_src = v_flat.expand(4).contiguous() - elif v_flat.numel() >= 4: - # Has enough values, take first 4 - values_src = v_flat[:4].contiguous() + # the v tensor is random if it has spatial dimensions > 1x1 + is_random_fill = v.shape[-2:] != (1, 1) + + # allocate any space for standard torch tensors + mask = (1 << image.shape[3]) - 1 + src_anchor = torch.tensor([[j, i]], dtype=torch.int32, device="cuda") + src_erasing = torch.tensor([[w, h, mask]], dtype=torch.int32, device="cuda") + src_idx = torch.tensor([0], dtype=torch.int32, device="cuda") + + # allocate the fill values based on if random or not + # use zeros for random fill since we have to pass the tensor to the kernel anyway + if is_random_fill: + src_vals = torch.zeros(4, device="cuda", dtype=torch.float32) + # CV-CUDA requires that the fill values is a flat size 4 tensor + # so we need to flatten the fill values and pad with zeros if needed else: - # Has fewer than 4 values, pad with zeros - padding = torch.zeros(4 - v_flat.numel(), dtype=torch.float32, device="cuda") - values_src = torch.cat([v_flat, padding]) - - # Copy data from torch tensors to CV-CUDA tensors using cudaMemcpy - cudart.cudaMemcpy( - cv_anchor.cuda().__cuda_array_interface__["data"][0], - anchor_src.data_ptr(), - 8, - CUDA_MEMCPY_D2D, # 2 x int32 = 8 bytes + v_flat = v.flatten().to(dtype=torch.float32, device="cuda") + if v_flat.numel() == 1: + src_vals = v_flat.expand(4).contiguous() + else: + if v_flat.numel() >= 4: + src_vals = v_flat[:4] + else: + pad_len = 4 - v_flat.numel() + src_vals = torch.cat([v_flat, torch.zeros(pad_len, device="cuda", dtype=torch.float32)]) + src_vals = src_vals.contiguous() + + # the simple tensors can be read directly by CV-CUDA + cv_imgIdx = cvcuda.as_tensor( + src_idx.reshape( + 1, + ), + "N", ) - cudart.cudaMemcpy( - cv_erasing.cuda().__cuda_array_interface__["data"][0], - erasing_src.data_ptr(), - 12, - CUDA_MEMCPY_D2D, # 3 x int32 = 12 bytes - ) - cudart.cudaMemcpy( - cv_imgIdx.cuda().__cuda_array_interface__["data"][0], - imgIdx_src.data_ptr(), - 4, - CUDA_MEMCPY_D2D, # 1 x int32 = 4 bytes - ) - cudart.cudaMemcpy( - cv_values.cuda().__cuda_array_interface__["data"][0], - values_src.data_ptr(), - 16, - CUDA_MEMCPY_D2D, # 4 x float32 = 16 bytes + cv_values = cvcuda.as_tensor( + src_vals.reshape( + 1 * 4, + ), + "N", ) - result = cvcuda.erase( + # packed types (_2S32, _3S32) need to be copied into pre-allocated tensors + # torch does not support these packed types directly, so we create a helper function + # which will enable torch copy into the data directly (by overriding type/strides info) + def _to_torch(cv_tensor: cvcuda.Tensor, shape: tuple[int, ...], typestr: str) -> torch.Tensor: + iface = cv_tensor.cuda().__cuda_array_interface__ + iface.update(shape=shape, typestr=typestr, strides=None) + return torch.as_tensor(SimpleNamespace(__cuda_array_interface__=iface), device="cuda") + + # allocate the data for packed types + cv_anchor = cvcuda.Tensor((1,), cvcuda.Type._2S32, "N") + cv_erasing = cvcuda.Tensor((1,), cvcuda.Type._3S32, "N") + + # do a memcpy with torch, pretending data is scalar type contiguous + _to_torch(cv_anchor, (1, 2), " Date: Fri, 5 Dec 2025 13:11:01 -0800 Subject: [PATCH 14/14] remove cvcuda refs that arent used --- torchvision/transforms/v2/functional/_color.py | 12 +----------- torchvision/transforms/v2/functional/_misc.py | 11 ++--------- 2 files changed, 3 insertions(+), 20 deletions(-) diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index 5be9c62902a..be254c0d63a 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -1,5 +1,3 @@ -from typing import TYPE_CHECKING - import PIL.Image import torch from torch.nn.functional import conv2d @@ -11,15 +9,7 @@ from ._misc import _num_value_bits, to_dtype_image from ._type_conversion import pil_to_tensor, to_pil_image -from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal - - -CVCUDA_AVAILABLE = _is_cvcuda_available() - -if TYPE_CHECKING: - import cvcuda # type: ignore[import-not-found] -if CVCUDA_AVAILABLE: - cvcuda = _import_cvcuda() # noqa: F811 +from ._utils import _get_kernel, _register_kernel_internal def rgb_to_grayscale(inpt: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor: diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index 0fa05a2113c..daf263df046 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -1,5 +1,5 @@ import math -from typing import Optional, TYPE_CHECKING +from typing import Optional import PIL.Image import torch @@ -13,14 +13,7 @@ from ._meta import _convert_bounding_box_format -from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal, is_pure_tensor - -CVCUDA_AVAILABLE = _is_cvcuda_available() - -if TYPE_CHECKING: - import cvcuda # type: ignore[import-not-found] -if CVCUDA_AVAILABLE: - cvcuda = _import_cvcuda() # noqa: F811 +from ._utils import _get_kernel, _register_kernel_internal, is_pure_tensor def normalize(