From 76a9041bf6b2b81cf7bb0ac5c384b5ace68db697 Mon Sep 17 00:00:00 2001 From: Fabien Racape Date: Fri, 2 Feb 2024 12:48:53 -0800 Subject: [PATCH 01/27] fix rebase with master --- compressai/layers/layers.py | 63 ++++++++++++++++++++++++++++++++++++- 1 file changed, 62 insertions(+), 1 deletion(-) diff --git a/compressai/layers/layers.py b/compressai/layers/layers.py index 73fcbce1..e802fec1 100644 --- a/compressai/layers/layers.py +++ b/compressai/layers/layers.py @@ -29,7 +29,7 @@ import math -from typing import Any +from typing import Any, Tuple import torch import torch.nn as nn @@ -47,6 +47,8 @@ "ResidualBlockUpsample", "ResidualBlockWithStride", "conv1x1", + "SpectralConv2d", + "SpectralConvTranspose2d", "conv3x3", "subpel_conv3x3", "QReLU", @@ -54,6 +56,65 @@ ] +class _SpectralConvNdMixin: + def __init__(self, dim: Tuple[int, ...]): + self.dim = dim + self.weight_transformed = nn.Parameter(self._to_transform_domain(self.weight)) + del self._parameters["weight"] # Unregister weight, and fallback to property. + + @property + def weight(self) -> Tensor: + return self._from_transform_domain(self.weight_transformed) + + def _to_transform_domain(self, x: Tensor) -> Tensor: + return torch.fft.rfftn(x, s=self.kernel_size, dim=self.dim, norm="ortho") + + def _from_transform_domain(self, x: Tensor) -> Tensor: + return torch.fft.irfftn(x, s=self.kernel_size, dim=self.dim, norm="ortho") + + +class SpectralConv2d(nn.Conv2d, _SpectralConvNdMixin): + r"""Spectral 2D convolution. + + Introduced in [Balle2018efficient]. + Reparameterizes the weights to be derived from weights stored in the + frequency domain. + In the original paper, this is referred to as "spectral Adam" or + "Sadam" due to its effect on the Adam optimizer update rule. + The motivation behind representing the weights in the frequency + domain is that optimizer updates/steps may now affect all + frequencies to an equal amount. + This improves the gradient conditioning, thus leading to faster + convergence and increased stability at larger learning rates. + + For comparison, see the TensorFlow Compression implementations of + `SignalConv2D + `_ + and + `RDFTParameter + `_. + + [Balle2018efficient]: `"Efficient Nonlinear Transforms for Lossy + Image Compression" `_, + by Johannes Ballé, PCS 2018. + """ + + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + _SpectralConvNdMixin.__init__(self, dim=(-2, -1)) + + +class SpectralConvTranspose2d(nn.ConvTranspose2d, _SpectralConvNdMixin): + r"""Spectral 2D transposed convolution. + + Transposed version of :class:`SpectralConv2d`. + """ + + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + _SpectralConvNdMixin.__init__(self, dim=(-2, -1)) + + class MaskedConv2d(nn.Conv2d): r"""Masked 2D convolution implementation, mask future "unseen" pixels. Useful for building auto-regressive network components. From fdb884cf6ccc69cfd77d091021f0ff6a8c4a3bcd Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Thu, 1 Feb 2024 00:17:17 -0800 Subject: [PATCH 02/27] docs: add missing autodocs --- docs/source/datasets.rst | 20 +++++++++++++++++++- docs/source/index.rst | 1 + docs/source/losses.rst | 16 ++++++++++++++++ docs/source/models.rst | 2 +- 4 files changed, 37 insertions(+), 2 deletions(-) create mode 100644 docs/source/losses.rst diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index 45a23a1b..05a1535d 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -4,13 +4,31 @@ compressai.datasets .. currentmodule:: compressai.datasets + +Image/video datasets +~~~~~~~~~~~~~~~~~~~~ + + ImageFolder ----------- .. autoclass:: ImageFolder :members: +PreGeneratedMemmapDataset +------------------------- +.. autoclass:: PreGeneratedMemmapDataset + :members: + + VideoFolder ----------- .. autoclass:: VideoFolder - :members: \ No newline at end of file + :members: + + +Vimeo90kDataset +--------------- +.. autoclass:: Vimeo90kDataset + :members: + diff --git a/docs/source/index.rst b/docs/source/index.rst index fe6f240b..e283c511 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -29,6 +29,7 @@ end-to-end compression research. entropy_models latent_codecs layers + losses models ops transforms diff --git a/docs/source/losses.rst b/docs/source/losses.rst new file mode 100644 index 00000000..12e61aaa --- /dev/null +++ b/docs/source/losses.rst @@ -0,0 +1,16 @@ +compressai.losses +================= + +.. currentmodule:: compressai.losses + + + +Image/video losses +~~~~~~~~~~~~~~~~~~ + + +RateDistortionLoss +------------------ +.. autoclass:: RateDistortionLoss + :members: + diff --git a/docs/source/models.rst b/docs/source/models.rst index cd64f1b4..d509f96b 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -63,6 +63,6 @@ Elic2022Chandelier .. currentmodule:: compressai.models.video ScaleSpaceFlow ------------------- +-------------- .. autoclass:: ScaleSpaceFlow From 07a7f0ac8d9c783c291645c4a9e63e2333dcd855 Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Wed, 31 Jan 2024 20:10:00 -0800 Subject: [PATCH 03/27] feat: compressai.typing.TTransform --- compressai/typing/__init__.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/compressai/typing/__init__.py b/compressai/typing/__init__.py index 1c162a63..4ad93626 100644 --- a/compressai/typing/__init__.py +++ b/compressai/typing/__init__.py @@ -27,6 +27,8 @@ # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from typing import Callable + from .torch import ( TCriterion, TDataLoader, @@ -45,4 +47,7 @@ "TModule", "TOptimizer", "TScheduler", + "TTransform", ] + +TTransform = Callable From af6396bb69a3f1fbe69ed90f4cec4c9ac9cd6076 Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Wed, 31 Jan 2024 20:34:19 -0800 Subject: [PATCH 04/27] refactor: compressai.registry.transforms --- compressai/registry/__init__.py | 3 +- compressai/registry/torchvision.py | 11 +++--- compressai/registry/transforms.py | 55 ++++++++++++++++++++++++++++++ 3 files changed, 62 insertions(+), 7 deletions(-) create mode 100644 compressai/registry/transforms.py diff --git a/compressai/registry/__init__.py b/compressai/registry/__init__.py index a9aa8b21..d15d940d 100644 --- a/compressai/registry/__init__.py +++ b/compressai/registry/__init__.py @@ -41,7 +41,7 @@ register_optimizer, register_scheduler, ) -from .torchvision import TRANSFORMS +from .transforms import TRANSFORMS, register_transform __all__ = [ "CRITERIONS", @@ -57,4 +57,5 @@ "register_module", "register_optimizer", "register_scheduler", + "register_transform", ] diff --git a/compressai/registry/torchvision.py b/compressai/registry/torchvision.py index c3c57d70..f1b564ec 100644 --- a/compressai/registry/torchvision.py +++ b/compressai/registry/torchvision.py @@ -27,10 +27,9 @@ # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from typing import Callable, Dict +from .transforms import TRANSFORMS, register_transform -from torchvision import transforms - -TRANSFORMS: Dict[str, Callable[..., Callable]] = { - k: v for k, v in transforms.__dict__.items() if k[0].isupper() -} +__all__ = [ + "TRANSFORMS", + "register_transform", +] diff --git a/compressai/registry/transforms.py b/compressai/registry/transforms.py new file mode 100644 index 00000000..8ebbb8b4 --- /dev/null +++ b/compressai/registry/transforms.py @@ -0,0 +1,55 @@ +# Copyright (c) 2021-2022, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from typing import Callable, Dict, Type, TypeVar + +import torchvision.transforms + +from compressai.typing import TTransform + +__all__ = [ + "TRANSFORMS", + "register_transform", +] + +TRANSFORMS: Dict[str, Callable[..., TTransform]] = { + **{k: v for k, v in torchvision.transforms.__dict__.items() if k[0].isupper()}, +} + +TTransform_b = TypeVar("TTransform_b", bound=TTransform) + + +def register_transform(name: str): + """Decorator for registering a transform.""" + + def decorator(cls: Type[TTransform_b]) -> Type[TTransform_b]: + TRANSFORMS[name] = cls + return cls + + return decorator From 5b2b8d0a84a9be9bfd8805e293c4d36de4785c65 Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Wed, 26 Apr 2023 21:59:58 -0700 Subject: [PATCH 05/27] feat: fast EntropyBottleneck aux_loss minimization via bisection search [Currently disabled by default.] This method completes in <1 second and reduces aux_loss to <0.01. This makes the aux_loss optimization during training unnecessary. Another alternative would be to run the following post-training: ```python while aux_loss > 0.1: aux_loss = model.aux_loss() aux_loss.backward() aux_optimizer.step() aux_optimizer.zero_grad() ``` ...but since we do not manage aux_loss learning rates, the bisection search method might converge better. --- compressai/entropy_models/entropy_models.py | 6 +++++- compressai/models/base.py | 5 +++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/compressai/entropy_models/entropy_models.py b/compressai/entropy_models/entropy_models.py index 1dd46609..5d87305e 100644 --- a/compressai/entropy_models/entropy_models.py +++ b/compressai/entropy_models/entropy_models.py @@ -386,7 +386,7 @@ def _get_medians(self) -> Tensor: medians = self.quantiles[:, :, 1:2] return medians - def update(self, force: bool = False, update_quantiles: bool = True) -> bool: + def update(self, force: bool = False, update_quantiles: bool = False) -> bool: # Check if we need to update the bottleneck parameters, the offsets are # only computed and stored when the conditonal model is update()'d. if self._offset.numel() > 0 and not force: @@ -526,6 +526,10 @@ def _extend_ndims(tensor, n): @torch.no_grad() def _update_quantiles(self, search_radius=1e5, rtol=1e-4, atol=1e-3): + """Fast quantile update via bisection search. + + Often faster and much more precise than minimizing aux loss. + """ device = self.quantiles.device shape = (self.channels, 1, 1) low = torch.full(shape, -search_radius, device=device) diff --git a/compressai/models/base.py b/compressai/models/base.py index f1175d16..e7ff0b53 100644 --- a/compressai/models/base.py +++ b/compressai/models/base.py @@ -114,7 +114,7 @@ def load_state_dict(self, state_dict, strict=True): return nn.Module.load_state_dict(self, state_dict, strict=strict) - def update(self, scale_table=None, force=False): + def update(self, scale_table=None, force=False, update_quantiles: bool = False): """Updates EntropyBottleneck and GaussianConditional CDFs. Needs to be called once after training to be able to later perform the @@ -125,6 +125,7 @@ def update(self, scale_table=None, force=False): for initializing the Gaussian distributions (default: 64 logarithmically spaced scales from 0.11 to 256) force (bool): overwrite previous values (default: False) + update_quantiles (bool): fast update quantiles (default: False) Returns: updated (bool): True if at least one of the modules was updated. @@ -134,7 +135,7 @@ def update(self, scale_table=None, force=False): updated = False for _, module in self.named_modules(): if isinstance(module, EntropyBottleneck): - updated |= module.update(force=force) + updated |= module.update(force=force, update_quantiles=update_quantiles) if isinstance(module, GaussianConditional): updated |= module.update_scale_table(scale_table, force=force) return updated From 246200d7aa744988b4754544d6d778bc0d0983eb Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Wed, 31 Jan 2024 00:03:51 -0800 Subject: [PATCH 06/27] feat: net optimizer --- compressai/optimizers/__init__.py | 2 ++ compressai/optimizers/net.py | 59 +++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+) create mode 100644 compressai/optimizers/net.py diff --git a/compressai/optimizers/__init__.py b/compressai/optimizers/__init__.py index 7028cffc..aa749d02 100644 --- a/compressai/optimizers/__init__.py +++ b/compressai/optimizers/__init__.py @@ -27,8 +27,10 @@ # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from .net import net_optimizer from .net_aux import net_aux_optimizer __all__ = [ + "net_optimizer", "net_aux_optimizer", ] diff --git a/compressai/optimizers/net.py b/compressai/optimizers/net.py new file mode 100644 index 00000000..c583b89d --- /dev/null +++ b/compressai/optimizers/net.py @@ -0,0 +1,59 @@ +# Copyright (c) 2021-2022, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +from typing import Any, Dict, Mapping, cast + +import torch.nn as nn +import torch.optim as optim + +from compressai.registry import OPTIMIZERS, register_optimizer + + +@register_optimizer("net") +def net_optimizer( + net: nn.Module, conf: Mapping[str, Any] +) -> Dict[str, optim.Optimizer]: + """Returns optimizer for net loss.""" + parameters = { + "net": {name for name, param in net.named_parameters() if param.requires_grad}, + } + + params_dict = dict(net.named_parameters()) + + def make_optimizer(key): + kwargs = dict(conf[key]) + del kwargs["type"] + params = (params_dict[name] for name in sorted(parameters[key])) + return OPTIMIZERS[conf[key]["type"]](params, **kwargs) + + optimizer = {key: make_optimizer(key) for key in ["net"]} + + return cast(Dict[str, optim.Optimizer], optimizer) From 4d4192ba3f689a30917d54db3ea0b414f962358c Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Tue, 30 Jan 2024 22:53:57 -0800 Subject: [PATCH 07/27] feat: basic layers (Lambda, Reshape, Transpose, Interleave, etc.) --- compressai/layers/__init__.py | 1 + compressai/layers/basic.py | 119 ++++++++++++++++++++++++++++++++++ 2 files changed, 120 insertions(+) create mode 100644 compressai/layers/basic.py diff --git a/compressai/layers/__init__.py b/compressai/layers/__init__.py index 3464f4f4..080131e9 100644 --- a/compressai/layers/__init__.py +++ b/compressai/layers/__init__.py @@ -27,5 +27,6 @@ # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from .basic import * from .gdn import * from .layers import * diff --git a/compressai/layers/basic.py b/compressai/layers/basic.py new file mode 100644 index 00000000..4d7cd355 --- /dev/null +++ b/compressai/layers/basic.py @@ -0,0 +1,119 @@ +# Copyright (c) 2021-2022, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +import torch +import torch.nn as nn + +from torch import Tensor + +__all__ = [ + "Lambda", + "NamedLayer", + "Reshape", + "Transpose", + "Interleave", + "Gain", +] + + +class Lambda(nn.Module): + def __init__(self, func): + super().__init__() + self.func = func + + def __repr__(self): + return f"{self.__class__.__name__}(func={self.func})" + + def forward(self, x): + return self.func(x) + + +class NamedLayer(nn.Module): + def __init__(self, name: str): + super().__init__() + self.name = name + + def __repr__(self): + return f"{self.__class__.__name__}(name={self.name})" + + def forward(self, x): + return x + + +class Reshape(nn.Module): + def __init__(self, shape): + super().__init__() + self.shape = shape + + def __repr__(self): + return f"{self.__class__.__name__}(shape={self.shape})" + + def forward(self, x): + output_shape = (x.shape[0], *self.shape) + try: + return x.reshape(output_shape) + except RuntimeError as e: + e.args += (f"Cannot reshape input {tuple(x.shape)} to {output_shape}",) + raise e + + +class Transpose(nn.Module): + def __init__(self, dim0, dim1): + super().__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def __repr__(self): + return f"{self.__class__.__name__}(dim0={self.dim0}, dim1={self.dim1})" + + def forward(self, x): + return x.transpose(self.dim0, self.dim1).contiguous() + + +class Interleave(nn.Module): + def __init__(self, groups: int): + super().__init__() + self.groups = groups + + def forward(self, x: Tensor) -> Tensor: + g = self.groups + n, c, *tail = x.shape + return x.reshape(n, g, c // g, *tail).transpose(1, 2).reshape(x.shape) + + +class Gain(nn.Module): + def __init__(self, shape=None, factor: float = 1.0): + super().__init__() + self.factor = factor + self.gain = nn.Parameter(torch.ones(shape)) + + def forward(self, x: Tensor) -> Tensor: + return self.factor * self.gain * x From 5e69c5b3650c0a1639cfd9e2698b8b98e2feb247 Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Wed, 31 Jan 2024 18:24:53 -0800 Subject: [PATCH 08/27] chore(deps): einops, pandas, torch-geometric, tqdm --- setup.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/setup.py b/setup.py index 9a2aa22e..042369b2 100644 --- a/setup.py +++ b/setup.py @@ -128,12 +128,16 @@ def get_extra_requirements(): zip_safe=False, python_requires=">=3.6", install_requires=[ + "einops", "numpy", + "pandas", "scipy", "matplotlib", "torch>=1.7.1", + "torch-geometric>=2.3.0", "torchvision", "pytorch-msssim", + "tqdm", ], extras_require=get_extra_requirements(), license="BSD 3-Clause Clear License", From 6b0d638953a3ddc6b306269dc25254fff5fd814b Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Wed, 31 Jan 2024 20:00:17 -0800 Subject: [PATCH 09/27] chore(deps): pointops, pyntcloud [pointcloud] --- setup.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/setup.py b/setup.py index 042369b2..902f9169 100644 --- a/setup.py +++ b/setup.py @@ -104,6 +104,10 @@ def get_extensions(): "isort", "mypy", ] +POINTCLOUD_REQUIRES = [ + "pointops @ git+https://github.com/YodaEmbedding/pointops.git", + "pyntcloud", +] def get_extra_requirements(): @@ -112,6 +116,7 @@ def get_extra_requirements(): "dev": DEV_REQUIRES, "doc": ["sphinx", "sphinx-book-theme", "Jinja2<3.1"], "tutorials": ["jupyter", "ipywidgets"], + "pointcloud": POINTCLOUD_REQUIRES, } extras_require["all"] = {req for reqs in extras_require.values() for req in reqs} return extras_require From 09ed45c8cfa5f70b21ed3038442f16bc4cf9d057 Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Thu, 1 Feb 2024 21:14:23 -0800 Subject: [PATCH 10/27] chore(deps): pyntcloud use PR with *.off header fix Improves robustness of header parsing. In particular, ModelNet40 has faulty headers: ```bash $ head -n 1 ModelNet40/chair/train/chair_0856.off OFF6586 5534 0 ``` For reference, the correct format is: ``` OFF 6586 5534 0 ``` --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 902f9169..c97958d5 100644 --- a/setup.py +++ b/setup.py @@ -106,7 +106,7 @@ def get_extensions(): ] POINTCLOUD_REQUIRES = [ "pointops @ git+https://github.com/YodaEmbedding/pointops.git", - "pyntcloud", + "pyntcloud @ git+https://github.com/YodaEmbedding/pyntcloud.git@12ee9f2208f4207844be80ac5fdbafaf9f0652fa", ] From 7f1f81c05eec72d0d28efec43bf6f02caa294255 Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Wed, 31 Jan 2024 20:34:42 -0800 Subject: [PATCH 11/27] feat: compressai.registry.transforms torch_geometric --- compressai/registry/transforms.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/compressai/registry/transforms.py b/compressai/registry/transforms.py index 8ebbb8b4..9515d6c2 100644 --- a/compressai/registry/transforms.py +++ b/compressai/registry/transforms.py @@ -29,6 +29,7 @@ from typing import Callable, Dict, Type, TypeVar +import torch_geometric.transforms import torchvision.transforms from compressai.typing import TTransform @@ -40,6 +41,7 @@ TRANSFORMS: Dict[str, Callable[..., TTransform]] = { **{k: v for k, v in torchvision.transforms.__dict__.items() if k[0].isupper()}, + **{k: v for k, v in torch_geometric.transforms.__dict__.items() if k[0].isupper()}, } TTransform_b = TypeVar("TTransform_b", bound=TTransform) From f9d3708bd6732dee574d2a25be748a6471cc9598 Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Tue, 30 Jan 2024 23:18:34 -0800 Subject: [PATCH 12/27] feat: point cloud datasets (ModelNet, ShapeNet, S3DIS, SemanticKITTI) --- compressai/datasets/__init__.py | 3 + compressai/datasets/cache.py | 126 ++++++++ compressai/datasets/ndarray.py | 65 +++++ compressai/datasets/pointcloud/__init__.py | 40 +++ compressai/datasets/pointcloud/modelnet.py | 175 +++++++++++ compressai/datasets/pointcloud/s3dis.py | 200 +++++++++++++ .../datasets/pointcloud/semantic_kitti.py | 264 +++++++++++++++++ compressai/datasets/pointcloud/shapenet.py | 274 ++++++++++++++++++ compressai/datasets/stack.py | 91 ++++++ compressai/datasets/utils.py | 74 +++++ docs/source/datasets.rst | 29 ++ 11 files changed, 1341 insertions(+) create mode 100644 compressai/datasets/cache.py create mode 100644 compressai/datasets/ndarray.py create mode 100644 compressai/datasets/pointcloud/__init__.py create mode 100644 compressai/datasets/pointcloud/modelnet.py create mode 100644 compressai/datasets/pointcloud/s3dis.py create mode 100644 compressai/datasets/pointcloud/semantic_kitti.py create mode 100644 compressai/datasets/pointcloud/shapenet.py create mode 100644 compressai/datasets/stack.py create mode 100644 compressai/datasets/utils.py diff --git a/compressai/datasets/__init__.py b/compressai/datasets/__init__.py index d22b8c76..c13a9cc5 100644 --- a/compressai/datasets/__init__.py +++ b/compressai/datasets/__init__.py @@ -27,13 +27,16 @@ # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from . import pointcloud from .image import ImageFolder +from .pointcloud import * from .pregenerated import PreGeneratedMemmapDataset from .rawvideo import * from .video import VideoFolder from .vimeo90k import Vimeo90kDataset __all__ = [ + *pointcloud.__all__, "ImageFolder", "PreGeneratedMemmapDataset", "VideoFolder", diff --git a/compressai/datasets/cache.py b/compressai/datasets/cache.py new file mode 100644 index 00000000..4cb3c18c --- /dev/null +++ b/compressai/datasets/cache.py @@ -0,0 +1,126 @@ +# Copyright (c) 2021-2022, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import json +import os +import os.path + +from pathlib import Path + +import numpy as np + +from torch.utils.data import Dataset +from tqdm import tqdm + + +class CacheDataset(Dataset): + def __init__( + self, + cache_root=None, + pre_transform=None, + transform=None, + ): + self.__cache_root = Path(cache_root) + self.pre_transform = pre_transform + self.transform = transform + self._store = {} + + def __len__(self): + return len(self._store[next(iter(self._store))]) + + def __getitem__(self, index): + data = {k: v[index].copy() for k, v in self._store.items()} + if self.transform is not None: + data = self.transform(data) + return data + + def _ensure_cache(self): + try: + self._load_cache(mode="r") + except FileNotFoundError: + self._generate_cache() + self._load_cache(mode="r") + + def _load_cache(self, mode): + with open(self.__cache_root / "info.json", "r") as f: + info = json.load(f) + + self._store = { + k: np.memmap( + self.__cache_root / f"{k}.npy", + mode=mode, + dtype=settings["dtype"], + shape=tuple(settings["shape"]), + ) + for k, settings in info.items() + } + + def _generate_cache(self, verbose=True): + if verbose: + print(f"Generating cache at {self.__cache_root}...") + + items = self._get_items() + + if verbose: + items = tqdm(items) + + for i, item in enumerate(items): + data = self._load_item(item) + + if self.pre_transform is not None: + data = self.pre_transform(data) + + if not self._store: + self._write_cache_info(len(items), data) + self._load_cache(mode="w+") + + for k, v in data.items(): + self._store[k][i] = v + + def _write_cache_info(self, num_samples, data): + info = { + k: { + "dtype": _removeprefix(str(v.dtype), "torch."), + "shape": (num_samples, *v.shape), + } + for k, v in data.items() + } + os.makedirs(self.__cache_root, exist_ok=True) + with open(self.__cache_root / "info.json", "w") as f: + json.dump(info, f, indent=2) + + def _get_items(self): + raise NotImplementedError + + def _load_item(self, item): + raise NotImplementedError + + +def _removeprefix(s: str, prefix: str) -> str: + return s[len(prefix) :] if s.startswith(prefix) else s diff --git a/compressai/datasets/ndarray.py b/compressai/datasets/ndarray.py new file mode 100644 index 00000000..8309045a --- /dev/null +++ b/compressai/datasets/ndarray.py @@ -0,0 +1,65 @@ +# Copyright (c) 2021-2022, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# Adapted via https://github.com/pytorch/pytorch/blob/v2.1.0/torch/utils/data/dataset.py +# BSD-style license: https://github.com/pytorch/pytorch/blob/v2.1.0/LICENSE + +from typing import Tuple, Union + +import numpy as np + +from torch.utils.data import Dataset + + +class NdArrayDataset(Dataset[Union[np.ndarray, Tuple[np.ndarray, ...]]]): + r"""Dataset wrapping arrays. + + Each sample will be retrieved by indexing arrays along the first dimension. + + Args: + *arrays (np.ndarray): arrays that have the same size of the first dimension. + """ + + arrays: Tuple[np.ndarray, ...] + + def __init__(self, *arrays: np.ndarray, single: bool = False) -> None: + assert all( + arrays[0].shape[0] == array.shape[0] for array in arrays + ), "Size mismatch between arrays" + self.arrays = arrays + self.single = single + + def __getitem__(self, index): + if self.single: + [array] = self.arrays + return array[index] + return tuple(array[index] for array in self.arrays) + + def __len__(self): + return self.arrays[0].shape[0] diff --git a/compressai/datasets/pointcloud/__init__.py b/compressai/datasets/pointcloud/__init__.py new file mode 100644 index 00000000..809b7d64 --- /dev/null +++ b/compressai/datasets/pointcloud/__init__.py @@ -0,0 +1,40 @@ +# Copyright (c) 2021-2022, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from .modelnet import ModelNetDataset +from .s3dis import S3disDataset +from .semantic_kitti import SemanticKittiDataset +from .shapenet import ShapeNetCorePartDataset + +__all__ = [ + "ModelNetDataset", + "S3disDataset", + "SemanticKittiDataset", + "ShapeNetCorePartDataset", +] diff --git a/compressai/datasets/pointcloud/modelnet.py b/compressai/datasets/pointcloud/modelnet.py new file mode 100644 index 00000000..65554e63 --- /dev/null +++ b/compressai/datasets/pointcloud/modelnet.py @@ -0,0 +1,175 @@ +# Copyright (c) 2021-2022, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import os +import os.path +import re +import shutil + +from pathlib import Path + +import numpy as np + +from pyntcloud import PyntCloud + +from compressai.datasets.cache import CacheDataset +from compressai.datasets.utils import download_url, hash_file +from compressai.registry import register_dataset + + +@register_dataset("ModelNetDataset") +class ModelNetDataset(CacheDataset): + """ModelNet dataset. + + This dataset of 3D CAD models of objects was introduced by + [Wu2015]_, consisting of 10 or 40 classes, with 4899 and 12311 + aligned items, respectively. + Each 3D model is represented in the OFF file format by a triangle + mesh (i.e. faces) and has a single label (e.g. airplane). + To convert the triangle meshes to point clouds, one may use a mesh + sampling method (e.g. ``SamplePoints``). + + See also: [PapersWithCode_ModelNet]_. + + References: + + .. [Wu2015] `"3D ShapeNets: A deep representation for volumetric + shapes," `_, by Zhirong Wu, + Shuran Song, Aditya Khosla, Fisher Yu, Linguang Zhang, + Xiaoou Tang, and Jianxiong Xiao, CVPR 2015. + + .. [PapersWithCode_ModelNet] `PapersWithCode: ModelNet + `_ + """ + + # fmt: off + LABEL_LIST = { + "10": [ + "bathtub", "bed", "chair", "desk", "dresser", + "monitor", "night_stand", "sofa", "table", "toilet", + ], + "40": [ + "airplane", "bathtub", "bed", "bench", "bookshelf", + "bottle", "bowl", "car", "chair", "cone", "cup", + "curtain", "desk", "door", "dresser", "flower_pot", + "glass_box", "guitar", "keyboard", "lamp", "laptop", + "mantel", "monitor", "night_stand", "person", "piano", + "plant", "radio", "range_hood", "sink", "sofa", + "stairs", "stool", "table", "tent", "toilet", + "tv_stand", "vase", "wardrobe", "xbox", + ], + } + # fmt: on + + LABEL_STR_TO_LABEL_INDEX = { + "10": {label: idx for idx, label in enumerate(LABEL_LIST["10"])}, + "40": {label: idx for idx, label in enumerate(LABEL_LIST["40"])}, + } + + URLS = { + "10": "http://3dvision.princeton.edu/projects/2014/3DShapeNets/ModelNet10.zip", + "40": "http://modelnet.cs.princeton.edu/ModelNet40.zip", + } + + HASHES = { + "10": "9d8679435fc07d1d26f13009878db164a7aa8ea5e7ea3c8880e42794b7307d51", + "40": "42dc3e656932e387f554e25a4eb2cc0e1a1bd3ab54606e2a9eae444c60e536ac", + } + + def __init__( + self, + root=None, + cache_root=None, + split="train", + split_name=None, + name="40", + pre_transform=None, + transform=None, + download=True, + ): + if cache_root is None: + assert root is not None + cache_root = f"{str(root).rstrip('/')}_cache" + + self.root = Path(root) if root else None + self.cache_root = Path(cache_root) + self.split = split + self.split_name = split if split_name is None else split_name + self.name = name + + if download and self.root: + self.download() + + super().__init__( + cache_root=self.cache_root / self.split_name, + pre_transform=pre_transform, + transform=transform, + ) + + self._ensure_cache() + + def download(self, force=False): + if not force and self.root.exists(): + return + tmpdir = self.root.parent / "tmp" + os.makedirs(tmpdir, exist_ok=True) + filepath = download_url(self.URLS[self.name], tmpdir, overwrite=force) + assert self.HASHES[self.name] == hash_file(filepath, method="sha256") + shutil.unpack_archive(filepath, tmpdir) + shutil.move(tmpdir / f"ModelNet{self.name}", self.root) + + def _get_items(self): + return sorted(self.root.glob(f"**/{self.split}/*.off")) + + def _load_item(self, path): + label_index, file_index = self._parse_path(path) + cloud = PyntCloud.from_file(str(path)) + return { + "file_index": np.array([file_index], dtype=np.int32), + "label": np.array([label_index], dtype=np.uint8), + "pos": cloud.points.values, + "face": cloud.mesh.values.T, + } + + def _parse_path(self, path): + pattern = ( + r"^.*?/?" + r"(?P[a-zA-Z_]+)/" + r"(?P[a-zA-Z_]+)/" + r"(?P[a-zA-Z_]+)_(?P\d+)\.off$" + ) + match = re.match(pattern, str(path)) + if match is None: + raise ValueError(f"Could not parse path: {path}") + assert match.group("split") == self.split + assert match.group("label_str") == match.group("label_str_again") + label_str = match.group("label_str") + label_index = self.LABEL_STR_TO_LABEL_INDEX[self.name][label_str] + file_index = int(match.group("file_index")) + return label_index, file_index diff --git a/compressai/datasets/pointcloud/s3dis.py b/compressai/datasets/pointcloud/s3dis.py new file mode 100644 index 00000000..b370c3c9 --- /dev/null +++ b/compressai/datasets/pointcloud/s3dis.py @@ -0,0 +1,200 @@ +# Copyright (c) 2021-2022, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import os +import shutil + +from pathlib import Path + +from torch.utils.data import ConcatDataset + +from compressai.datasets.cache import CacheDataset +from compressai.datasets.ndarray import NdArrayDataset +from compressai.datasets.stack import StackDataset +from compressai.datasets.utils import download_url, hash_file +from compressai.registry import register_dataset + + +@register_dataset("S3disDataset") +class S3disDataset(CacheDataset): + """S3DIS dataset. + + The Stanford 3D Indoor Scene Dataset (S3DIS) dataset, introduced by + [Armeni2012]_, contains 3D point clouds of 6 large-scale indoor areas. + There are multiple rooms (e.g. office, lounge, hallway, etc) per area. + See the [ProjectPage_S3DIS]_ for a visualization. + + The ``semantic_index`` is a number between 0 and 12 (inclusive), + which can be used as the semantic label for each point. + + See also: [PapersWithCode_S3DIS]_. + + References: + + .. [Armeni2012] `"3D Semantic Parsing of Large-Scale Indoor Spaces," + `_, + by Iro Armeni, Ozan Sener, Amir R. Zamir, Helen Jiang, + Ioannis Brilakis, Martin Fischer, and Silvio Savarese, + CVPR 2012. + + .. [ProjectPage_S3DIS] `Project page + `_ + + .. [PapersWithCode_S3DIS] `PapersWithCode: S3DIS + `_ + """ + + URLS = [ + "https://shapenet.cs.stanford.edu/media/indoor3d_sem_seg_hdf5_data.zip", + ] + + HASHES = [ + "587bb63b296d542c24910c384c41028f2caa1d749042ae891d0d64968c773185", # indoor3d_sem_seg_hdf5_data.zip + ] + + # Suggested splits: + AREAS = { + "train": (1, 2, 3, 4, 6), + "valid": (5,), + "test": (5,), + } + + NUM_SAMPLES_PER_AREA = [0, 3687, 4440, 1650, 3662, 6852, 3294] + + LABELS = [ + "ceiling", + "floor", + "wall", + "beam", + "column", + "window", + "door", + "table", + "chair", + "sofa", + "bookcase", + "board", + "clutter", + ] + + ROOMS = [ + "auditorium", + "conferenceRoom", + "copyRoom", + "hallway", + "lobby", + "lounge", + "office", + "openspace", + "pantry", + "storage", + "WC", + ] + + def __init__( + self, + root=None, + cache_root=None, + split="train", + split_name=None, + areas=AREAS["train"], + pre_transform=None, + transform=None, + download=True, + ): + if cache_root is None: + assert root is not None + cache_root = f"{str(root).rstrip('/')}_cache" + + self.root = Path(root) if root else None + self.cache_root = Path(cache_root) + self.split = split + self.split_name = split if split_name is None else split_name + self.areas = areas + + if download and self.root: + self.download() + + self._root_dataset = self._get_root_dataset() + + super().__init__( + cache_root=self.cache_root / self.split_name, + pre_transform=pre_transform, + transform=transform, + ) + + self._ensure_cache() + + def download(self, force=False): + if not force and self.root.exists(): + return + tmpdir = self.root.parent / "tmp" + os.makedirs(tmpdir, exist_ok=True) + for expected_hash, url in zip(self.HASHES, self.URLS): + filepath = download_url( + url, tmpdir, check_certificate=False, overwrite=force + ) + shutil.unpack_archive(filepath, tmpdir) + assert expected_hash == hash_file(filepath, method="sha256") + shutil.move(tmpdir / "indoor3d_sem_seg_hdf5_data", self.root) + + def _get_root_dataset(self): + import h5py + + h5_files = [h5py.File(path, "r") for path in sorted(self.root.glob("**/*.h5"))] + keys = ["data", "label"] + + return ConcatDataset( + StackDataset(**{k: NdArrayDataset(h5_file[k], single=True) for k in keys}) + for h5_file in h5_files + ) + + def _get_items(self): + with open(self.root / "room_filelist.txt") as f: + lines = f.read().splitlines() + return [ + (i, line) + for i, line in enumerate(lines) + if int(line.split("_")[1]) in self.areas + ] + + def _load_item(self, item): + index, name = item + _, area_index_str, room_str, *_ = name.split("_") + data = self._root_dataset[index] + + return { + "file_index": index, + "area_index": int(area_index_str), + "room_index": self.ROOMS.index(room_str), + "semantic_index": data["label"], + "pos": data["data"][:, 0:3], # xyz + "color": data["data"][:, 3:6], # rgb + "pos_normalized": data["data"][:, 6:9], # Normalized xyz + } diff --git a/compressai/datasets/pointcloud/semantic_kitti.py b/compressai/datasets/pointcloud/semantic_kitti.py new file mode 100644 index 00000000..3e43a941 --- /dev/null +++ b/compressai/datasets/pointcloud/semantic_kitti.py @@ -0,0 +1,264 @@ +# Copyright (c) 2021-2022, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import os +import re +import shutil + +from pathlib import Path + +import numpy as np + +from compressai.datasets.cache import CacheDataset +from compressai.datasets.utils import download_url, hash_file +from compressai.registry import register_dataset + + +@register_dataset("SemanticKittiDataset") +class SemanticKittiDataset(CacheDataset): + """SemanticKITTI dataset. + + The KITTI dataset, introduced by [Geiger2012]_, contains 3D point + clouds sequences (i.e. video) of LiDAR sensor data from the + perspective of a driving vehicle. + The SemanticKITTI dataset, introduced by [Behley2019]_ and + [Behley2021]_, provides semantic annotation of all 22 sequences from + the odometry task [Odometry_KITTI]_ of KITTI. + See the [ProjectPage_SemanticKITTI]_ for a visualization. + Note that the test set is unlabelled, and must be evaluated on the + server, as mentioned at [ProjectPageTasks_SemanticKITTI]_. + + The ``semantic_index`` is a number between 0 and 33 (inclusive), + which can be used as the semantic label for each point. + + See also: [PapersWithCode_SemanticKITTI]_. + + References: + + .. [Geiger2012] `"Are we ready for Autonomous Driving? The KITTI + Vision Benchmark Suite," + `_, + by Andreas Geiger, Philip Lenz, and Raquel Urtasun, + CVPR 2012. + + .. [Behley2019] `"SemanticKITTI: A Dataset for Semantic Scene + Understanding of LiDAR Sequences," + `_, + by Jens Behley, Martin Garbade, Andres Milioto, Jan Quenzel, + Sven Behnke, Cyrill Stachniss, and Juergen Gall, ICCV 2019. + + .. [Behley2021] `"Towards 3D LiDAR-based semantic scene + understanding of 3D point cloud sequences: The SemanticKITTI + Dataset," + `_, + by Jens Behley, Martin Garbade, Andres Milioto, Jan Quenzel, + Sven Behnke, Jürgen Gall, and Cyrill Stachniss, IJRR 2021. + + .. [ProjectPage_SemanticKITTI] `Project page (SemanticKITTI) + `_ + + .. [ProjectPageTasks_SemanticKITTI] `Project page: Tasks + (SemanticKITTI) + `_ + + .. [Odometry_KITTI] `"Visual Odometry / SLAM Evaluation 2012" + `_ + + .. [PapersWithCode_SemanticKITTI] `PapersWithCode: SemanticKITTI + `_ + """ + + URLS = [ + "https://s3.eu-central-1.amazonaws.com/avg-kitti/data_odometry_calib.zip", + "https://s3.eu-central-1.amazonaws.com/avg-kitti/data_odometry_velodyne.zip", + "http://www.semantic-kitti.org/assets/data_odometry_labels.zip", + "http://www.semantic-kitti.org/assets/data_odometry_voxels_all.zip", + "http://www.semantic-kitti.org/assets/data_odometry_voxels.zip", + ] + + HASHES = [ + "fa45d2bbff828776e6df689b161415fb7cd719345454b6d3567c2ff81fa4d075", # data_odometry_calib.zip + "062a45667bec6874ac27f733bd6809919f077265e7ac0bb25ac885798fa85ab5", # data_odometry_velodyne.zip + "408ec524636a393bae0288a0b2f48bf5418a1af988e82dee8496f89ddb7e6dda", # data_odometry_labels.zip + "10f333faa63426a519a573fbf0b4e3b56513511af30583473fa6a5782e037f3a", # data_odometry_voxels_all.zip + "d92c253e88e5e30c0a0b88f028510760e1db83b7e262d75c5931bf9b8d6dd51b", # data_odometry_voxels.zip + ] + + # Suggested splits: + SEQUENCES = { + "train": (0, 1, 2, 3, 4, 5, 6, 7, 9, 10), + "valid": (8,), + "infer": (8,), + "test": (11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21), # Unlabelled. + } + + # fmt: off + NUM_SAMPLES_PER_SEQUENCE = [ + 4541, 1101, 4661, 801, 271, 2761, 1101, 1101, 4071, 1591, 1201, + 921, 1061, 3281, 631, 1901, 1731, 491, 1801, 4981, 831, 2721 + ] + # fmt: on + + RAW_SEMANTIC_INDEX_TO_LABEL = { + 0: "unlabeled", + 1: "outlier", + 10: "car", + 11: "bicycle", + 13: "bus", + 15: "motorcycle", + 16: "on-rails", + 18: "truck", + 20: "other-vehicle", + 30: "person", + 31: "bicyclist", + 32: "motorcyclist", + 40: "road", + 44: "parking", + 48: "sidewalk", + 49: "other-ground", + 50: "building", + 51: "fence", + 52: "other-structure", + 60: "lane-marking", + 70: "vegetation", + 71: "trunk", + 72: "terrain", + 80: "pole", + 81: "traffic-sign", + 99: "other-object", + 252: "moving-car", + 253: "moving-bicyclist", + 254: "moving-person", + 255: "moving-motorcyclist", + 256: "moving-on-rails", + 257: "moving-bus", + 258: "moving-truck", + 259: "moving-other-vehicle", + } + + RAW_SEMANTIC_INDEX_TO_SEMANTIC_INDEX = { + idx: i for i, idx in enumerate(RAW_SEMANTIC_INDEX_TO_LABEL) + } + + def __init__( + self, + root=None, + cache_root=None, + split="train", + split_name=None, + sequences=SEQUENCES["train"], + pre_transform=None, + transform=None, + download=True, + ): + if cache_root is None: + assert root is not None + cache_root = f"{str(root).rstrip('/')}_cache" + + self.root = Path(root) if root else None + self.cache_root = Path(cache_root) + self.split = split + self.split_name = split if split_name is None else split_name + self.sequences = sequences + + if download and self.root: + self.download() + + super().__init__( + cache_root=self.cache_root / self.split_name, + pre_transform=pre_transform, + transform=transform, + ) + + self._ensure_cache() + + def download(self, force=False): + if not force and self.root.exists(): + return + tmpdir = self.root.parent / "tmp" + os.makedirs(tmpdir, exist_ok=True) + for expected_hash, url in zip(self.HASHES, self.URLS): + filepath = download_url( + url, tmpdir, check_certificate=False, overwrite=force + ) + shutil.unpack_archive(filepath, tmpdir) + assert expected_hash == hash_file(filepath, method="sha256") + shutil.move(tmpdir / "dataset", self.root) + + def _get_items(self): + return sorted( + x + for i in self.sequences + for x in self.root.glob(f"**/{i:02}/velodyne/*.bin") + ) + + def _load_item(self, path): + path_prefix, sequence_index, file_index = self._parse_path(path) + assert str(path) == f"{path_prefix}{sequence_index}/velodyne/{file_index}.bin" + point_data = np.fromfile(path, dtype=np.float32).reshape(-1, 4) + label_data = ( + np.fromfile( + f"{path_prefix}{sequence_index}/labels/{file_index}.label", dtype=".*?/?)" + r"(?P\d+)/" + r"velodyne/" + r"(?P\d{6})\.\w+$" + ) + match = re.match(pattern, str(path)) + if match is None: + raise ValueError(f"Could not parse path: {path}") + path_prefix = match.group("path_prefix") + sequence_index = match.group("sequence_index") + file_index = match.group("file_index") + return path_prefix, sequence_index, file_index + + +def np_remap(arr, d): + values, inverse = np.unique(arr, return_inverse=True) + values = np.array([d[x] for x in values], dtype=arr.dtype) + return values[inverse].reshape(arr.shape) diff --git a/compressai/datasets/pointcloud/shapenet.py b/compressai/datasets/pointcloud/shapenet.py new file mode 100644 index 00000000..3075c859 --- /dev/null +++ b/compressai/datasets/pointcloud/shapenet.py @@ -0,0 +1,274 @@ +# Copyright (c) 2021-2022, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +import json +import os +import re +import shutil + +from pathlib import Path + +import numpy as np +import pandas as pd + +from compressai.datasets.cache import CacheDataset +from compressai.datasets.utils import download_url, hash_file +from compressai.registry import register_dataset + + +@register_dataset("ShapeNetCorePartDataset") +class ShapeNetCorePartDataset(CacheDataset): + """ShapeNet-Part dataset. + + The ShapeNet dataset of 3D CAD models of objects was introduced by + [Yi2016]_, consisting of over 3000000 models. + The ShapeNetCore (v2) dataset is a "clean" subset of ShapeNet, + consisting of 51127 aligned items from 55 object categories. + The ShapeNet-Part dataset is a further subset of this dataset, + consisting of 16881 items from 16 object categories. + See page 2 of [Yi2017]_ for additional description. + + Object categories are labeled with two to six segmentation parts + each, as shown in the image below. + (Purple represents a "miscellaneous" part.) + + .. image:: https://cs.stanford.edu/~ericyi/project_page/part_annotation/figures/categoriesNumbers.png + + [ProjectPage_ShapeNetPart]_ also releases a processed version of + ShapeNet-Part containing point cloud and normals with + expert-verified segmentations, which we use here. + + The ``semantic_index`` is a number between 0 and 49 (inclusive), + which can be used as the semantic label for each point. + + See also: [PapersWithCode_ShapeNetPart]_ (benchmarks). + + References: + + .. [Yi2016] `"A scalable active framework for region annotation + in 3D shape collections," + `_, + by Li Yi, Vladimir G. Kim, Duygu Ceylan, I-Chao Shen, + Mengyan Yan, Hao Su, Cewu Lu, Qixing Huang, Alla Sheffer, + and Leonidas Guibas, ACM Transactions on Graphics, 2016. + + .. [Yi2017] `"Large-scale 3D shape reconstruction and + segmentation from ShapeNet Core55," + `_, + by Li Yi et al. (total 50 authors), ICCV 2017. + + .. [ProjectPage_ShapeNetPart] `Project page (ShapeNet-Part) + `_ + + .. [PapersWithCode_ShapeNetPart] `PapersWithCode: ShapeNet-Part Benchmark + (3D Part Segmentation) + `_ + """ + + URLS = { + "shapenetcore_partanno_segmentation_benchmark_v0": "https://shapenet.cs.stanford.edu/media/shapenetcore_partanno_segmentation_benchmark_v0.zip", + "shapenetcore_partanno_segmentation_benchmark_v0_normal": "https://shapenet.cs.stanford.edu/media/shapenetcore_partanno_segmentation_benchmark_v0_normal.zip", + } + # Related: https://shapenet.cs.stanford.edu/media/shapenet_part_seg_hdf5_data.zip + + HASHES = { + "shapenetcore_partanno_segmentation_benchmark_v0": "f1dc7bad73237060946f13e1fa767b40d9adba52a79d42d64de31552b8c0b65e", + "shapenetcore_partanno_segmentation_benchmark_v0_normal": "0e26411700bae2da38ee8ecc719ba4db2e6e0133486e258665952ad5dfced0fe", + } + + CATEGORY_ID_TO_CATEGORY_STR = { + "02691156": "Airplane", + "02773838": "Bag", + "02954340": "Cap", + "02958343": "Car", + "03001627": "Chair", + "03261776": "Earphone", + "03467517": "Guitar", + "03624134": "Knife", + "03636649": "Lamp", + "03642806": "Laptop", + "03790512": "Motorbike", + "03797390": "Mug", + "03948459": "Pistol", + "04099429": "Rocket", + "04225987": "Skateboard", + "04379243": "Table", + } + + NUM_PARTS = { + "02691156": 4, # Airplane + "02773838": 2, # Bag + "02954340": 2, # Cap + "02958343": 4, # Car + "03001627": 4, # Chair + "03261776": 3, # Earphone + "03467517": 3, # Guitar + "03624134": 2, # Knife + "03636649": 4, # Lamp + "03642806": 2, # Laptop + "03790512": 6, # Motorbike + "03797390": 2, # Mug + "03948459": 3, # Pistol + "04099429": 3, # Rocket + "04225987": 3, # Skateboard + "04379243": 3, # Table + } + + def __init__( + self, + root=None, + cache_root=None, + split="train", + split_name=None, + pre_transform=None, + transform=None, + name="shapenetcore_partanno_segmentation_benchmark_v0_normal", + download=True, + ): + if cache_root is None: + assert root is not None + cache_root = f"{str(root).rstrip('/')}_cache" + + self.root = Path(root) if root else None + self.cache_root = Path(cache_root) + self.split = split + self.split_name = split if split_name is None else split_name + self.name = name + + if download and self.root: + self.download() + + self._verify_category_ids() + + self.category_id_info = { + category_id: { + "category_str": category_str, + "category_index": category_index, + } + for category_index, (category_id, category_str) in enumerate( + self.CATEGORY_ID_TO_CATEGORY_STR.items() + ) + } + + self.category_offsets = np.cumsum([0] + list(self.NUM_PARTS.values())) + + super().__init__( + cache_root=self.cache_root / self.split_name, + pre_transform=pre_transform, + transform=transform, + ) + + self._ensure_cache() + + def download(self, force=False): + if not force and self.root.exists(): + return + tmpdir = self.root.parent / "tmp" + os.makedirs(tmpdir, exist_ok=True) + filepath = download_url( + self.URLS[self.name], tmpdir, check_certificate=False, overwrite=force + ) + assert self.HASHES[self.name] == hash_file(filepath, method="sha256") + shutil.unpack_archive(filepath, tmpdir) + shutil.move(tmpdir / f"{self.name}", self.root) + + def _verify_category_ids(self): + with open(self.root / "synsetoffset2category.txt") as f: + pairs = [line.split() for line in f.readlines()] + category_id_to_category_str = { + category_id: category_str for category_str, category_id in pairs + } + assert category_id_to_category_str == self.CATEGORY_ID_TO_CATEGORY_STR + + def _get_items(self): + file_list = f"shuffled_{self.split}_file_list.json" + with open(self.root / "train_test_split" / file_list) as f: + paths = json.load(f) + return paths + + def _load_item(self, path): + category_id, file_hash = self._parse_path(path) + category_index = self.category_id_info[category_id]["category_index"] + category_offset = self.category_offsets[category_index] + read_csv_kwargs = {"sep": " ", "header": None, "index_col": False} + + if self.name == "shapenetcore_partanno_segmentation_benchmark_v0_normal": + names = ["x", "y", "z", "nx", "ny", "nz", "semantic_index"] + df = pd.read_csv( + f"{self.root}/{category_id}/{file_hash}.txt", + names=names, + dtype={k: np.float32 for k in names}, + **read_csv_kwargs, + ) + df["semantic_index"] = df["semantic_index"].astype(np.uint8) + df["part_index"] = df["semantic_index"] - category_offset + + elif self.name == "shapenetcore_partanno_segmentation_benchmark_v0": + df_points = pd.read_csv( + f"{self.root}/{category_id}/points/{file_hash}.pts", + names=["x", "y", "z"], + dtype={k: np.float32 for k in ["x", "y", "z"]}, + **read_csv_kwargs, + ) + df_points_label = pd.read_csv( + f"{self.root}/{category_id}/points_label/{file_hash}.seg", + names=["part_index"], + dtype={"part_index": np.uint8}, + **read_csv_kwargs, + ) + df = pd.concat([df_points, df_points_label], axis="columns") + assert df["part_index"].min() >= 1 + df["part_index"] -= 1 + df["semantic_index"] = category_offset + df["part_index"] + + else: + raise ValueError(f"Unknown name: {self.name}") + + data = { + "category_index": np.array([category_index], dtype=np.uint8), + "part_index": df["part_index"].values, + "semantic_index": df["semantic_index"].values, + "pos": df[["x", "y", "z"]].values, + } + + if self.name == "shapenetcore_partanno_segmentation_benchmark_v0_normal": + data["normal"] = df[["nx", "ny", "nz"]].values + + return data + + def _parse_path(self, path): + pattern = r"^.*?/?(?P\d+)/(?P[-a-fu\d]+)$" + match = re.match(pattern, str(path)) + if match is None: + raise ValueError(f"Could not parse path: {path}") + category_id = match.group("category_id") + file_hash = match.group("file_hash") + return category_id, file_hash diff --git a/compressai/datasets/stack.py b/compressai/datasets/stack.py new file mode 100644 index 00000000..fcae5cdb --- /dev/null +++ b/compressai/datasets/stack.py @@ -0,0 +1,91 @@ +# Copyright (c) 2021-2022, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# Copied from https://github.com/pytorch/pytorch/blob/v2.1.0/torch/utils/data/dataset.py +# BSD-style license: https://github.com/pytorch/pytorch/blob/v2.1.0/LICENSE + +from typing import Dict, Tuple, TypeVar, Union + +from torch.utils.data import Dataset + +T_co = TypeVar("T_co", covariant=True) +T = TypeVar("T") +T_dict = Dict[str, T_co] +T_tuple = Tuple[T_co, ...] +T_stack = TypeVar("T_stack", T_tuple, T_dict) + + +class StackDataset(Dataset[T_stack]): + r"""Dataset as a stacking of multiple datasets. + + This class is useful to assemble different parts of complex input data, given as datasets. + + Example: + >>> # xdoctest: +SKIP + >>> images = ImageDataset() + >>> texts = TextDataset() + >>> tuple_stack = StackDataset(images, texts) + >>> tuple_stack[0] == (images[0], texts[0]) + >>> dict_stack = StackDataset(image=images, text=texts) + >>> dict_stack[0] == {'image': images[0], 'text': texts[0]} + + Args: + *args (Dataset): Datasets for stacking returned as tuple. + **kwargs (Dataset): Datasets for stacking returned as dict. + """ + + datasets: Union[tuple, dict] + + def __init__(self, *args: Dataset[T_co], **kwargs: Dataset[T_co]) -> None: + if args: + if kwargs: + raise ValueError( + "Supported either ``tuple``- (via ``args``) or" + "``dict``- (via ``kwargs``) like input/output, but both types are given." + ) + self._length = len(args[0]) # type: ignore[arg-type] + if any(self._length != len(dataset) for dataset in args): # type: ignore[arg-type] + raise ValueError("Size mismatch between datasets") + self.datasets = args + elif kwargs: + tmp = list(kwargs.values()) + self._length = len(tmp[0]) # type: ignore[arg-type] + if any(self._length != len(dataset) for dataset in tmp): # type: ignore[arg-type] + raise ValueError("Size mismatch between datasets") + self.datasets = kwargs + else: + raise ValueError("At least one dataset should be passed") + + def __getitem__(self, index): + if isinstance(self.datasets, dict): + return {k: dataset[index] for k, dataset in self.datasets.items()} + return tuple(dataset[index] for dataset in self.datasets) + + def __len__(self): + return self._length diff --git a/compressai/datasets/utils.py b/compressai/datasets/utils.py new file mode 100644 index 00000000..1aa94f6f --- /dev/null +++ b/compressai/datasets/utils.py @@ -0,0 +1,74 @@ +# Copyright (c) 2021-2022, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import hashlib +import urllib +import urllib.parse +import urllib.request + +from pathlib import Path + +import requests + +from tqdm import tqdm + + +def download_url(url, path, chunk_size=65536, check_certificate=True, overwrite=False): + path = Path(path) + + if path.is_dir(): + path = path / urllib.parse.unquote(url.split("/")[-1]) + + print(f"Downloading {url} to {path}...") + response = requests.get(url, stream=True, verify=check_certificate) + total_size = int(response.headers.get("content-length", 0)) + file_size = path.stat().st_size if path.is_file() else None + + if not overwrite and file_size == total_size: + return path + + with tqdm(total=total_size, unit="B", unit_scale=True) as progress_bar: + with open(path, "wb") as f: + for data in response.iter_content(chunk_size): + progress_bar.update(len(data)) + f.write(data) + + if total_size != 0 and progress_bar.n != total_size: + raise RuntimeError("Could not download file") + + return path + + +def hash_file(path, method="sha256", bufsize=131072): + hash = hashlib.sha256() if method == "sha256" else None + mv = memoryview(bytearray(bufsize)) + with open(path, "rb", buffering=0) as f: + for n in iter(lambda: f.readinto(mv), 0): + hash.update(mv[:n]) + return hash.hexdigest() diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index 05a1535d..e0ef1124 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -32,3 +32,32 @@ Vimeo90kDataset .. autoclass:: Vimeo90kDataset :members: + + +Point cloud datasets +~~~~~~~~~~~~~~~~~~~~ + + +ModelNetDataset +--------------- +.. autoclass:: ModelNetDataset + :members: + + +S3disDataset +------------ +.. autoclass:: S3disDataset + :members: + + +SemanticKittiDataset +-------------------- +.. autoclass:: SemanticKittiDataset + :members: + + +ShapeNetCorePartDataset +----------------------- +.. autoclass:: ShapeNetCorePartDataset + :members: + From b7c1f1e4070cba998eeeea734b73254475dc33d7 Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Wed, 31 Jan 2024 19:57:44 -0800 Subject: [PATCH 13/27] feat: point cloud transforms --- compressai/transforms/__init__.py | 2 + compressai/transforms/point/__init__.py | 46 +++++++ .../point/generate_position_normals.py | 83 ++++++++++++ .../transforms/point/normalize_scale_v2.py | 60 +++++++++ .../transforms/point/random_permutation.py | 51 ++++++++ .../transforms/point/random_rotate_full.py | 61 +++++++++ compressai/transforms/point/random_sample.py | 84 +++++++++++++ .../transforms/point/sample_points_v2.py | 118 ++++++++++++++++++ compressai/transforms/point/to_dict.py | 61 +++++++++ docs/source/transforms.rst | 7 ++ 10 files changed, 573 insertions(+) create mode 100644 compressai/transforms/point/__init__.py create mode 100644 compressai/transforms/point/generate_position_normals.py create mode 100644 compressai/transforms/point/normalize_scale_v2.py create mode 100644 compressai/transforms/point/random_permutation.py create mode 100644 compressai/transforms/point/random_rotate_full.py create mode 100644 compressai/transforms/point/random_sample.py create mode 100644 compressai/transforms/point/sample_points_v2.py create mode 100644 compressai/transforms/point/to_dict.py diff --git a/compressai/transforms/__init__.py b/compressai/transforms/__init__.py index 30a24c22..3cc90328 100644 --- a/compressai/transforms/__init__.py +++ b/compressai/transforms/__init__.py @@ -27,4 +27,6 @@ # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from . import point as point +from .point import * from .transforms import * diff --git a/compressai/transforms/point/__init__.py b/compressai/transforms/point/__init__.py new file mode 100644 index 00000000..7e13c66d --- /dev/null +++ b/compressai/transforms/point/__init__.py @@ -0,0 +1,46 @@ +# Copyright (c) 2021-2022, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from .generate_position_normals import GeneratePositionNormals +from .normalize_scale_v2 import NormalizeScaleV2 +from .random_permutation import RandomPermutation +from .random_rotate_full import RandomRotateFull +from .random_sample import RandomSample +from .sample_points_v2 import SamplePointsV2 +from .to_dict import ToDict + +__all__ = [ + "GeneratePositionNormals", + "NormalizeScaleV2", + "RandomPermutation", + "RandomRotateFull", + "RandomSample", + "SamplePointsV2", + "ToDict", +] diff --git a/compressai/transforms/point/generate_position_normals.py b/compressai/transforms/point/generate_position_normals.py new file mode 100644 index 00000000..5be92532 --- /dev/null +++ b/compressai/transforms/point/generate_position_normals.py @@ -0,0 +1,83 @@ +# Copyright (c) 2021-2022, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from contextlib import suppress + +import torch + +from torch_geometric.data import Data +from torch_geometric.data.datapipes import functional_transform +from torch_geometric.transforms import BaseTransform + +from compressai.registry import register_transform + + +@functional_transform("generate_position_normals") +@register_transform("GeneratePositionNormals") +class GeneratePositionNormals(BaseTransform): + r"""Generates normals from node positions + (functional name: :obj:`generate_position_normals`). + """ + + def __init__(self, *, method="any", **kwargs): + self.method = method + self.kwargs = kwargs + + def __call__(self, data: Data) -> Data: + assert data.pos.ndim == 2 and data.pos.shape[1] == 3 + + if self.method == "open3d": + import open3d as o3d + + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(data.pos.cpu().numpy()) + pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamKNN()) + pcd.normalize_normals() + data.norm = torch.tensor( + pcd.normals, dtype=torch.float32, device=data.pos.device + ) + + return data + + if self.method == "pytorch3d": + import pytorch3d.ops + + data.norm = pytorch3d.ops.estimate_pointcloud_normals( + data.pos.unsqueeze(0), **self.kwargs + ).squeeze(0) + + return data + + if self.method == "any": + for self.method in ["open3d", "pytorch3d"]: + with suppress(ImportError): + return self(data) + raise RuntimeError("Please install open3d / pytorch3d to estimate normals.") + + raise ValueError(f"Unknown method: {self.method}") diff --git a/compressai/transforms/point/normalize_scale_v2.py b/compressai/transforms/point/normalize_scale_v2.py new file mode 100644 index 00000000..78638fc6 --- /dev/null +++ b/compressai/transforms/point/normalize_scale_v2.py @@ -0,0 +1,60 @@ +# Copyright (c) 2021-2022, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import torch + +from torch_geometric.data import Data +from torch_geometric.data.datapipes import functional_transform +from torch_geometric.transforms import BaseTransform, Center + +from compressai.registry import register_transform + + +@functional_transform("normalize_scale_v2") +@register_transform("NormalizeScaleV2") +class NormalizeScaleV2(BaseTransform): + r"""Centers and normalizes node positions + (functional name: :obj:`normalize_scale_v2`). + """ + + def __init__(self, *, center=True, scale_method="linf"): + self.scale_method = scale_method + self.center = Center() if center else lambda x: x + + def __call__(self, data: Data) -> Data: + data = self.center(data) + data.pos = data.pos / self._compute_scale(data) + return data + + def _compute_scale(self, data: Data) -> torch.Tensor: + if self.scale_method == "l2": + return (data.pos**2).sum(axis=-1).sqrt().max() + if self.scale_method == "linf": + return data.pos.abs().max() + raise ValueError(f"Unknown scale_method: {self.scale_method}") diff --git a/compressai/transforms/point/random_permutation.py b/compressai/transforms/point/random_permutation.py new file mode 100644 index 00000000..9ab04b6b --- /dev/null +++ b/compressai/transforms/point/random_permutation.py @@ -0,0 +1,51 @@ +# Copyright (c) 2021-2022, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import torch + +from torch_geometric.data import Data +from torch_geometric.data.datapipes import functional_transform +from torch_geometric.transforms import BaseTransform + +from compressai.registry import register_transform + + +@functional_transform("random_permutation") +@register_transform("RandomPermutation") +class RandomPermutation(BaseTransform): + r"""Randomly permutes points and associated attributes + (functional name: :obj:`random_permutation`). + """ + + def __init__(self, *, attrs=("pos",)): + self.attrs = attrs + + def __call__(self, data: Data) -> Data: + perm = torch.randperm(data.pos.shape[0]) + return Data(**{k: v[perm] if k in self.attrs else v for k, v in data.items()}) diff --git a/compressai/transforms/point/random_rotate_full.py b/compressai/transforms/point/random_rotate_full.py new file mode 100644 index 00000000..55100869 --- /dev/null +++ b/compressai/transforms/point/random_rotate_full.py @@ -0,0 +1,61 @@ +# Copyright (c) 2021-2022, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import torch + +from torch_geometric.data import Data +from torch_geometric.data.datapipes import functional_transform +from torch_geometric.transforms import BaseTransform + +from compressai.registry import register_transform + + +@functional_transform("random_rotate_full") +@register_transform("RandomRotateFull") +class RandomRotateFull(BaseTransform): + r"""Randomly rotates node positions around the origin + (functional name: :obj:`random_rotate_full`). + """ + + def __call__(self, data: Data) -> Data: + _, ndim = data.pos.shape + rot = random_rotation_matrix(1, ndim).to(data.pos.device).squeeze(0) + data.pos = data.pos @ rot.T + return data + + +# See https://math.stackexchange.com/questions/442418/random-generation-of-rotation-matrices/4832876#4832876 +def random_rotation_matrix(batch_size: int, ndim=3, generator=None) -> torch.Tensor: + z = torch.randn((batch_size, ndim, ndim), generator=generator) + q, r = torch.linalg.qr(z) + sign = 2 * (r.diagonal(dim1=-2, dim2=-1) >= 0) - 1 + rot = q + rot *= sign[..., None, :] + rot[:, 0, :] *= torch.linalg.det(rot)[..., None] + return rot diff --git a/compressai/transforms/point/random_sample.py b/compressai/transforms/point/random_sample.py new file mode 100644 index 00000000..9a4e66f6 --- /dev/null +++ b/compressai/transforms/point/random_sample.py @@ -0,0 +1,84 @@ +# Copyright (c) 2021-2022, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import torch + +from torch_geometric.data import Data +from torch_geometric.data.datapipes import functional_transform +from torch_geometric.transforms import BaseTransform + +from compressai.registry import register_transform + + +@functional_transform("random_sample") +@register_transform("RandomSample") +class RandomSample(BaseTransform): + r"""Randomly samples points and associated attributes + (functional name: :obj:`random_sample`). + """ + + def __init__( + self, + num=None, + *, + attrs=("pos",), + remove_duplicates_by=None, + preserve_order=False, + seed=None, + static_seed=None, + ): + self.num = num + self.attrs = attrs + self.remove_duplicates_by = remove_duplicates_by + self.preserve_order = preserve_order + self.generator = torch.Generator() + if seed is not None: + self.generator.manual_seed(seed) + self.static_seed = static_seed + + def __call__(self, data: Data) -> Data: + if self.static_seed is not None: + self.generator.manual_seed(self.static_seed) + + if self.remove_duplicates_by is not None: + _, perm = data[self.remove_duplicates_by].unique(return_inverse=True, dim=0) + for attr in self.attrs: + data[attr] = data[attr][perm] + + num_input = data[self.attrs[0]].shape[0] + assert all(data[k].shape[0] == num_input for k in self.attrs) + + p = torch.ones(max(num_input, self.num), dtype=torch.float32) + perm = torch.multinomial(p, self.num, generator=self.generator) + perm %= num_input + + if self.preserve_order: + perm = perm.sort()[0] + + return Data(**{k: v[perm] if k in self.attrs else v for k, v in data.items()}) diff --git a/compressai/transforms/point/sample_points_v2.py b/compressai/transforms/point/sample_points_v2.py new file mode 100644 index 00000000..17ca6011 --- /dev/null +++ b/compressai/transforms/point/sample_points_v2.py @@ -0,0 +1,118 @@ +# Copyright (c) 2021-2022, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import torch + +from torch_geometric.data import Data +from torch_geometric.data.datapipes import functional_transform +from torch_geometric.transforms import BaseTransform + +from compressai.registry import register_transform + + +@functional_transform("sample_points_v2") +@register_transform("SamplePointsV2") +class SamplePointsV2(BaseTransform): + r"""Uniformly samples a fixed number of points on the mesh faces according + to their face area (functional name: :obj:`sample_points`). + + Adapted from PyTorch Geometric under MIT license at + https://github.com/pyg-team/pytorch_geometric/blob/master/LICENSE. + + Args: + num (int): The number of points to sample. + remove_faces (bool, optional): If set to :obj:`False`, the face tensor + will not be removed. (default: :obj:`True`) + include_normals (bool, optional): If set to :obj:`True`, then compute + normals for each sampled point. (default: :obj:`False`) + seed (int, optional): Initial random seed. + static_seed (int, optional): Reset random seed to this every call. + """ + + def __init__( + self, + num: int, + *, + remove_faces: bool = True, + include_normals: bool = False, + seed=None, + static_seed=None, + ): + self.num = num + self.remove_faces = remove_faces + self.include_normals = include_normals + self.generator = torch.Generator() + if seed is not None: + self.generator.manual_seed(seed) + self.static_seed = static_seed + + def __call__(self, data: Data) -> Data: + assert data.pos is not None + assert data.face is not None + + if self.static_seed is not None: + self.generator.manual_seed(self.static_seed) + + pos, face = data.pos, data.face + assert pos.size(1) == 3 and face.size(0) == 3 + + pos_max = pos.abs().max() + pos = pos / pos_max + + area = (pos[face[1]] - pos[face[0]]).cross(pos[face[2]] - pos[face[0]]) + area = area.norm(p=2, dim=1).abs() / 2 + + prob = area / area.sum() + sample = torch.multinomial(prob, self.num, replacement=True) + face = face[:, sample] + + frac = torch.rand(self.num, 2, device=pos.device, generator=self.generator) + mask = frac.sum(dim=-1) > 1 + frac[mask] = 1 - frac[mask] + + vec1 = pos[face[1]] - pos[face[0]] + vec2 = pos[face[2]] - pos[face[0]] + + if self.include_normals: + data.normal = torch.nn.functional.normalize(vec1.cross(vec2), p=2) + + pos_sampled = pos[face[0]] + pos_sampled += frac[:, :1] * vec1 + pos_sampled += frac[:, 1:] * vec2 + + pos_sampled = pos_sampled * pos_max + data.pos = pos_sampled + + if self.remove_faces: + data.face = None + + return data + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.num})" diff --git a/compressai/transforms/point/to_dict.py b/compressai/transforms/point/to_dict.py new file mode 100644 index 00000000..b180a086 --- /dev/null +++ b/compressai/transforms/point/to_dict.py @@ -0,0 +1,61 @@ +# Copyright (c) 2021-2022, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from typing import Any, Dict + +import torch + +from torch_geometric.data import Data +from torch_geometric.data.datapipes import functional_transform +from torch_geometric.transforms import BaseTransform + +from compressai.registry import register_transform + + +@functional_transform("to_dict") +@register_transform("ToDict") +class ToDict(BaseTransform): + r"""Convert :obj:`Mapping[str, Any]` + (functional name: :obj:`to_dict`). + """ + + def __init__(self, *, wrapper="dict"): + if wrapper == "dict": + self.wrap = dict + elif wrapper == "torch_geometric.data.Data": + self.wrap = Data + else: + raise ValueError(f"Unknown wrapper: {wrapper}") + + def __call__(self, data) -> Dict[str, Any]: + data = { + k: v if isinstance(v, torch.Tensor) else torch.tensor(v) + for k, v in data.items() + } + return self.wrap(**data) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 8866d8ea..cd4ced50 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -23,3 +23,10 @@ Functional transforms can be used to define custom transform classes. .. automodule:: compressai.transforms.functional :members: + + +Point Cloud Transforms +---------------------- + +.. automodule:: compressai.transforms.point + :members: From bd8a0d903514fe8874d1fff82626601f816d3f01 Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Tue, 30 Jan 2024 23:35:37 -0800 Subject: [PATCH 14/27] feat: point cloud losses --- compressai/losses/__init__.py | 3 + compressai/losses/pointcloud/__init__.py | 37 ++++ compressai/losses/pointcloud/chamfer.py | 100 +++++++++++ compressai/losses/pointcloud/hrtzxf2022.py | 196 +++++++++++++++++++++ compressai/losses/utils.py | 43 +++++ docs/source/losses.rst | 17 ++ 6 files changed, 396 insertions(+) create mode 100644 compressai/losses/pointcloud/__init__.py create mode 100644 compressai/losses/pointcloud/chamfer.py create mode 100644 compressai/losses/pointcloud/hrtzxf2022.py create mode 100644 compressai/losses/utils.py diff --git a/compressai/losses/__init__.py b/compressai/losses/__init__.py index 62b6e1e2..19e2bff7 100644 --- a/compressai/losses/__init__.py +++ b/compressai/losses/__init__.py @@ -27,8 +27,11 @@ # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from . import pointcloud +from .pointcloud import * from .rate_distortion import RateDistortionLoss __all__ = [ + *pointcloud.__all__, "RateDistortionLoss", ] diff --git a/compressai/losses/pointcloud/__init__.py b/compressai/losses/pointcloud/__init__.py new file mode 100644 index 00000000..2d833e22 --- /dev/null +++ b/compressai/losses/pointcloud/__init__.py @@ -0,0 +1,37 @@ +# Copyright (c) 2021-2022, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from .chamfer import ChamferPccRateDistortionLoss, chamfer_distance +from .hrtzxf2022 import RateDistortionLoss_hrtzxf2022 + +__all__ = [ + "chamfer_distance", + "ChamferPccRateDistortionLoss", + "RateDistortionLoss_hrtzxf2022", +] diff --git a/compressai/losses/pointcloud/chamfer.py b/compressai/losses/pointcloud/chamfer.py new file mode 100644 index 00000000..10978afd --- /dev/null +++ b/compressai/losses/pointcloud/chamfer.py @@ -0,0 +1,100 @@ +# Copyright (c) 2021-2022, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +import torch +import torch.nn as nn + +from einops import rearrange +from pointops.functions import pointops + +from compressai.layers.pointcloud.hrtzxf2022 import index_points +from compressai.losses.utils import compute_rate_loss +from compressai.registry import register_criterion + + +@register_criterion("ChamferPccRateDistortionLoss") +class ChamferPccRateDistortionLoss(nn.Module): + """Simple loss for regular point cloud compression. + + For compression models that reconstruct the input point cloud. + """ + + LMBDA_DEFAULT = { + # "bpp": 1.0, + "rec": 1.0, + } + + def __init__(self, lmbda=None, rate_key="bpp"): + super().__init__() + self.lmbda = lmbda or dict(self.LMBDA_DEFAULT) + self.lmbda.setdefault(rate_key, 1.0) + + def forward(self, output, target): + out = { + **self.compute_rate_loss(output, target), + **self.compute_rec_loss(output, target), + } + + out["loss"] = sum( + self.lmbda[k] * out[f"{k}_loss"] + for k in self.lmbda.keys() + if f"{k}_loss" in out + ) + + return out + + def compute_rate_loss(self, output, target): + if "likelihoods" not in output: + return {} + N, P, _ = target["pos"].shape + return compute_rate_loss(output["likelihoods"], N, P) + + def compute_rec_loss(self, output, target): + dist1, dist2, _, _ = chamfer_distance( + target["pos"], output["x_hat"], order="b n c" + ) + loss_chamfer = dist1.mean() + dist2.mean() + return {"rec_loss": loss_chamfer} + + +def chamfer_distance(xyzs1, xyzs2, order="b n c"): + # idx1, dist1: (b, n1) + # idx2, dist2: (b, n2) + xyzs1_bcn = rearrange(xyzs1, f"{order} -> b c n").contiguous() + xyzs1_bnc = rearrange(xyzs1, f"{order} -> b n c").contiguous() + xyzs2_bcn = rearrange(xyzs2, f"{order} -> b c n").contiguous() + xyzs2_bnc = rearrange(xyzs2, f"{order} -> b n c").contiguous() + idx1 = pointops.knnquery_heap(1, xyzs2_bnc, xyzs1_bnc).long().squeeze(2) + idx2 = pointops.knnquery_heap(1, xyzs1_bnc, xyzs2_bnc).long().squeeze(2) + torch.cuda.empty_cache() + dist1 = ((xyzs1_bcn - index_points(xyzs2_bcn, idx1)) ** 2).sum(1) + dist2 = ((xyzs2_bcn - index_points(xyzs1_bcn, idx2)) ** 2).sum(1) + return dist1, dist2, idx1, idx2 diff --git a/compressai/losses/pointcloud/hrtzxf2022.py b/compressai/losses/pointcloud/hrtzxf2022.py new file mode 100644 index 00000000..5caf7d9f --- /dev/null +++ b/compressai/losses/pointcloud/hrtzxf2022.py @@ -0,0 +1,196 @@ +# Copyright (c) 2021-2022, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# Code adapted from https://github.com/yunhe20/D-PCC + +from __future__ import annotations + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from compressai.layers.pointcloud.hrtzxf2022 import index_points +from compressai.losses.utils import compute_rate_loss +from compressai.registry import register_criterion + +from .chamfer import chamfer_distance + + +@register_criterion("RateDistortionLoss_hrtzxf2022") +class RateDistortionLoss_hrtzxf2022(nn.Module): + """Loss introduced in [He2022pcc]_ for "hrtzxf2022-pcc-rec" model. + + References: + + .. [He2022pcc] `"Density-preserving Deep Point Cloud Compression" + `_, by Yun He, Xinlin Ren, + Danhang Tang, Yinda Zhang, Xiangyang Xue, and Yanwei Fu, + CVPR 2022. + """ + + LMBDA_DEFAULT = { + "bpp": 1.0, + "chamfer": 1e4, + "chamfer_layers": (1.0, 0.1, 0.1), + "latent_xyzs": 1e2, + "mean_distance": 5e1, + "normal": 1e2, + "pts_num": 5e-3, + "upsample_num": 1.0, + } + + def __init__( + self, + lmbda=None, + compress_normal=False, + latent_xyzs_codec_mode="learned", + ): + super().__init__() + self.lmbda = lmbda or dict(self.LMBDA_DEFAULT) + self.compress_normal = compress_normal + self.latent_xyzs_codec_mode = latent_xyzs_codec_mode + + def forward(self, output, target): + device = target["pos"].device + B, P, _ = target["pos"].shape + + out = {} + + chamfer_loss_, nearest_gt_idx_ = get_chamfer_loss( + output["gt_xyz_"], + output["xyz_hat_"], + ) + + out["chamfer_loss"] = sum( + self.lmbda["chamfer_layers"][i] * chamfer_loss_[i] + for i in range(len(chamfer_loss_)) + ) + + out["rec_loss"] = chamfer_loss_[0] # Name as rec_loss for compatibility. + + out["mean_distance_loss"], out["upsample_num_loss"] = get_density_loss( + output["gt_downsample_num_"], + output["gt_mean_distance_"], + output["upsample_num_hat_"], + output["mean_distance_hat_"], + nearest_gt_idx_, + ) + + out["pts_num_loss"] = get_pts_num_loss( + output["gt_xyz_"], + output["upsample_num_hat_"], + ) + + if self.latent_xyzs_codec_mode == "learned": + out["latent_xyzs_loss"] = get_latent_xyzs_loss( + output["gt_latent_xyz"], + output["latent_xyz_hat"], + ) + elif self.latent_xyzs_codec_mode == "float16": + out["latent_xyzs_loss"] = torch.tensor([0.0], device=device) + else: + raise ValueError( + f"Unknown latent_xyzs_codec_mode: {self.latent_xyzs_codec_mode}" + ) + + if self.compress_normal: + out["normal_loss"] = get_normal_loss( + output["gt_normal"], + output["feat_hat"].tanh(), + nearest_gt_idx_[0], + ) + else: + out["normal_loss"] = torch.tensor([0.0], device=device) + + if "likelihoods" in output: + out.update(compute_rate_loss(output["likelihoods"], B, P)) + + out["loss"] = sum( + self.lmbda[k] * out[f"{k}_loss"] + for k in self.lmbda.keys() + if f"{k}_loss" in out + ) + + return out + + +def get_chamfer_loss(gt_xyzs_, xyzs_hat_): + num_layers = len(gt_xyzs_) + chamfer_loss_ = [] + nearest_gt_idx_ = [] + + for i in range(num_layers): + xyzs1 = gt_xyzs_[i] + xyzs2 = xyzs_hat_[num_layers - i - 1] + dist1, dist2, _, idx2 = chamfer_distance(xyzs1, xyzs2, order="b c n") + chamfer_loss_.append(dist1.mean() + dist2.mean()) + nearest_gt_idx_.append(idx2.long()) + + return chamfer_loss_, nearest_gt_idx_ + + +def get_density_loss(gt_dnums_, gt_mdis_, unums_hat_, mdis_hat_, nearest_gt_idx_): + num_layers = len(gt_dnums_) + l1_loss = nn.L1Loss(reduction="mean") + mean_distance_loss_ = [] + upsample_num_loss_ = [] + + for i in range(num_layers): + if i == num_layers - 1: + # At the final downsample layer, gt_latent_xyzs ≈ latent_xyzs_hat. + mdis_i = gt_mdis_[i] + dnum_i = gt_dnums_[i] + else: + idx = nearest_gt_idx_[i + 1] + mdis_i = index_points(gt_mdis_[i].unsqueeze(1), idx).squeeze(1) + dnum_i = index_points(gt_dnums_[i].unsqueeze(1), idx).squeeze(1) + + mean_distance_loss_.append(l1_loss(mdis_hat_[num_layers - i - 1], mdis_i)) + upsample_num_loss_.append(l1_loss(unums_hat_[num_layers - i - 1], dnum_i)) + + return sum(mean_distance_loss_), sum(upsample_num_loss_) + + +def get_pts_num_loss(gt_xyzs_, unums_hat_): + num_layers = len(gt_xyzs_) + b, _, _ = gt_xyzs_[0].shape + gt_num_points_ = [x.shape[2] for x in gt_xyzs_] + return sum( + torch.abs(unums_hat_[num_layers - i - 1].sum() - gt_num_points_[i] * b) + for i in range(num_layers) + ) + + +def get_normal_loss(gt_normals, pred_normals, nearest_gt_idx): + nearest_normal = index_points(gt_normals, nearest_gt_idx) + return F.mse_loss(pred_normals, nearest_normal) + + +def get_latent_xyzs_loss(gt_latent_xyzs, latent_xyzs_hat): + return F.mse_loss(gt_latent_xyzs, latent_xyzs_hat) diff --git a/compressai/losses/utils.py b/compressai/losses/utils.py new file mode 100644 index 00000000..611b545b --- /dev/null +++ b/compressai/losses/utils.py @@ -0,0 +1,43 @@ +# Copyright (c) 2021-2022, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +def compute_rate_loss(likelihoods, batch_size, bit_per_bpp): + out_bit = { + f"bit_{name}_loss": lh.log2().sum() / -batch_size + for name, lh in likelihoods.items() + } + out_bpp = { + f"bpp_{name}_loss": out_bit[f"bit_{name}_loss"] / bit_per_bpp + for name in likelihoods.keys() + } + out = {**out_bit, **out_bpp} + out["bit_loss"] = sum(out_bit.values()) + out["bpp_loss"] = out["bit_loss"] / bit_per_bpp + return out diff --git a/docs/source/losses.rst b/docs/source/losses.rst index 12e61aaa..7dfbf0c3 100644 --- a/docs/source/losses.rst +++ b/docs/source/losses.rst @@ -14,3 +14,20 @@ RateDistortionLoss .. autoclass:: RateDistortionLoss :members: + + +Point cloud losses +~~~~~~~~~~~~~~~~~~ + + +ChamferPccRateDistortionLoss +---------------------------- +.. autoclass:: ChamferPccRateDistortionLoss + :members: + + +RateDistortionLoss_hrtzxf2022 +----------------------------- +.. autoclass:: RateDistortionLoss_hrtzxf2022 + :members: + From 4ece4582b88859273520be817410e97d05e89f79 Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Tue, 30 Jan 2024 23:38:49 -0800 Subject: [PATCH 15/27] feat: point cloud layers (pointnet, pointnet2, hrtzxf2022) --- compressai/layers/pointcloud/__init__.py | 28 + compressai/layers/pointcloud/hrtzxf2022.py | 844 ++++++++++++++++++ compressai/layers/pointcloud/pointnet.py | 89 ++ compressai/layers/pointcloud/pointnet2.py | 469 ++++++++++ compressai/layers/pointcloud/pointnet2_sfu.py | 57 ++ compressai/layers/pointcloud/utils.py | 244 +++++ 6 files changed, 1731 insertions(+) create mode 100644 compressai/layers/pointcloud/__init__.py create mode 100644 compressai/layers/pointcloud/hrtzxf2022.py create mode 100644 compressai/layers/pointcloud/pointnet.py create mode 100644 compressai/layers/pointcloud/pointnet2.py create mode 100644 compressai/layers/pointcloud/pointnet2_sfu.py create mode 100644 compressai/layers/pointcloud/utils.py diff --git a/compressai/layers/pointcloud/__init__.py b/compressai/layers/pointcloud/__init__.py new file mode 100644 index 00000000..e4861cbe --- /dev/null +++ b/compressai/layers/pointcloud/__init__.py @@ -0,0 +1,28 @@ +# Copyright (c) 2021-2022, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/compressai/layers/pointcloud/hrtzxf2022.py b/compressai/layers/pointcloud/hrtzxf2022.py new file mode 100644 index 00000000..293f6aa0 --- /dev/null +++ b/compressai/layers/pointcloud/hrtzxf2022.py @@ -0,0 +1,844 @@ +# Copyright (c) 2021-2022, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# Code adapted from https://github.com/yunhe20/D-PCC + +import math + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange, repeat +from pointops.functions import pointops + +from .utils import index_points + + +class DownsampleLayer(nn.Module): + """Downsampling layer used in [He2022pcc]_. + + Downsamples positions into a smaller number of centroids. + Each centroid is grouped with nearby points, + and the local point density is estimated for that group. + Then, the positions, features, and density for the group + are embedded into a single aggregate vector from which the + group of points may later be reconstructed. + + References: + + .. [He2022pcc] `"Density-preserving Deep Point Cloud Compression" + `_, by Yun He, Xinlin Ren, + Danhang Tang, Yinda Zhang, Xiangyang Xue, and Yanwei Fu, + CVPR 2022. + """ + + def __init__(self, downsample_rate, dim, hidden_dim, k, ngroups): + super().__init__() + self.k = k + self.downsample_rate = downsample_rate + self.pre_conv = nn.Conv1d(dim, dim, 1) + self.embed_features = PointTransformerLayer(dim, dim, hidden_dim, ngroups) + self.embed_positions = PositionEmbeddingLayer(hidden_dim, dim, ngroups) + self.embed_densities = DensityEmbeddingLayer(hidden_dim, dim, ngroups) + self.post_conv = nn.Conv1d(dim * 3, dim, 1) + + def get_density(self, downsampled_xyzs, input_xyzs): + # downsampled_xyzs: (b, 3, m) + # input_xyzs: (b, 3, n) + # nn_idx: (b, n, 1) + # downsample_num: (b, m) + # knn_idx: (b, m, k) + # mask: (b, m, k) + # distance: (b, m) + # mean_distance: (b, m) + _, _, n = input_xyzs.shape + distance, mask, knn_idx, _ = nearby_distance_sum( + downsampled_xyzs, input_xyzs, min(self.k, n) + ) + downsample_num = mask.sum(dim=-1).float() + mean_distance = distance / downsample_num + return downsample_num, mean_distance, mask, knn_idx + + def forward(self, xyzs, feats): + # xyzs: (b, 3, n) + # features: (b, cin, n) + # sample_idx: (b, m) + # sampled_xyzs: (b, 3, m) + # sampled_feats: (b, c, m) + + # Downsample positions into a smaller number of centroids. + sampled_xyzs, sample_idx = self.downsample_positions(xyzs, feats) + + # For each centroid, form a group with nearby points. + # Also, estimate local point density ("mean distance") for each centroid. + downsample_num, mean_distance, mask, knn_idx = self.get_density( + sampled_xyzs, xyzs + ) + + # Embed features, positions, and density for each downsampled + # point group into a single aggregate vector for that group. + sampled_feats = self.downsample_features( + sampled_xyzs, xyzs, feats, downsample_num, sample_idx, knn_idx, mask + ) + + return sampled_xyzs, sampled_feats, downsample_num, mean_distance + + def downsample_positions(self, xyzs, sample_num): + _, _, n = xyzs.shape + sample_num = round(n * self.downsample_rate) + xyzs_tr = xyzs.permute(0, 2, 1).contiguous() + sample_idx = pointops.furthestsampling(xyzs_tr, sample_num).long() + sampled_xyzs = index_points(xyzs, sample_idx) + return sampled_xyzs, sample_idx + + def downsample_features( + self, sampled_xyzs, xyzs, feats, downsample_num, sample_idx, knn_idx, mask + ): + # sampled_xyzs: (b, 3, m) + # sampled_feats: (b, c, m) + + identity = index_points(feats, sample_idx) + + feats = self.pre_conv(feats) + sampled_feats = index_points(feats, sample_idx) + embeddings = [ + self.embed_features( + sampled_xyzs, xyzs, sampled_feats, feats, feats, knn_idx, mask + ), + self.embed_positions(sampled_xyzs, xyzs, knn_idx, mask), + self.embed_densities(downsample_num.unsqueeze(1)), + ] + agg_embedding = self.post_conv(torch.cat(embeddings, dim=1)) + + sampled_feats_new = agg_embedding + identity + return sampled_feats_new + + +class PointTransformerLayer(nn.Module): + """Point Transformer layer introduced by [Zhao2021]_. + + References: + + .. [Zhao2021] `"Point Transformer" + `_, by Hengshuang Zhao, + Li Jiang, Jiaya Jia, Philip Torr, and Vladlen Koltun, + CVPR 2021. + """ + + def __init__(self, in_fdim, out_fdim, hidden_dim, ngroups): + super().__init__() + + self.w_qs = nn.Conv1d(in_fdim, hidden_dim, 1) + self.w_ks = nn.Conv1d(in_fdim, hidden_dim, 1) + self.w_vs = nn.Conv1d(in_fdim, hidden_dim, 1) + + self.conv_delta = nn.Sequential( + nn.Conv2d(3, hidden_dim, 1), + nn.GroupNorm(ngroups, hidden_dim), + nn.ReLU(inplace=True), + nn.Conv2d(hidden_dim, hidden_dim, 1), + ) + + self.conv_gamma = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim, 1), + nn.GroupNorm(ngroups, hidden_dim), + nn.ReLU(inplace=True), + nn.Conv2d(hidden_dim, hidden_dim, 1), + ) + + self.post_conv = nn.Conv1d(hidden_dim, out_fdim, 1) + + def forward(self, q_xyzs, k_xyzs, q_feats, k_feats, v_feats, knn_idx, mask): + # q: (b, c, m) + # k: (b, c, n) + # knn_idx: (b, m, k) + # mask: (b, m, k) + # knn_xyzs: (b, 3, m, k) + # query: (b, c, m) + # key: (b, c, m, k) + # pos_enc: (b, c, m, k) + # attn: (b, c, m, k) + + knn_xyzs = index_points(k_xyzs, knn_idx) + + # NOTE: it's q_feats, not v_feats! + identity = q_feats + + query = self.w_qs(q_feats) + key = index_points(self.w_ks(k_feats), knn_idx) + value = index_points(self.w_vs(v_feats), knn_idx) + + pos_enc = self.conv_delta(q_xyzs.unsqueeze(-1) - knn_xyzs) + + attn = self.conv_gamma(query.unsqueeze(-1) - key + pos_enc) + attn = attn / math.sqrt(key.shape[1]) + mask_value = -(torch.finfo(attn.dtype).max) + attn.masked_fill_(~mask[:, None], mask_value) + attn = F.softmax(attn, dim=-1) + + result = torch.einsum("bcmk, bcmk -> bcm", attn, value + pos_enc) + result = self.post_conv(result) + identity + + return result + + +class PositionEmbeddingLayer(nn.Module): + """Position embedding for downsampling, as introduced in [He2022pcc]_. + + For each group of feature vectors (f₁, ..., fₖ) with centroid fₒ, + represents the offsets (f₁ - fₒ, ..., fₖ - fₒ) as + magnitude-direction vectors, then applies an MLP to each vector, + then takes a softmax self-attention over the resulting vectors, + and finally reduces the vectors via a sum, + resulting in a single embedded vector for the group. + + References: + + .. [He2022pcc] `"Density-preserving Deep Point Cloud Compression" + `_, by Yun He, Xinlin Ren, + Danhang Tang, Yinda Zhang, Xiangyang Xue, and Yanwei Fu, + CVPR 2022. + """ + + def __init__(self, hidden_dim, dim, ngroups): + super().__init__() + + self.embed_positions = nn.Sequential( + nn.Conv2d(4, hidden_dim, 1), + nn.GroupNorm(ngroups, hidden_dim), + nn.ReLU(inplace=True), + nn.Conv2d(hidden_dim, dim, 1), + ) + + self.attention = nn.Sequential( + nn.Conv2d(dim, hidden_dim, 1), + nn.GroupNorm(ngroups, hidden_dim), + nn.ReLU(inplace=True), + nn.Conv2d(hidden_dim, dim, 1), + ) + + def forward(self, q_xyzs, k_xyzs, knn_idx, mask): + # q_xyzs: (b, 3, m) + # k_xyzs: (b, 3, n) + # knn_idx: (b, m, k) + # mask: (b, m, k) + # knn_xyzs: (b, 3, m, k) + # repeated_xyzs: (b, 3, m, k) + # direction: (b, 3, m, k) + # distance: (b, 1, m, k) + # local_pattern: (b, 4, m, k) + # position_embedding_expanded: (b, c, m, k) + # attn: (b, c, m, k) + # position_embedding: (b, c, m) + + # "query" (q_xyzs) points are the centroids of each point group. + # "key" (k_xyzs) points are points in the neighborhood of each centroid. + + _, _, k = knn_idx.shape + knn_xyzs = index_points(k_xyzs, knn_idx) + repeated_xyzs = q_xyzs[..., None].repeat(1, 1, 1, k) + + # Represent points within a point group as (direction, distance) + # of offsets from the group centroid. + offset_xyzs = knn_xyzs - repeated_xyzs + direction = F.normalize(offset_xyzs, p=2, dim=1) + distance = torch.linalg.norm(offset_xyzs, dim=1, keepdim=True) + local_pattern = torch.cat((direction, distance), dim=1) + + # Apply a pointwise MLP to each point. + position_embedding_expanded = self.embed_positions(local_pattern) + + # Compute self-attention, ignoring points that are not in the + # neighborhood of the centroid. + attn = self.attention(position_embedding_expanded) + mask_value = -(torch.finfo(attn.dtype).max) + attn.masked_fill_(~mask[:, None], mask_value) + attn = F.softmax(attn, dim=-1) + position_embedding = (position_embedding_expanded * attn).sum(dim=-1) + + return position_embedding + + +class DensityEmbeddingLayer(nn.Module): + """Density embedding for downsampling, as introduced in [He2022pcc]_. + + Applies an embedding ℝ → ℝᶜ to the local point density (scalar). + The local point density is measured using the mean distance of the + points within the neighborhood of a "downsampled" centroid. + This information is useful when upsampling from the single centroid. + + References: + + .. [He2022pcc] `"Density-preserving Deep Point Cloud Compression" + `_, by Yun He, Xinlin Ren, + Danhang Tang, Yinda Zhang, Xiangyang Xue, and Yanwei Fu, + CVPR 2022. + """ + + def __init__(self, hidden_dim, dim, ngroups): + super().__init__() + self.embed_densities = nn.Sequential( + nn.Conv1d(1, hidden_dim, 1), + nn.GroupNorm(ngroups, hidden_dim), + nn.ReLU(inplace=True), + nn.Conv1d(hidden_dim, dim, 1), + ) + + def forward(self, downsample_num): + # downsample_num: (b, 1, n) + # density_embedding: (b, c, n) + density_embedding = self.embed_densities(downsample_num) + return density_embedding + + +class UpsampleLayer(nn.Module): + """Upsampling layer used in [He2022pcc]_. + + Upsamples many candidate points from a smaller number of centroids. + (Not all candidate upsampled points will be kept; some will be + thrown away to match the predicted local point density.) + + References: + + .. [He2022pcc] `"Density-preserving Deep Point Cloud Compression" + `_, by Yun He, Xinlin Ren, + Danhang Tang, Yinda Zhang, Xiangyang Xue, and Yanwei Fu, + CVPR 2022. + """ + + def __init__(self, dim, hidden_dim, k, sub_point_conv_mode, upsample_rate): + super().__init__() + self.xyzs_upsample_nn = XyzsUpsampleLayer( + dim, hidden_dim, k, sub_point_conv_mode, upsample_rate + ) + self.feats_upsample_nn = FeatsUpsampleLayer( + dim, hidden_dim, k, sub_point_conv_mode, upsample_rate + ) + + def forward(self, xyzs, feats): + upsampled_xyzs = self.xyzs_upsample_nn(xyzs, feats) + upsampled_feats = self.feats_upsample_nn(feats) + return upsampled_xyzs, upsampled_feats + + +class UpsampleNumLayer(nn.Module): + """Predicts local point density while upsampling, as used in [He2022pcc]_. + + Extracts the number of candidate points to keep after upsampling + from a given "centroid" feature vector. + (Some candidate upsampled points will be thrown away to match the + predicted local point density.) + + References: + + .. [He2022pcc] `"Density-preserving Deep Point Cloud Compression" + `_, by Yun He, Xinlin Ren, + Danhang Tang, Yinda Zhang, Xiangyang Xue, and Yanwei Fu, + CVPR 2022. + """ + + def __init__(self, dim, hidden_dim, upsample_rate): + super().__init__() + self.upsample_rate = upsample_rate + self.upsample_num_nn = nn.Sequential( + nn.Conv1d(dim, hidden_dim, 1), + nn.ReLU(), + nn.Conv1d(hidden_dim, 1, 1), + nn.Sigmoid(), + ) + + def forward(self, feats): + # upsample_num: (b, n) + upsample_frac = self.upsample_num_nn(feats).squeeze(1) + upsample_num = upsample_frac * (self.upsample_rate - 1) + 1 + return upsample_num + + +class RefineLayer(nn.Module): + """Refines upsampled points, as used in [He2022pcc]_. + + After the centroids are upsampled, there may be overlapping + point groups between nearby centroids, and other artifacts. + Refinement should help correct various such artifacts. + + References: + + .. [He2022pcc] `"Density-preserving Deep Point Cloud Compression" + `_, by Yun He, Xinlin Ren, + Danhang Tang, Yinda Zhang, Xiangyang Xue, and Yanwei Fu, + CVPR 2022. + """ + + def __init__(self, dim, hidden_dim, k, sub_point_conv_mode, decompress_normal): + super().__init__() + + self.xyzs_refine_nn = XyzsUpsampleLayer( + dim, + hidden_dim, + k, + sub_point_conv_mode, + upsample_rate=1, + ) + + self.feats_refine_nn = FeatsUpsampleLayer( + dim, + hidden_dim, + k, + sub_point_conv_mode, + upsample_rate=1, + decompress_normal=decompress_normal, + ) + + def forward(self, xyzs, feats): + # refined_xyzs: (b, 3, n, 1) + # refined_xyzs: (b, 3, n) [after rearrange] + # refined_feats: (b, c, n, 1) + # refined_feats: (b, c, n) [after rearrange] + + refined_xyzs = self.xyzs_refine_nn(xyzs, feats) + refined_xyzs = rearrange(refined_xyzs, "b c n u -> b c (n u)") + + refined_feats = self.feats_refine_nn(feats) + refined_feats = rearrange(refined_feats, "b c n u -> b c (n u)") + + return refined_xyzs, refined_feats + + +class XyzsUpsampleLayer(nn.Module): + """Position upsampling layer used in [He2022pcc]_. + + Upsamples many positions from each "centroid" feature vector. + Each feature vector is upsampled into various offsets represented as + magnitude-direction vectors, where each direction is determined by a + weighted sum of various fixed hypothesized directions. + From this, the candidate upsampled positions are simply the + the offset vectors plus their original centroid position. + + References: + + .. [He2022pcc] `"Density-preserving Deep Point Cloud Compression" + `_, by Yun He, Xinlin Ren, + Danhang Tang, Yinda Zhang, Xiangyang Xue, and Yanwei Fu, + CVPR 2022. + """ + + def __init__(self, dim, hidden_dim, k, sub_point_conv_mode, upsample_rate): + super().__init__() + + self.upsample_rate = upsample_rate + + # The hypothesis is a basis of 43 candidate direction vectors. + hypothesis, _ = icosahedron2sphere(1) + hypothesis = np.append(np.zeros((1, 3)), hypothesis, axis=0) + self.hypothesis = torch.from_numpy(hypothesis).float().cuda() + + self.weight_nn = SubPointConv( + hidden_dim, k, sub_point_conv_mode, dim, 43 * upsample_rate, upsample_rate + ) + + self.scale_nn = SubPointConv( + hidden_dim, k, sub_point_conv_mode, dim, 1 * upsample_rate, upsample_rate + ) + + def forward(self, xyzs, feats): + # xyzs: (b, 3, n) + # feats: (b, c, n) + # weights: (b, 43, n, u) + # weights: (b, 43, 1, n, u) [after unsqueeze] + # hypothesis: (b, 43, 3, n, u) + # directions: (b, 3, n, u) + # scales: (b, 1, n, u) + # deltas: (b, 3, n, u) + # repeated_xyzs: (b, 3, n, u) + + batch_size = xyzs.shape[0] + points_num = xyzs.shape[2] + + weights = self.weight_nn(feats) + weights = weights.unsqueeze(2) + weights = F.softmax(weights, dim=1) + + hypothesis = repeat( + self.hypothesis, + "h c -> b h c n u", + b=batch_size, + n=points_num, + u=self.upsample_rate, + ) + weighted_hypothesis = weights * hypothesis + directions = torch.sum(weighted_hypothesis, dim=1) + directions = F.normalize(directions, p=2, dim=1) + + scales = self.scale_nn(feats) + + deltas = directions * scales + + repeated_xyzs = repeat(xyzs, "b c n -> b c n u", u=self.upsample_rate) + upsampled_xyzs = repeated_xyzs + deltas + + return upsampled_xyzs + + +class FeatsUpsampleLayer(nn.Module): + """Feature upsampling layer used in [He2022pcc]_. + + Upsamples many features from each "centroid" feature vector. + The feature vector associated with each centroid is upsampled + into various candidate upsampled feature vectors. + + References: + + .. [He2022pcc] `"Density-preserving Deep Point Cloud Compression" + `_, by Yun He, Xinlin Ren, + Danhang Tang, Yinda Zhang, Xiangyang Xue, and Yanwei Fu, + CVPR 2022. + """ + + def __init__( + self, + dim, + hidden_dim, + k, + sub_point_conv_mode, + upsample_rate, + decompress_normal=False, + ): + super().__init__() + + self.upsample_rate = upsample_rate + self.decompress_normal = decompress_normal + out_fdim = (3 if decompress_normal else dim) * upsample_rate + + self.feats_nn = SubPointConv( + hidden_dim, k, sub_point_conv_mode, dim, out_fdim, upsample_rate + ) + + def forward(self, feats): + # upsampled_feats: (b, c, n, u) + upsampled_feats = self.feats_nn(feats) + if not self.decompress_normal: + repeated_feats = repeat(feats, "b c n -> b c n u", u=self.upsample_rate) + upsampled_feats = upsampled_feats + repeated_feats + return upsampled_feats + + +class SubPointConv(nn.Module): + """Sub-point convolution for upsampling, as introduced in [He2022pcc]_. + + Each feature vector (representing a "centroid" point) is sliced + into g feature vectors, where each feature vector represents a + point that has been upsampled from the original centroid point. + Then, an MLP is applied to each slice individually. + + References: + + .. [He2022pcc] `"Density-preserving Deep Point Cloud Compression" + `_, by Yun He, Xinlin Ren, + Danhang Tang, Yinda Zhang, Xiangyang Xue, and Yanwei Fu, + CVPR 2022. + """ + + def __init__(self, hidden_dim, k, mode, in_fdim, out_fdim, group_num): + super().__init__() + + self.mode = mode + self.group_num = group_num + group_in_fdim = in_fdim // group_num + group_out_fdim = out_fdim // group_num + + if self.mode == "mlp": + self.mlp = nn.Sequential( + nn.Conv2d(group_in_fdim, hidden_dim, 1), + nn.ReLU(inplace=True), + nn.Conv2d(hidden_dim, group_out_fdim, 1), + ) + elif self.mode == "edge_conv": + self.edge_conv = EdgeConv(in_fdim, out_fdim, hidden_dim, k) + else: + raise ValueError(f"Unknown mode: {self.mode}") + + def forward(self, feats): + # feats: (b, cin * g, n) + # expanded_feats: (b, cout, n, g) + + if self.mode == "mlp": + feats = rearrange( + feats, "b (c g) n -> b c n g", g=self.group_num + ).contiguous() + expanded_feats = self.mlp(feats) + elif self.mode == "edge_conv": + expanded_feats = self.edge_conv(feats) + expanded_feats = rearrange( + expanded_feats, "b (c g) n -> b c n g", g=self.group_num + ).contiguous() + else: + raise ValueError(f"Unknown mode: {self.mode}") + + return expanded_feats + + +class EdgeConv(nn.Module): + """EdgeConv introduced by [Wang2019dgcnn]_. + + First, groups similar feature vectors together via k-nearest neighbors + using the following distance metric between feature vectors fᵢ and fⱼ: + distance[i, j] = 2fᵢᵀfⱼ - ||fᵢ||² - ||fⱼ||². + + Then, for each group of feature vectors (f₁, ..., fₖ) with centroid fₒ, + the residual feature vectors are each concatenated with the centroid, + then an MLP is applied to each resulting vector individually, + i.e., (MLP(f₁ - fₒ, fₒ), ..., MLP(fₖ - fₒ, fₒ)), + and finally the elementwise max is taken across the resulting vectors, + resulting in a single vector fₘₐₓ for the group. + + Original code located at [DGCNN]_ under MIT License. + + References: + + .. [Wang2019dgcnn] `"Dynamic Graph CNN for Learning on Point Clouds" + `_, by Yue Wang, Yongbin + Sun, Ziwei Liu, Sanjay E. Sarma, Michael M. Bronstein, + Justin M. Solomon, ACM Transactions on Graphics 2019. + + .. [DGCNN] `DGCNN + `_ + """ + + def __init__(self, in_fdim, out_fdim, hidden_dim, k): + super().__init__() + self.k = k + self.conv = nn.Sequential( + nn.Conv2d(2 * in_fdim, hidden_dim, 1), + nn.ReLU(), + nn.Conv2d(hidden_dim, out_fdim, 1), + ) + + # WARN: This requires at least O(n^2) memory. + def knn(self, feats, k): + # feats: (b, c, n) + # sq_norm: (b, 1, n) + # pairwise_dot: (b, n, n) + # pairwise_distance: (b, n, n) + # knn_idx: (b, n, k) + + sq_norm = (feats**2).sum(dim=1, keepdim=True) + + # Pairwise dot product fᵢᵀfⱼ between feature vectors fᵢ and fⱼ. + pairwise_dot = torch.matmul(feats.transpose(2, 1), feats) + + # pairwise_distance[i, j] = 2fᵢᵀfⱼ - ||fᵢ||² - ||fⱼ||² + pairwise_distance = 2 * pairwise_dot - sq_norm - sq_norm.transpose(2, 1) + + _, knn_idx = pairwise_distance.topk(k=k, dim=-1) + return knn_idx + + def get_graph_features(self, feats, k): + # knn_feats: (b, c, n, k) + # graph_feats: (b, 2c, n, k) + dim = feats.shape[1] + if dim == 3: + feats_tr = feats.permute(0, 2, 1).contiguous() + knn_idx = pointops.knnquery_heap(k, feats_tr, feats_tr).long() + else: + knn_idx = self.knn(feats, k) + torch.cuda.empty_cache() + knn_feats = index_points(feats, knn_idx) + repeated_feats = repeat(feats, "b c n -> b c n k", k=k) + graph_feats = torch.cat((knn_feats - repeated_feats, repeated_feats), dim=1) + return graph_feats + + def forward(self, feats): + # feats: (b, c, n) + # graph_feats: (b, 2c, n, k) + # expanded_feats: (b, cout*g, n, k) + # feats_new: (b, cout*g, n) + _, _, n = feats.shape + graph_feats = self.get_graph_features(feats, k=min(self.k, n)) + expanded_feats = self.conv(graph_feats) + feats_new, _ = expanded_feats.max(dim=-1) + return feats_new + + +def icosahedron2sphere(level): + """Samples uniformly on a sphere using a icosahedron. + + Code adapted from [IcoSphere_MATLAB]_ and [IcoSphere_Python]_, + from paper [Xiao2009]_. + + References: + + .. [Xiao2009] `"Image-based street-side city modeling" + `_, + by Jianxiong Xiao, Tian Fang, Peng Zhao, Maxime Lhuillier, + and Long Quan, ACM Transactions on Graphics, 2009. + + .. [IcoSphere_MATLAB] https://github.com/jianxiongxiao/ProfXkit/blob/master/icosahedron2sphere/icosahedron2sphere.m + + .. [IcoSphere_Python] https://github.com/23michael45/PanoContextTensorflow/blob/master/PanoContextTensorflow/icosahedron2sphere.py + """ + a = 2 / (1 + np.sqrt(5)) + + # fmt: off + M = np.array([ + 0, a, -1, a, 1, 0, -a, 1, 0, # noqa: E241, E126 + 0, a, 1, -a, 1, 0, a, 1, 0, # noqa: E241 + 0, a, 1, 0, -a, 1, -1, 0, a, # noqa: E241 + 0, a, 1, 1, 0, a, 0, -a, 1, # noqa: E241 + 0, a, -1, 0, -a, -1, 1, 0, -a, # noqa: E241 + 0, a, -1, -1, 0, -a, 0, -a, -1, # noqa: E241 + 0, -a, 1, a, -1, 0, -a, -1, 0, # noqa: E241 + 0, -a, -1, -a, -1, 0, a, -1, 0, # noqa: E241 + -a, 1, 0, -1, 0, a, -1, 0, -a, # noqa: E241, E131 + -a, -1, 0, -1, 0, -a, -1, 0, a, # noqa: E241 + a, 1, 0, 1, 0, -a, 1, 0, a, # noqa: E241 + a, -1, 0, 1, 0, a, 1, 0, -a, # noqa: E241 + 0, a, 1, -1, 0, a, -a, 1, 0, # noqa: E241 + 0, a, 1, a, 1, 0, 1, 0, a, # noqa: E241 + 0, a, -1, -a, 1, 0, -1, 0, -a, # noqa: E241 + 0, a, -1, 1, 0, -a, a, 1, 0, # noqa: E241 + 0, -a, -1, -1, 0, -a, -a, -1, 0, # noqa: E241 + 0, -a, -1, a, -1, 0, 1, 0, -a, # noqa: E241 + 0, -a, 1, -a, -1, 0, -1, 0, a, # noqa: E241 + 0, -a, 1, 1, 0, a, a, -1, 0, # noqa: E241 + ]) + # fmt: on + + coor = M.reshape(60, 3) + coor, idx = np.unique(coor, return_inverse=True, axis=0) + tri = idx.reshape(20, 3) + + # extrude + coor_norm = np.linalg.norm(coor, axis=1, keepdims=True) + coor = list(coor / np.tile(coor_norm, (1, 3))) + + for _ in range(level): + tris = [] + + for t in range(len(tri)): + n = len(coor) + coor.extend( + [ + (coor[tri[t, 0]] + coor[tri[t, 1]]) / 2, + (coor[tri[t, 1]] + coor[tri[t, 2]]) / 2, + (coor[tri[t, 2]] + coor[tri[t, 0]]) / 2, + ] + ) + tris.extend( + [ + [n, tri[t, 0], n + 2], + [n, tri[t, 1], n + 1], + [n + 1, tri[t, 2], n + 2], + [n, n + 1, n + 2], + ] + ) + + tri = np.array(tris) + + # uniquefy + coor, idx = np.unique(coor, return_inverse=True, axis=0) + tri = idx[tri] + + # extrude + coor_norm = np.linalg.norm(coor, axis=1, keepdims=True) + coor = list(coor / np.tile(coor_norm, (1, 3))) + + return np.array(coor), np.array(tri) + + +def nearby_distance_sum(a_xyzs, b_xyzs, k): + """Computes sum of nearby distances to B for each point in A. + + Partitions a point set B into non-intersecting sets + C(a_1), ..., C(a_m) where each C(a_i) contains points that are + nearest to a_i ∈ A. + For each a_i ∈ A, computes the total distance from a_i to C(a_i). + (Note that C(a_1), ..., C(a_m) may not cover all of B.) + + In more precise terms: + For each a ∈ A, let C(a) ⊆ B denote its "collapsed point set" s.t. + (i) b ∈ C(a) ⇒ min_{a' ∈ A} ||a' - b|| = ||a - b||, + (ii) ⋃ _{a ∈ A} C(a) ⊆ B, + (iii) ⋂ _{a ∈ A} C(a) = ∅, and + (iv) |C(a)| ≤ k. + For each a ∈ A, we then compute d(a) = ∑_{b ∈ C(a)} ||a - b||. + + Args: + a_xyzs: (b, 3, m) Input point set A. + b_xyzs: (b, 3, n) Input point set B. + k: Maximum number of points in each collapsed point set C(a_i). + + Returns: + distance: (b, m) Sum of distances from each point in A to its + collapsed point set. + mask: (b, m, k) Mask indicating which points in the ``knn_idx`` + belong to the collapsed point set of each point in A. + knn_idx: (b, m, k) Indices of the points in B that are nearest + to each point in A. + nn_idx: (b, n, 1) Indices of the point in A that is nearest + to each point in B. + """ + # expect_idx: (b, m, k) + # actual_idx: (b, m, k) + # knn_xyzs: (b, 3, m, k) + # knn_distances: (b, m, k) + + device = a_xyzs.device + _, _, m = a_xyzs.shape + a_xyzs_tr = a_xyzs.permute(0, 2, 1).contiguous() + b_xyzs_tr = b_xyzs.permute(0, 2, 1).contiguous() + + # Determine which point in A each point in B is closest to. + nn_idx = pointops.knnquery_heap(1, a_xyzs_tr, b_xyzs_tr) + nn_idx_tr = nn_idx.permute(0, 2, 1).contiguous() + + # Determine k nearest neighbors in B for each point in A. + knn_idx = pointops.knnquery_heap(k, b_xyzs_tr, a_xyzs_tr).long() + torch.cuda.empty_cache() + + # Mask points that do not belong to the collapsed points set C(a). + expect_idx = torch.arange(m, device=device)[None, :, None] + actual_idx = index_points(nn_idx_tr, knn_idx).squeeze(1) + mask = expect_idx == actual_idx + + # Compute the distance from each A point to its k nearest neighbors in B. + knn_xyzs = index_points(b_xyzs, knn_idx) + knn_distances = torch.linalg.norm(knn_xyzs - a_xyzs[..., None], dim=1) + + # Zero away the distances for points that are not in C(a). + # knn_distances.masked_fill_(~mask, 0) + knn_distances = knn_distances * mask.float() + + # Compute masked distances. + # Notably, distance.sum(-1) is upper bounded by the sum of the + # distances from each point in B to its nearest point in A. + distance = knn_distances.sum(dim=-1) + + return distance, mask, knn_idx, nn_idx diff --git a/compressai/layers/pointcloud/pointnet.py b/compressai/layers/pointcloud/pointnet.py new file mode 100644 index 00000000..aec9c424 --- /dev/null +++ b/compressai/layers/pointcloud/pointnet.py @@ -0,0 +1,89 @@ +# Copyright (c) 2021-2022, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +import torch.nn as nn + +from compressai.layers.basic import Gain, Interleave, Reshape + +GAIN = 10.0 + + +def conv1d_group_seq( + num_channels, + groups=None, + kernel_size=1, + enabled=("bn", "act"), + enabled_final=("bn", "act"), +): + if groups is None: + groups = [1] * (len(num_channels) - 1) + assert len(num_channels) == 0 or len(groups) == len(num_channels) - 1 + xs = [] + for i in range(len(num_channels) - 1): + is_final = i + 1 == len(num_channels) - 1 + xs.append( + nn.Conv1d( + num_channels[i], num_channels[i + 1], kernel_size, groups=groups[i] + ) + ) + # ChannelShuffle is only required between consecutive group convs. + if not is_final and groups[i] > 1 and groups[i + 1] > 1: + xs.append(Interleave(groups[i])) + if "bn" in enabled and (not is_final or "bn" in enabled_final): + xs.append(nn.BatchNorm1d(num_channels[i + 1])) + if "act" in enabled and (not is_final or "act" in enabled_final): + xs.append(nn.ReLU(inplace=True)) + return nn.Sequential(*xs) + + +def pointnet_g_a_simple(num_channels, groups=None, gain=GAIN): + return nn.Sequential( + *conv1d_group_seq(num_channels, groups), + nn.AdaptiveMaxPool1d(1), + Gain((num_channels[-1], 1), gain), + ) + + +def pointnet_g_s_simple(num_channels, gain=GAIN): + return nn.Sequential( + Gain((num_channels[0], 1), 1 / gain), + *conv1d_group_seq(num_channels, enabled=["act"], enabled_final=[]), + Reshape((num_channels[-1] // 3, 3)), + ) + + +def pointnet_classification_backend(num_channels): + return nn.Sequential( + *conv1d_group_seq(num_channels[:-1], enabled_final=[]), + nn.Dropout(0.3), + nn.Conv1d(num_channels[-2], num_channels[-1], 1), + Reshape((num_channels[-1],)), + ) diff --git a/compressai/layers/pointcloud/pointnet2.py b/compressai/layers/pointcloud/pointnet2.py new file mode 100644 index 00000000..e9b7774d --- /dev/null +++ b/compressai/layers/pointcloud/pointnet2.py @@ -0,0 +1,469 @@ +# Copyright (c) 2021-2022, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# NOTE: This module has been adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch +# +# See https://github.com/yanx27/Pointnet_Pointnet2_pytorch/blob/master/LICENSE. +# LICENSE is also reproduced below: +# +# +# MIT License +# +# Copyright (c) 2019 benny +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from concurrent.futures.thread import ThreadPoolExecutor + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def pc_normalize(pc): + centroid = np.mean(pc, axis=0) + pc = pc - centroid + m = np.max(np.sqrt(np.sum(pc**2, axis=1))) + pc = pc / m + return pc + + +def square_distance(src, dst): + """ + Calculate Euclid distance between each two points. + + src^T * dst = xn * xm + yn * ym + zn * zm; + sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; + sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; + dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 + = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst + + Input: + src: source points, [B, N, C] + dst: target points, [B, M, C] + Output: + dist: per-point square distance, [B, N, M] + """ + B, N, _ = src.shape + _, M, _ = dst.shape + dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) + dist += torch.sum(src**2, -1).view(B, N, 1) + dist += torch.sum(dst**2, -1).view(B, 1, M) + return dist + + +def index_points(points, idx): + """ + + Input: + points: input points data, [B, N, C] + idx: sample index data, [B, S] + Return: + new_points:, indexed points data, [B, S, C] + """ + device = points.device + B = points.shape[0] + view_shape = list(idx.shape) + view_shape[1:] = [1] * (len(view_shape) - 1) + repeat_shape = list(idx.shape) + repeat_shape[0] = 1 + batch_indices = ( + torch.arange(B, dtype=torch.long) + .to(device) + .view(view_shape) + .repeat(repeat_shape) + ) + new_points = points[batch_indices, idx, :] + return new_points + + +def farthest_point_sample(xyz, npoint, _method="pointops"): + """ + Input: + xyz: pointcloud data, [B, N, 3] + npoint: number of samples + Return: + centroids: sampled pointcloud index, [B, npoint] + """ + if _method == "pointops": + from pointops.functions import pointops + + xyz_tr = xyz.permute(0, 2, 1).contiguous() + indices = pointops.furthestsampling(xyz_tr, npoint).long() + return indices + + if _method == "pytorch3d": + from pytorch3d.ops import sample_farthest_points + + sampled_points, indices = sample_farthest_points(xyz, K=npoint) + return indices + + if _method.startswith("fpsample"): + import fpsample + + if _method == "fpsample.npdu_kdtree": + func = fpsample.fps_npdu_kdtree_sampling + if _method == "fpsample.bucket": + func = lambda *args: fpsample.bucket_fps_kdline_sampling( # noqa: E731 + *args, h=5 + ) + + with ThreadPoolExecutor(max_workers=min(8, len(xyz))) as executor: + indices = list(executor.map(lambda pc: func(pc, npoint), xyz.cpu().numpy())) + + indices = torch.from_numpy(np.stack(indices, dtype=np.int64)).to(xyz.device) + return indices + + if _method == "yanx27": + return _farthest_point_sample_yanx27(xyz, npoint) + + raise ValueError(f"Unknown method {_method}") + + +def _farthest_point_sample_yanx27(xyz, npoint): + device = xyz.device + B, N, C = xyz.shape + centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) + distance = torch.ones(B, N).to(device) * 1e10 + farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) + batch_indices = torch.arange(B, dtype=torch.long).to(device) + for i in range(npoint): + centroids[:, i] = farthest + centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) + dist = torch.sum((xyz - centroid) ** 2, -1) + mask = dist < distance + distance[mask] = dist[mask] + farthest = torch.max(distance, -1)[1] + return centroids + + +def query_ball_point(radius, nsample, xyz, new_xyz, _method="pointops"): + """ + Input: + radius: local region radius + nsample: max sample number in local region + xyz: all points, [B, N, 3] + new_xyz: query points, [B, S, 3] + Return: + group_idx: grouped points index, [B, S, nsample] + """ + if _method == "pointops": + from pointops.functions import pointops + + idx = pointops.ballquery(radius, nsample, xyz, new_xyz).long() + return idx + + if _method == "pytorch3d": + from pytorch3d.ops import ball_query + + dists, idx, neighbors = ball_query( + new_xyz, xyz, K=nsample, radius=radius, return_nn=False + ) + return idx + + if _method == "yanx27": + return _query_ball_point_yanx27(radius, nsample, xyz, new_xyz) + + raise ValueError(f"Unknown method {_method}") + + +def _query_ball_point_yanx27(radius, nsample, xyz, new_xyz): + device = xyz.device + B, N, C = xyz.shape + _, S, _ = new_xyz.shape + group_idx = ( + torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) + ) + sqrdists = square_distance(new_xyz, xyz) + group_idx[sqrdists > radius**2] = N + group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] + group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) + mask = group_idx == N + group_idx[mask] = group_first[mask] + return group_idx + + +def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False): + """ + Input: + npoint: + radius: + nsample: + xyz: input points position data, [B, N, 3] + points: input points data, [B, N, D] + Return: + new_xyz: sampled points position data, [B, npoint, nsample, 3] + new_points: sampled points data, [B, npoint, nsample, 3+D] + """ + B, N, C = xyz.shape + S = npoint + fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint] + new_xyz = index_points(xyz, fps_idx) + idx = query_ball_point(radius, nsample, xyz, new_xyz) + grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C] + grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C) + + if points is not None: + grouped_points = index_points(points, idx) + new_points = torch.cat( + [grouped_xyz_norm, grouped_points], dim=-1 + ) # [B, npoint, nsample, C+D] + else: + new_points = grouped_xyz_norm + if returnfps: + return new_xyz, new_points, grouped_xyz, fps_idx + else: + return new_xyz, new_points + + +def sample_and_group_all(xyz, points, returnfps=False): + """ + Input: + xyz: input points position data, [B, N, 3] + points: input points data, [B, N, D] + Return: + new_xyz: sampled points position data, [B, 1, 3] + new_points: sampled points data, [B, 1, N, 3+D] + """ + device = xyz.device + B, N, C = xyz.shape + fps_idx = torch.zeros(B, 1, dtype=torch.long).to(device) + new_xyz = torch.zeros(B, 1, C).to(device) + grouped_xyz = xyz.view(B, 1, N, C) + if points is not None: + new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1) + else: + new_points = grouped_xyz + if returnfps: + return new_xyz, new_points, grouped_xyz, fps_idx + else: + return new_xyz, new_points + + +class PointNetSetAbstraction(nn.Module): + def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all): + super(PointNetSetAbstraction, self).__init__() + self.npoint = npoint + self.radius = radius + self.nsample = nsample + self.mlp_convs = nn.ModuleList() + self.mlp_bns = nn.ModuleList() + last_channel = in_channel + for out_channel in mlp: + self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1)) + self.mlp_bns.append(nn.BatchNorm2d(out_channel)) + last_channel = out_channel + self.group_all = group_all + + def forward(self, xyz, points): + """ + Input: + xyz: input points position data, [B, C, N] + points: input points data, [B, D, N] + Return: + new_xyz: sampled points position data, [B, C, S] + new_points_concat: sample points feature data, [B, D', S] + """ + B, C, N = xyz.shape + xyz = xyz.permute(0, 2, 1) + + if points is not None: + _, D, _ = points.shape + points = points.permute(0, 2, 1) + else: + D = 0 + + if self.group_all: + new_xyz, grouped_points, grouped_xyz, idx = sample_and_group_all( + xyz, points, returnfps=True + ) + npoint = 1 + nsample = N + else: + new_xyz, grouped_points, grouped_xyz, idx = sample_and_group( + self.npoint, self.radius, self.nsample, xyz, points, returnfps=True + ) + npoint = self.npoint + nsample = self.nsample + + assert grouped_xyz.shape == (B, npoint, nsample, C) + assert grouped_points.shape == (B, npoint, nsample, C + D) + assert new_xyz.shape == (B, npoint, C) + assert idx.shape == (B, npoint) + + grouped_xyz = grouped_xyz.permute(0, 3, 2, 1) # [B, C, npoint, nsample] + grouped_points = grouped_points.permute(0, 3, 1, 2) # [B, C+D, npoint, nsample] + new_xyz = new_xyz.permute(0, 2, 1) # [B, C, npoint] + + new_points = grouped_points.permute(0, 1, 3, 2) # [B, C+D, nsample, npoint] + for i, conv in enumerate(self.mlp_convs): + bn = self.mlp_bns[i] + new_points = F.relu(bn(conv(new_points))) + new_points = torch.max(new_points, 2)[0] + + # NOTE: "points" is misleading, so it's relabeled to "features" instead. + return { + "grouped_xyz": grouped_xyz, # [B, C, npoint, nsample] + "grouped_features": grouped_points, # [B, C+D, npoint, nsample] + "new_xyz": new_xyz, # [B, C, npoint] + "new_features": new_points, # [B, C+D, npoint] + "idx": idx, # [B, npoint] + } + + +class PointNetSetAbstractionMsg(nn.Module): + def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list): + super(PointNetSetAbstractionMsg, self).__init__() + self.npoint = npoint + self.radius_list = radius_list + self.nsample_list = nsample_list + self.conv_blocks = nn.ModuleList() + self.bn_blocks = nn.ModuleList() + for i in range(len(mlp_list)): + convs = nn.ModuleList() + bns = nn.ModuleList() + last_channel = in_channel # + 3 + for out_channel in mlp_list[i]: + convs.append(nn.Conv2d(last_channel, out_channel, 1)) + bns.append(nn.BatchNorm2d(out_channel)) + last_channel = out_channel + self.conv_blocks.append(convs) + self.bn_blocks.append(bns) + + def forward(self, xyz, points): + """ + Input: + xyz: input points position data, [B, C, N] + points: input points data, [B, D, N] + Return: + new_xyz: sampled points position data, [B, C, S] + new_points_concat: sample points feature data, [B, D', S] + """ + xyz = xyz.permute(0, 2, 1) + if points is not None: + points = points.permute(0, 2, 1) + + B, N, C = xyz.shape + S = self.npoint + new_xyz = index_points(xyz, farthest_point_sample(xyz, S)) + new_points_list = [] + for i, radius in enumerate(self.radius_list): + K = self.nsample_list[i] + group_idx = query_ball_point(radius, K, xyz, new_xyz) + grouped_xyz = index_points(xyz, group_idx) + grouped_xyz -= new_xyz.view(B, S, 1, C) + if points is not None: + grouped_points = index_points(points, group_idx) + grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1) + else: + grouped_points = grouped_xyz + + grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S] + for j in range(len(self.conv_blocks[i])): + conv = self.conv_blocks[i][j] + bn = self.bn_blocks[i][j] + grouped_points = F.relu(bn(conv(grouped_points))) + new_points = torch.max(grouped_points, 2)[0] # [B, D', S] + new_points_list.append(new_points) + + new_xyz = new_xyz.permute(0, 2, 1) + new_points_concat = torch.cat(new_points_list, dim=1) + return new_xyz, new_points_concat + + +class PointNetFeaturePropagation(nn.Module): + def __init__(self, in_channel, mlp): + super(PointNetFeaturePropagation, self).__init__() + self.mlp_convs = nn.ModuleList() + self.mlp_bns = nn.ModuleList() + last_channel = in_channel + for out_channel in mlp: + self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1)) + self.mlp_bns.append(nn.BatchNorm1d(out_channel)) + last_channel = out_channel + + def forward(self, xyz1, xyz2, points1, points2): + """ + Input: + xyz1: input points position data, [B, C, N] + xyz2: sampled input points position data, [B, C, S] + points1: input points data, [B, D, N] + points2: input points data, [B, D, S] + Return: + new_points: upsampled points data, [B, D', N] + """ + xyz1 = xyz1.permute(0, 2, 1) + xyz2 = xyz2.permute(0, 2, 1) + + points2 = points2.permute(0, 2, 1) + B, N, C = xyz1.shape + _, S, _ = xyz2.shape + + if S == 1: + interpolated_points = points2.repeat(1, N, 1) + else: + dists = square_distance(xyz1, xyz2) + dists, idx = dists.sort(dim=-1) + dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3] + + dist_recip = 1.0 / (dists + 1e-8) + norm = torch.sum(dist_recip, dim=2, keepdim=True) + weight = dist_recip / norm + interpolated_points = torch.sum( + index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2 + ) + + if points1 is not None: + points1 = points1.permute(0, 2, 1) + new_points = torch.cat([points1, interpolated_points], dim=-1) + else: + new_points = interpolated_points + + new_points = new_points.permute(0, 2, 1) + for i, conv in enumerate(self.mlp_convs): + bn = self.mlp_bns[i] + new_points = F.relu(bn(conv(new_points))) + return new_points diff --git a/compressai/layers/pointcloud/pointnet2_sfu.py b/compressai/layers/pointcloud/pointnet2_sfu.py new file mode 100644 index 00000000..4bfad7ef --- /dev/null +++ b/compressai/layers/pointcloud/pointnet2_sfu.py @@ -0,0 +1,57 @@ +# Copyright (c) 2021-2022, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import torch.nn as nn + +from torch import Tensor + +from compressai.layers.basic import Interleave, Reshape, Transpose + + +class UpsampleBlock(nn.Module): + def __init__(self, D, E, M, P, S, i, extra_in_ch=3, groups=(1, 1)): + super().__init__() + self.block = nn.Sequential( + nn.Conv1d( + E[i + 1] + (D[i] + extra_in_ch) * bool(M[i]), D[i], 1, groups=groups[0] + ), + Interleave(groups=groups[0]), + nn.BatchNorm1d(D[i]), + nn.ReLU(inplace=True), + nn.Conv1d(D[i], E[i] * S[i], 1, groups=groups[1]), + Interleave(groups=groups[1]), + nn.BatchNorm1d(E[i] * S[i]), + nn.ReLU(inplace=True), + Reshape((E[i], S[i], P[i])), + Transpose(-2, -1), + Reshape((E[i], P[i] * S[i])), + ) + + def forward(self, x: Tensor) -> Tensor: + return self.block(x) diff --git a/compressai/layers/pointcloud/utils.py b/compressai/layers/pointcloud/utils.py new file mode 100644 index 00000000..cb933837 --- /dev/null +++ b/compressai/layers/pointcloud/utils.py @@ -0,0 +1,244 @@ +# Copyright (c) 2021-2022, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from math import ceil + +import torch + +from pointops.functions import pointops + + +def index_points(xyzs, idx): + """Index points. + + Args: + xyzs: (b, c, n) + idx: (b, ...) + + Returns: + xyzs_out: (b, c, ...) + """ + _, c, _ = xyzs.shape + b, *idx_dims = idx.shape + idx_out = idx.reshape(b, 1, -1).repeat(1, c, 1) + xyzs_out = xyzs.gather(2, idx_out).reshape(b, c, *idx_dims) + return xyzs_out + + +def select_xyzs_and_feats( + candidate_xyzs, + candidate_feats, + upsample_num, + upsample_rate=None, + method="batch_loop", +): + """Selects subset of points to match predicted local point cloud densities. + + Args: + candidate_xyzs: (b, 3, n, s) + candidate_feats: (b, c, n, s) + upsample_num: (b, n) + upsample_rate: Maximum number of points per group. + method: "batch_loop" or "batch_noloop". + + Returns: + xyzs: (b, 3, m) + feats: (b, c, m) + """ + device = candidate_xyzs.device + b, c, n, s = candidate_feats.shape + + # Prefer faster method, if applicable. + if b == 1 and upsample_rate is None: + xyzs, feats = _select_xyzs_and_feats_single( + candidate_xyzs, candidate_feats, upsample_num + ) + return xyzs, feats + + # Select upsample_num points, then resample to a fixed number of points. + if method == "batch_loop": + max_points = ceil(n * upsample_rate) + xyzs = [] + feats = [] + + for i in range(b): + xyzs_i, feats_i = _select_xyzs_and_feats_single( + candidate_xyzs[[i]], candidate_feats[[i]], upsample_num[[i]] + ) + xyzs_i, feats_i = resample_points(xyzs_i, feats_i, max_points) + xyzs.append(xyzs_i) + feats.append(feats_i) + + xyzs = torch.cat(xyzs) + feats = torch.cat(feats) + + # Sometimes a bit faster than batch_loop. + if method == "batch_noloop": + upsample_num = upsample_num.round().long().clip(1, s) + + # Select upsample_num points from each group. + idx = torch.arange(s, device=device).repeat(b, n, 1) + mask = idx < upsample_num.unsqueeze(-1) + + # Initialize random permutation. + idx = randperm((b, n, s), device=device, dim=-1) + + # Flatten point groups. + idx += torch.arange(n, device=device).view(1, -1, 1) * s + idx = idx.view(b, -1) + mask = mask.view(b, n * s) + + # Reorder selected points from various groups to the beginning. + # TODO(perf): A batch-wise .nonzero() could be faster. + perm = mask.to(torch.uint8).argsort(dim=-1, descending=True) + idx = idx.gather(-1, perm) + + # NOTE(perf): Using consistent shapes is much faster on the GPU, + # so specifying upsample_rate is preferred. + max_points = ( + mask.sum(dim=-1).max().item() + if upsample_rate is None + else ceil(n * upsample_rate) + ) + + # Reduce dimensionality to maximum number of points. + idx = idx[..., :max_points] + + # Cycle selected points if there are not enough. + idx, _, _ = cycle_after(idx, mask.sum(dim=-1)) + + # Shuffle points (usually not necessary). + idx = idx.gather(-1, randperm(idx.shape, device=device, dim=-1)) + + xyzs = index_points(candidate_xyzs.view(b, 3, -1), idx) + feats = index_points(candidate_feats.view(b, c, -1), idx) + + return xyzs, feats + + +def _select_xyzs_and_feats_single(candidate_xyzs, candidate_feats, upsample_num): + # candidate_xyzs: (b, 3, n, max_upsample_num) + # candidate_feats: (b, c, n, max_upsample_num) + # upsample_num: (b, n) + # mask: (n*max_upsample_num) + # idx: (b, m) + # xyzs: (b, 3, m) + # feats: (b, c, m) + + batch_size, _, points_num, max_upsample_num = candidate_xyzs.shape + assert batch_size == 1 + + # Create mask denoting the first upsample_num points per group: + upsample_num = upsample_num.round().long().squeeze(0).view(-1, 1) + mask = torch.arange(max_upsample_num).cuda().view(1, -1).repeat(points_num, 1) + mask = (mask < upsample_num).view(-1) + + # Convert mask to indices: + [idx] = mask.nonzero(as_tuple=True) + idx = idx.unsqueeze(0) + + # Select the first upsample_num xyzs and feats: + xyzs = index_points(candidate_xyzs.view(*candidate_xyzs.shape[:2], -1), idx) + feats = index_points(candidate_feats.view(*candidate_feats.shape[:2], -1), idx) + + return xyzs, feats + + +def resample_points(xyzs, feats, num_points): + """Resample points to a target number. + + Args: + xyzs: (b, 3, n) + feats: (b, c, n) + + Returns: + new_xyzs: (b, 3, num_points) + new_feats: (b, c, num_points) + """ + b, _, n = xyzs.shape + device = xyzs.device + assert b == 1 + + if n == num_points: + return xyzs, feats + + # Subsample points if there are too many. + if n > num_points: + xyzs_tr = xyzs.permute(0, 2, 1).contiguous() + idx = pointops.furthestsampling(xyzs_tr, num_points).long() + + # Repeat and create randomly duplicated points if there are not enough. + elif n < num_points: + idx_repeated = torch.arange(n, device=device).repeat(num_points // n) + idx_random = ( + torch.multinomial(torch.ones(n, device=device), num_points % n) + if num_points % n > 0 + else torch.arange(0, device=device) + ) + idx = torch.cat((idx_repeated, idx_random)) + + # Shuffle; probably unnecessary: + perm = torch.randperm(len(idx), device=device) + idx = idx[perm] + + idx = idx.reshape(1, -1) + new_xyzs = index_points(xyzs, idx) + new_feats = index_points(feats, idx) + + return new_xyzs, new_feats + + +def randperm(shape, device=None, dim=-1): + """Random permutation, like `torch.randperm`, but with a shape.""" + if dim != -1: + raise NotImplementedError + idx = torch.rand(shape, device=device).argsort(dim=dim) + return idx + + +def cycle_after(x, end): + """Cycle tensor after a given index. + + Example: + + .. code-block:: python + + >>> x = torch.tensor([[5, 0, 7, 6, 2], [3, 1, 4, 8, 9]]) + >>> end = torch.tensor([2, 3]) + >>> idx, _, _ = cycle_after(x, end) + >>> idx + tensor([[5, 0, 5, 0, 5], [3, 1, 4, 3, 1]]) + """ + *dims, n = x.shape + assert end.shape == tuple(dims) + idx = torch.arange(n, device=x.device).repeat(*dims, 1) + mask = idx >= end.unsqueeze(-1) + idx[mask] %= end.unsqueeze(-1).repeat([1] * len(dims) + [n])[mask] + x = x.gather(-1, idx) + return x, idx, ~mask From 8ac44b948a20ef35efa4c8b2b30922a0a01a3ef9 Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Tue, 30 Jan 2024 22:40:43 -0800 Subject: [PATCH 16/27] feat: point cloud compression models --- compressai/models/__init__.py | 1 + compressai/models/pointcloud/__init__.py | 32 ++ compressai/models/pointcloud/hrtzxf2022.py | 453 ++++++++++++++++++ compressai/models/pointcloud/sfu_pointnet.py | 133 +++++ compressai/models/pointcloud/sfu_pointnet2.py | 333 +++++++++++++ docs/source/models.rst | 17 + 6 files changed, 969 insertions(+) create mode 100644 compressai/models/pointcloud/__init__.py create mode 100644 compressai/models/pointcloud/hrtzxf2022.py create mode 100644 compressai/models/pointcloud/sfu_pointnet.py create mode 100644 compressai/models/pointcloud/sfu_pointnet2.py diff --git a/compressai/models/__init__.py b/compressai/models/__init__.py index 94876ccb..b817633c 100644 --- a/compressai/models/__init__.py +++ b/compressai/models/__init__.py @@ -29,5 +29,6 @@ from .base import * from .google import * +from .pointcloud import * from .sensetime import * from .waseda import * diff --git a/compressai/models/pointcloud/__init__.py b/compressai/models/pointcloud/__init__.py new file mode 100644 index 00000000..da566a42 --- /dev/null +++ b/compressai/models/pointcloud/__init__.py @@ -0,0 +1,32 @@ +# Copyright (c) 2021-2022, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from .hrtzxf2022 import * +from .sfu_pointnet import * +from .sfu_pointnet2 import * diff --git a/compressai/models/pointcloud/hrtzxf2022.py b/compressai/models/pointcloud/hrtzxf2022.py new file mode 100644 index 00000000..1c8da845 --- /dev/null +++ b/compressai/models/pointcloud/hrtzxf2022.py @@ -0,0 +1,453 @@ +# Copyright (c) 2021-2022, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# Code adapted from https://github.com/yunhe20/D-PCC + +from __future__ import annotations + +import numpy as np +import torch +import torch.nn as nn + +from compressai.entropy_models import EntropyBottleneck +from compressai.latent_codecs import EntropyBottleneckLatentCodec +from compressai.layers.pointcloud.hrtzxf2022 import ( + DownsampleLayer, + EdgeConv, + RefineLayer, + UpsampleLayer, + UpsampleNumLayer, + nearby_distance_sum, +) +from compressai.layers.pointcloud.utils import select_xyzs_and_feats +from compressai.models import CompressionModel +from compressai.registry import register_model + +__all__ = [ + "DensityPreservingReconstructionPccModel", +] + + +@register_model("hrtzxf2022-pcc-rec") +class DensityPreservingReconstructionPccModel(CompressionModel): + """Density-preserving deep point cloud compression. + + Model introduced by [He2022pcc]_. + + References: + + .. [He2022pcc] `"Density-preserving Deep Point Cloud Compression" + `_, by Yun He, Xinlin Ren, + Danhang Tang, Yinda Zhang, Xiangyang Xue, and Yanwei Fu, + CVPR 2022. + """ + + def __init__( + self, + downsample_rate=(1 / 3, 1 / 3, 1 / 3), + candidate_upsample_rate=(8, 8, 8), + in_dim=3, + feat_dim=8, + hidden_dim=64, + k=16, + ngroups=1, + sub_point_conv_mode="mlp", + compress_normal=False, + latent_xyzs_codec=None, + **kwargs, + ): + super().__init__() + + self.compress_normal = compress_normal + + self.pre_conv = nn.Sequential( + nn.Conv1d(in_dim, hidden_dim, 1), + nn.GroupNorm(ngroups, hidden_dim), + nn.ReLU(), + nn.Conv1d(hidden_dim, feat_dim, 1), + ) + + self.encoder = Encoder( + downsample_rate, + feat_dim, + hidden_dim, + k, + ngroups, + ) + + self.decoder = Decoder( + downsample_rate, + candidate_upsample_rate, + feat_dim, + hidden_dim, + k, + sub_point_conv_mode, + compress_normal, + ) + + self.latent_codec = nn.ModuleDict( + { + "feat": EntropyBottleneckLatentCodec(channels=feat_dim), + "xyz": XyzsLatentCodec( + feat_dim, hidden_dim, k, ngroups, **(latent_xyzs_codec or {}) + ), + } + ) + + def _prepare_input(self, input): + input_data = [input["pos"]] + if self.compress_normal: + input_data.append(input["normal"]) + input_data = torch.cat(input_data, dim=1).permute(0, 2, 1).contiguous() + + xyzs = input_data[:, :3].contiguous() + gt_normals = input_data[:, 3 : 3 + 3 * self.compress_normal].contiguous() + feats = input_data + + return xyzs, gt_normals, feats + + def forward(self, input): + # xyzs: (b, 3, n) + + xyzs, gt_normals, feats = self._prepare_input(input) + + feats = self.pre_conv(feats) + + gt_xyzs_, gt_dnums_, gt_mdis_, latent_xyzs, latent_feats = self.encoder( + xyzs, feats + ) + + gt_latent_xyzs = latent_xyzs + + # NOTE: Temporarily reshape to (b, c, m, 1) for compatibility. + latent_feats = latent_feats.unsqueeze(-1) + latent_feats_out = self.latent_codec["feat"](latent_feats) + latent_feats_hat = latent_feats_out["y_hat"].squeeze(-1) + + latent_xyzs_out = self.latent_codec["xyz"](latent_xyzs) + latent_xyzs_hat = latent_xyzs_out["y_hat"] + + xyzs_hat_, unums_hat_, mdis_hat_, feats_hat = self.decoder( + latent_xyzs_hat, latent_feats_hat + ) + + # Permute final xyzs_hat back to (b, n, c) + xyzs_hat = xyzs_hat_[-1].permute(0, 2, 1).contiguous() + + return { + "x_hat": xyzs_hat, + "xyz_hat_": xyzs_hat_, + "latent_xyz_hat": latent_xyzs_hat, + "feat_hat": feats_hat, + "upsample_num_hat_": unums_hat_, + "mean_distance_hat_": mdis_hat_, + "gt_xyz_": gt_xyzs_, + "gt_latent_xyz": gt_latent_xyzs, + "gt_normal": gt_normals, + "gt_downsample_num_": gt_dnums_, + "gt_mean_distance_": gt_mdis_, + "likelihoods": { + "latent_feat": latent_feats_out["likelihoods"]["y"], + "latent_xyz": latent_xyzs_out["likelihoods"]["y"], + }, + } + + def compress(self, input): + xyzs, _, feats = self._prepare_input(input) + + feats = self.pre_conv(feats) + + _, _, _, latent_xyzs, latent_feats = self.encoder(xyzs, feats) + + latent_feats = latent_feats.unsqueeze(-1) + latent_feats_out = self.latent_codec["feat"].compress(latent_feats) + + latent_xyzs = latent_xyzs + latent_xyzs_out = self.latent_codec["xyz"].compress(latent_xyzs) + + return { + "strings": [ + latent_feats_out["strings"], + latent_xyzs_out["strings"], + ], + "shape": [ + latent_feats_out["shape"], + latent_xyzs_out["shape"], + ], + } + + def decompress(self, strings, shape): + assert isinstance(strings, list) and len(strings) == 2 + + latent_feats_out = self.latent_codec["feat"].decompress(strings[0], shape[0]) + latent_feats_hat = latent_feats_out["y_hat"].squeeze(-1) + + latent_xyzs_out = self.latent_codec["xyz"].decompress(strings[1], shape[1]) + latent_xyzs_hat = latent_xyzs_out["y_hat"] + + xyzs_hat_, _, _, feats_hat = self.decoder(latent_xyzs_hat, latent_feats_hat) + + # Permute final xyzs_hat back to (b, n, c) + xyzs_hat = xyzs_hat_[-1].permute(0, 2, 1).contiguous() + + return { + "x_hat": xyzs_hat, + "feat_hat": feats_hat, + } + + +class XyzsLatentCodec(nn.Module): + def __init__(self, dim, hidden_dim, k, ngroups, mode="learned", conv_mode="mlp"): + super().__init__() + self.mode = mode + if mode == "learned": + if conv_mode == "edge_conv": + self.analysis = EdgeConv(3, dim, hidden_dim, k) + self.synthesis = EdgeConv(dim, 3, hidden_dim, k) + elif conv_mode == "mlp": + self.analysis = nn.Sequential( + nn.Conv1d(3, hidden_dim, 1), + nn.GroupNorm(ngroups, hidden_dim), + nn.ReLU(inplace=True), + nn.Conv1d(hidden_dim, dim, 1), + ) + self.synthesis = nn.Sequential( + nn.Conv1d(dim, hidden_dim, 1), + nn.GroupNorm(ngroups, hidden_dim), + nn.ReLU(inplace=True), + nn.Conv1d(hidden_dim, 3, 1), + ) + else: + raise ValueError(f"Unknown conv_mode: {conv_mode}") + self.entropy_bottleneck = EntropyBottleneck(dim) + else: + self.placeholder = nn.Parameter(torch.empty(0)) + + def forward(self, latent_xyzs): + if self.mode == "learned": + z = self.analysis(latent_xyzs) + z_hat, z_likelihoods = self.entropy_bottleneck(z) + latent_xyzs_hat = self.synthesis(z_hat) + elif self.mode == "float16": + z_likelihoods = latent_xyzs.new_full(latent_xyzs.shape, 2**-16) + latent_xyzs_hat = latent_xyzs.to(torch.float16).float() + else: + raise ValueError(f"Unknown mode: {self.mode}") + return {"likelihoods": {"y": z_likelihoods}, "y_hat": latent_xyzs_hat} + + def compress(self, latent_xyzs): + if self.mode == "learned": + z = self.analysis(latent_xyzs) + shape = z.shape[2:] + z_strings = self.entropy_bottleneck.compress(z) + z_hat = self.entropy_bottleneck.decompress(z_strings, shape) + latent_xyzs_hat = self.synthesis(z_hat) + elif self.mode == "float16": + z = latent_xyzs + shape = z.shape[2:] + z_hat = latent_xyzs.to(torch.float16) + z_strings = [ + np.ascontiguousarray(x, dtype=">f2").tobytes() + for x in z_hat.cpu().numpy() + ] + latent_xyzs_hat = z_hat.float() + else: + raise ValueError(f"Unknown mode: {self.mode}") + return {"strings": [z_strings], "shape": shape, "y_hat": latent_xyzs_hat} + + def decompress(self, strings, shape): + [z_strings] = strings + if self.mode == "learned": + z_hat = self.entropy_bottleneck.decompress(z_strings, shape) + latent_xyzs_hat = self.synthesis(z_hat) + elif self.mode == "float16": + z_hat = [np.frombuffer(s, dtype=">f2").reshape(shape) for s in z_strings] + z_hat = torch.from_numpy(np.stack(z_hat)).to(self.placeholder.device) + latent_xyzs_hat = z_hat.float() + else: + raise ValueError(f"Unknown mode: {self.mode}") + return {"y_hat": latent_xyzs_hat} + + +class Encoder(nn.Module): + def __init__(self, downsample_rate, dim, hidden_dim, k, ngroups): + super().__init__() + downsample_layers = [ + DownsampleLayer(downsample_rate[i], dim, hidden_dim, k, ngroups) + for i in range(len(downsample_rate)) + ] + self.downsample_layers = nn.ModuleList(downsample_layers) + + def forward(self, xyzs, feats): + # xyzs: (b, 3, n) + # feats: (b, c, n) + + gt_xyzs_ = [] + gt_dnums_ = [] + gt_mdis_ = [] + + for downsample_layer in self.downsample_layers: + gt_xyzs_.append(xyzs) + xyzs, feats, downsample_num, mean_distance = downsample_layer(xyzs, feats) + gt_dnums_.append(downsample_num) + gt_mdis_.append(mean_distance) + + latent_xyzs = xyzs + latent_feats = feats + + return gt_xyzs_, gt_dnums_, gt_mdis_, latent_xyzs, latent_feats + + +class Decoder(nn.Module): + def __init__( + self, + downsample_rate, + candidate_upsample_rate, + dim, + hidden_dim, + k, + sub_point_conv_mode, + compress_normal, + ): + super().__init__() + + self.k = k + self.compress_normal = compress_normal + self.num_layers = len(downsample_rate) + self.downsample_rate = downsample_rate + + self.upsample_layers = nn.ModuleList( + [ + UpsampleLayer( + dim, + hidden_dim, + k, + sub_point_conv_mode, + candidate_upsample_rate[i], + ) + for i in range(self.num_layers) + ] + ) + + self.upsample_num_layers = nn.ModuleList( + [ + UpsampleNumLayer( + dim, + hidden_dim, + candidate_upsample_rate[i], + ) + for i in range(self.num_layers) + ] + ) + + self.refine_layers = nn.ModuleList( + [ + RefineLayer( + dim, + hidden_dim, + k, + sub_point_conv_mode, + compress_normal and i == self.num_layers - 1, + ) + for i in range(self.num_layers) + ] + ) + + def forward(self, xyzs, feats): + # xyzs: (b, 3, n) + # feats: (b, c, n) + + latent_xyzs = xyzs.clone() + + xyzs_hat_ = [] + unums_hat_ = [] + + for i, (upsample_nn, upsample_num_nn, refine_nn) in enumerate( + zip(self.upsample_layers, self.upsample_num_layers, self.refine_layers) + ): + # candidate_xyzs: (b, 3, n u) + # candidate_feats: (b, c, n u) + # upsample_num: (b, n) + # xyzs: (b, 3, m) [after upsample and select] + # feats: (b, c, m) [after upsample and select] + + # For each point within the current set of "n" points, + # upsample a fixed number "u" of candidate points. + # The resulting candidate points have the shape (n, u). + candidate_xyzs, candidate_feats = upsample_nn(xyzs, feats) + + # Determine local point cloud density near each upsampled group: + upsample_num = upsample_num_nn(feats) + + # Subsample each point group to match the desired local density. + # That is, from the i-th point group, select upsample_num[..., i] points. + # Then, collect all the points so the resulting point set has shape (m_i,). + # + # If the batch size is >1, then the "m_i"s may be different. + # In that case, resample each point set within the batch + # until they all have the same shape (m,). + # This can be done by either selecting a subset or + # duplicating points as necessary. + # + # Since one of the goals is to reduce local point cloud + # density in certain regions, we are happy with throwing + # away distinct points, and then duplicating the remaining + # points until they can fit within the desired tensor shape. + + # Select subset of points to match predicted local point cloud densities: + xyzs, feats = select_xyzs_and_feats( + candidate_xyzs, + candidate_feats, + upsample_num, + upsample_rate=1 / self.downsample_rate[self.num_layers - i - 1], + ) + + # Refine upsampled points. + xyzs, feats = refine_nn(xyzs, feats) + + xyzs_hat_.append(xyzs) + unums_hat_.append(upsample_num) + + # Compute mean distance between centroids and the upsampled points. + mdis_hat_ = self.get_pred_mdis([latent_xyzs, *xyzs_hat_], unums_hat_) + + return xyzs_hat_, unums_hat_, mdis_hat_, feats + + def get_pred_mdis(self, xyzs_hat_, unums_hat_): + mdis_hat_ = [] + + for prev_xyzs, curr_xyzs, curr_unums in zip( + xyzs_hat_[:-1], xyzs_hat_[1:], unums_hat_ + ): + # Compute mean distance for each point in "prev" to upsampled "curr". + distance, _, _, _ = nearby_distance_sum(prev_xyzs, curr_xyzs, self.k) + curr_mdis = distance / curr_unums + mdis_hat_.append(curr_mdis) + + return mdis_hat_ diff --git a/compressai/models/pointcloud/sfu_pointnet.py b/compressai/models/pointcloud/sfu_pointnet.py new file mode 100644 index 00000000..2e221c14 --- /dev/null +++ b/compressai/models/pointcloud/sfu_pointnet.py @@ -0,0 +1,133 @@ +# Copyright (c) 2021-2022, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +from compressai.latent_codecs import LatentCodec +from compressai.latent_codecs.entropy_bottleneck import EntropyBottleneckLatentCodec +from compressai.layers.pointcloud.pointnet import ( + pointnet_g_a_simple, + pointnet_g_s_simple, +) +from compressai.models import CompressionModel +from compressai.registry import register_model + +__all__ = [ + "PointNetReconstructionPccModel", +] + + +@register_model("sfu2023-pcc-rec-pointnet") +class PointNetReconstructionPccModel(CompressionModel): + """PointNet-based PCC reconstruction model. + + Model based on PointNet [Qi2017PointNet]_, modified for compression + by [Yan2019]_, with layer configurations and other modifications as + used in [Ulhaq2023]_. + + References: + + .. [Qi2017PointNet] `"PointNet: Deep Learning on Point Sets for + 3D Classification and Segmentation" + `_, by Charles R. Qi, + Hao Su, Kaichun Mo, and Leonidas J. Guibas, CVPR 2017. + + .. [Yan2019] `"Deep AutoEncoder-based Lossy Geometry Compression + for Point Clouds" `_, + by Wei Yan, Yiting Shao, Shan Liu, Thomas H Li, Zhu Li, + and Ge Li, 2019. + + .. [Ulhaq2023] `"Learned Point Cloud Compression for + Classification" `_, + by Mateen Ulhaq and Ivan V. Bajić, MMSP 2023. + """ + + latent_codec: LatentCodec + + def __init__( + self, + num_points=1024, + num_channels={ # noqa: B006 + "g_a": [3, 64, 64, 64, 128, 1024], + "g_s": [1024, 256, 512, 1024 * 3], + }, + groups={ # noqa: B006 + "g_a": [1, 1, 1, 1, 1], + }, + ): + super().__init__() + + assert num_channels["g_a"][-1] == num_channels["g_s"][0] + assert num_channels["g_s"][-1] == num_points * 3 + + self.g_a = pointnet_g_a_simple(num_channels["g_a"], groups["g_a"]) + + self.g_s = pointnet_g_s_simple(num_channels["g_s"]) + + self.latent_codec = EntropyBottleneckLatentCodec( + channels=num_channels["g_a"][-1], + tail_mass=1e-4, + ) + + def forward(self, input): + x = input["pos"] + x_t = x.transpose(-2, -1) + y = self.g_a(x_t) + y_out = self.latent_codec(y) + y_hat = y_out["y_hat"] + x_hat = self.g_s(y_hat) + assert x_hat.shape == x.shape + + return { + "x_hat": x_hat, + "likelihoods": { + "y": y_out["likelihoods"]["y"], + }, + # Additional outputs: + "y": y, + "y_hat": y_hat, + "debug_outputs": { + "y_hat": y_hat, + }, + } + + def compress(self, input): + x = input["pos"] + x_t = x.transpose(-2, -1) + y = self.g_a(x_t) + y_out = self.latent_codec.compress(y) + [y_strings] = y_out["strings"] + return {"strings": [y_strings], "shape": (1,)} + + def decompress(self, strings, shape): + assert isinstance(strings, list) and len(strings) == 1 + [y_strings] = strings + y_hat = self.latent_codec.decompress([y_strings], shape) + x_hat = self.g_s(y_hat) + return {"x_hat": x_hat} diff --git a/compressai/models/pointcloud/sfu_pointnet2.py b/compressai/models/pointcloud/sfu_pointnet2.py new file mode 100644 index 00000000..ffee09ca --- /dev/null +++ b/compressai/models/pointcloud/sfu_pointnet2.py @@ -0,0 +1,333 @@ +# Copyright (c) 2021-2022, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +import torch +import torch.nn as nn + +from compressai.latent_codecs import EntropyBottleneckLatentCodec +from compressai.layers.basic import Gain, Interleave, Reshape, Transpose +from compressai.layers.pointcloud.pointnet import GAIN +from compressai.layers.pointcloud.pointnet2 import PointNetSetAbstraction +from compressai.layers.pointcloud.pointnet2_sfu import UpsampleBlock +from compressai.models import CompressionModel +from compressai.registry import register_model + +__all__ = [ + "PointNet2SsgReconstructionPccModel", +] + + +@register_model("sfu2024-pcc-rec-pointnet2-ssg") +class PointNet2SsgReconstructionPccModel(CompressionModel): + """PointNet++-based PCC reconstruction model. + + Model based on PointNet++ [Qi2017PointNetPlusPlus]_, and modified + for compression by [Ulhaq2024]_. + Uses single-scale grouping (SSG) for point set abstraction. + + References: + + .. [Qi2017PointNetPlusPlus] `"PointNet++: Deep Hierarchical + Feature Learning on Point Sets in a Metric Space" + `_, by Charles R. Qi, + Li Yi, Hao Su, and Leonidas J. Guibas, NIPS 2017. + + .. [Ulhaq2024] `"Scalable Human-Machine Point Cloud Compression" + `_, + by Mateen Ulhaq and Ivan V. Bajić, PCS 2024. + """ + + def __init__( + self, + num_points=1024, + num_classes=40, + D=(0, 128, 192, 256), + P=(1024, 256, 64, 1), + S=(None, 4, 4, 64), + R=(None, 0.2, 0.4, None), + E=(3, 64, 32, 16, 0), + M=(0, 0, 64, 64), + normal_channel=False, + ): + """ + Args: + num_points: Number of input points. [unused] + num_classes: Number of object classes. [unused] + D: Number of input feature channels. + P: Number of output points. + S: Number of samples per centroid. + R: Radius of the ball to query points within. + E: Number of output feature channels after each upsample. + M: Number of latent channels for the bottleneck. + normal_channel: Whether the input includes normals. + """ + super().__init__() + + self.num_points = num_points + self.num_classes = num_classes + self.D = D + self.P = P + self.S = S + self.R = R + self.E = E + self.M = M + self.normal_channel = bool(normal_channel) + + # Original PointNet++ architecture: + # D = [3 * self.normal_channel, 128, 256, 1024] + # P = [None, 512, 128, 1] + # S = [None, 32, 64, 128] + # R = [None, 0.2, 0.4, None] + + # NOTE: P[0] is only used to determine the number of output points. + # assert P[0] == num_points + + assert P[0] == P[1] * S[1] + assert P[1] == P[2] * S[2] + assert P[2] == P[3] * S[3] + + self.levels = 4 + + self.down = nn.ModuleDict( + { + "_1": PointNetSetAbstraction( + npoint=P[1], + radius=R[1], + nsample=S[1], + in_channel=D[0] + 3, + mlp=[D[1] // 2, D[1] // 2, D[1]], + group_all=False, + ), + "_2": PointNetSetAbstraction( + npoint=P[2], + radius=R[2], + nsample=S[2], + in_channel=D[1] + 3, + mlp=[D[1], D[1], D[2]], + group_all=False, + ), + "_3": PointNetSetAbstraction( + npoint=None, + radius=None, + nsample=None, + in_channel=D[2] + 3, + mlp=[D[2], D[3], D[3]], + group_all=True, + ), + } + ) + + i_final = self.levels - 1 + groups_h_final = 1 if D[i_final] * M[i_final] <= 2**16 else 4 + + self.h_a = nn.ModuleDict( + { + **{ + f"_{i}": nn.Sequential( + Reshape((D[i] + 3, P[i + 1] * S[i + 1])), + nn.Conv1d(D[i] + 3, M[i], 1), + Gain((M[i], 1), factor=GAIN), + ) + for i in range(self.levels - 1) + if M[i] > 0 + }, + f"_{i_final}": nn.Sequential( + Reshape((D[i_final], 1)), + nn.Conv1d(D[i_final], M[i_final], 1, groups=groups_h_final), + Interleave(groups=groups_h_final), + Gain((M[i_final], 1), factor=GAIN), + ), + } + ) + + self.h_s = nn.ModuleDict( + { + **{ + f"_{i}": nn.Sequential( + Gain((M[i], 1), factor=1 / GAIN), + nn.Conv1d(M[i], D[i] + 3, 1), + ) + for i in range(self.levels - 1) + if M[i] > 0 + }, + f"_{i_final}": nn.Sequential( + Gain((M[i_final], 1), factor=1 / GAIN), + nn.Conv1d(M[i_final], D[i_final], 1, groups=groups_h_final), + Interleave(groups=groups_h_final), + ), + } + ) + + self.up = nn.ModuleDict( + { + "_0": nn.Sequential( + nn.Conv1d(E[1] + D[0] + 3 * bool(M[0]), E[1], 1), + # nn.BatchNorm1d(E[1]), + nn.ReLU(inplace=True), + nn.Conv1d(E[1], E[0], 1), + Reshape((E[0], P[0])), + Transpose(-2, -1), + ), + "_1": UpsampleBlock(D, E, M, P, S, i=1, extra_in_ch=3, groups=(1, 4)), + "_2": UpsampleBlock(D, E, M, P, S, i=2, extra_in_ch=3, groups=(1, 4)), + "_3": UpsampleBlock(D, E, M, P, S, i=3, extra_in_ch=0, groups=(1, 4)), + } + ) + + self.latent_codec = nn.ModuleDict( + { + f"_{i}": EntropyBottleneckLatentCodec(channels=M[i], tail_mass=1e-4) + for i in range(self.levels) + if M[i] > 0 + } + ) + + def forward(self, input): + xyz, norm = self._get_inputs(input) + y_out_, u_, uu_ = self._compress(xyz, norm, mode="forward") + x_hat, y_hat_, v_ = self._decompress(y_out_, mode="forward") + + return { + "x_hat": x_hat, + "likelihoods": { + f"y_{i}": y_out_[i]["likelihoods"]["y"] + for i in range(self.levels) + if "likelihoods" in y_out_[i] + }, + # Additional outputs: + "debug_outputs": { + **{f"u_{i}": v for i, v in u_.items() if v is not None}, + **{f"uu_{i}": v for i, v in uu_.items()}, + **{f"y_hat_{i}": v for i, v in y_hat_.items()}, + **{f"v_{i}": v for i, v in v_.items() if v.numel() > 0}, + }, + } + + def compress(self, input): + xyz, norm = self._get_inputs(input) + y_out_, _, _ = self._compress(xyz, norm, mode="compress") + + return { + # "strings": {f"y_{i}": y_out_[i]["strings"] for i in range(self.levels)}, + # Flatten nested structure into list[list[str]]: + "strings": [ + ss for level in range(self.levels) for ss in y_out_[level]["strings"] + ], + "shape": {f"y_{i}": y_out_[i]["shape"] for i in range(self.levels)}, + } + + def decompress(self, strings, shape): + y_inputs_ = { + i: { + "strings": [strings[i]], + "shape": shape[f"y_{i}"], + } + for i in range(self.levels) + } + + x_hat, _, _ = self._decompress(y_inputs_, mode="decompress") + + return { + "x_hat": x_hat, + } + + def _get_inputs(self, input): + points = input["pos"].transpose(-2, -1) + if self.normal_channel: + xyz = points[:, :3, :] + norm = points[:, 3:, :] + else: + xyz = points + norm = None + return xyz, norm + + def _compress(self, xyz, norm, *, mode): + lc_func = {"forward": lambda lc: lc, "compress": lambda lc: lc.compress}[mode] + + B, _, _ = xyz.shape + + xyz_ = {0: xyz} + u_ = {0: norm} + uu_ = {} + y_ = {} + y_out_ = {} + + for i in range(1, self.levels): + down_out_i = self.down[f"_{i}"](xyz_[i - 1], u_[i - 1]) + xyz_[i] = down_out_i["new_xyz"] + u_[i] = down_out_i["new_features"] + uu_[i - 1] = down_out_i["grouped_features"] + + uu_[self.levels - 1] = u_[self.levels - 1][:, :, None, :] + + for i in reversed(range(0, self.levels)): + if self.M[i] == 0: + y_out_[i] = {"strings": [[b""] * B], "shape": ()} + continue + + y_[i] = self.h_a[f"_{i}"](uu_[i]) + # NOTE: Reshape 1D -> 2D since latent codecs expect 2D inputs. + y_out_[i] = lc_func(self.latent_codec[f"_{i}"])(y_[i][..., None]) + + return y_out_, u_, uu_ + + def _decompress(self, y_inputs_, *, mode): + y_hat_ = {} + y_out_ = {} + uu_hat_ = {} + v_ = {} + + for i in reversed(range(0, self.levels)): + if self.M[i] == 0: + continue + if mode == "forward": + y_out_[i] = y_inputs_[i] + elif mode == "decompress": + y_out_[i] = self.latent_codec[f"_{i}"].decompress( + y_inputs_[i]["strings"], shape=y_inputs_[i]["shape"] + ) + # NOTE: Reshape 2D -> 1D since latent codecs return 2D outputs. + y_hat_[i] = y_out_[i]["y_hat"].squeeze(-1) + uu_hat_[i] = self.h_s[f"_{i}"](y_hat_[i]) + + B, _, *tail = uu_hat_[self.levels - 1].shape + v_[self.levels] = uu_hat_[self.levels - 1].new_zeros((B, 0, *tail)) + + for i in reversed(range(0, self.levels)): + v_[i] = self.up[f"_{i}"]( + v_[i + 1] + if self.M[i] == 0 + else torch.cat([v_[i + 1], uu_hat_[i]], dim=1) + ) + + x_hat = v_[0] + + return x_hat, y_hat_, v_ diff --git a/docs/source/models.rst b/docs/source/models.rst index d509f96b..7819b399 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -66,3 +66,20 @@ ScaleSpaceFlow -------------- .. autoclass:: ScaleSpaceFlow + +.. currentmodule:: compressai.models.pointcloud + +DensityPreservingReconstructionPccModel +--------------------------------------- +.. autoclass:: DensityPreservingReconstructionPccModel + + +PointNetReconstructionPccModel +------------------------------ +.. autoclass:: PointNetReconstructionPccModel + + +PointNet2SsgReconstructionPccModel +---------------------------------- +.. autoclass:: PointNet2SsgReconstructionPccModel + From 4ca480b2102da2729ee2feaf3976ddbdf89e0d82 Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Wed, 31 Jan 2024 19:59:11 -0800 Subject: [PATCH 17/27] fix: point cloud datasets optional dependencies --- compressai/datasets/pointcloud/modelnet.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/compressai/datasets/pointcloud/modelnet.py b/compressai/datasets/pointcloud/modelnet.py index 65554e63..dcb5d7ff 100644 --- a/compressai/datasets/pointcloud/modelnet.py +++ b/compressai/datasets/pointcloud/modelnet.py @@ -36,7 +36,10 @@ import numpy as np -from pyntcloud import PyntCloud +try: + from pyntcloud import PyntCloud +except ImportError: + pass # NOTE: Optional dependency. from compressai.datasets.cache import CacheDataset from compressai.datasets.utils import download_url, hash_file From 16a4752879dc3e1d8efc5ff4d3d6527b2b8ec5db Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Wed, 31 Jan 2024 19:59:29 -0800 Subject: [PATCH 18/27] fix: point cloud layers optional dependencies --- compressai/layers/pointcloud/hrtzxf2022.py | 6 +++++- compressai/layers/pointcloud/utils.py | 5 ++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/compressai/layers/pointcloud/hrtzxf2022.py b/compressai/layers/pointcloud/hrtzxf2022.py index 293f6aa0..dd230fcd 100644 --- a/compressai/layers/pointcloud/hrtzxf2022.py +++ b/compressai/layers/pointcloud/hrtzxf2022.py @@ -37,7 +37,11 @@ import torch.nn.functional as F from einops import rearrange, repeat -from pointops.functions import pointops + +try: + from pointops.functions import pointops +except ImportError: + pass # NOTE: Optional dependency. from .utils import index_points diff --git a/compressai/layers/pointcloud/utils.py b/compressai/layers/pointcloud/utils.py index cb933837..63e9c764 100644 --- a/compressai/layers/pointcloud/utils.py +++ b/compressai/layers/pointcloud/utils.py @@ -31,7 +31,10 @@ import torch -from pointops.functions import pointops +try: + from pointops.functions import pointops +except ImportError: + pass # NOTE: Optional dependency. def index_points(xyzs, idx): From 214050ccdf4e0d899c35a8f200fd80a373eea61e Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Wed, 31 Jan 2024 19:59:37 -0800 Subject: [PATCH 19/27] fix: point cloud losses optional dependencies --- compressai/losses/pointcloud/chamfer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/compressai/losses/pointcloud/chamfer.py b/compressai/losses/pointcloud/chamfer.py index 10978afd..ed318edf 100644 --- a/compressai/losses/pointcloud/chamfer.py +++ b/compressai/losses/pointcloud/chamfer.py @@ -33,7 +33,11 @@ import torch.nn as nn from einops import rearrange -from pointops.functions import pointops + +try: + from pointops.functions import pointops +except ImportError: + pass # NOTE: Optional dependency. from compressai.layers.pointcloud.hrtzxf2022 import index_points from compressai.losses.utils import compute_rate_loss From 20838fc2ecc23c13d1409dfd96cdd12925f66cf0 Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Wed, 31 Jan 2024 17:39:05 -0800 Subject: [PATCH 20/27] feat: zoo.pointcloud_models [placeholder] --- compressai/zoo/__init__.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/compressai/zoo/__init__.py b/compressai/zoo/__init__.py index a45512d5..6bd6b257 100644 --- a/compressai/zoo/__init__.py +++ b/compressai/zoo/__init__.py @@ -49,10 +49,18 @@ "cheng2020-attn": cheng2020_attn, } +# Not yet available. +pointcloud_models = { + "hrtzxf2022-pcc-rec": None, + "sfu2023-pcc-rec-pointnet": None, + "sfu2024-pcc-rec-pointnet2-ssg": None, +} + video_models = { "ssf2020": ssf2020, } models = {} models.update(image_models) +models.update(pointcloud_models) models.update(video_models) From ec7c7faac146eca76abb78cc0485d3cb89f23598 Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Wed, 31 Jan 2024 20:01:04 -0800 Subject: [PATCH 21/27] feat: examples/train_pointcloud.py ```bash python examples/train_pointcloud.py --cuda --dataset="datasets/modelnet40" ``` --- examples/train_pointcloud.py | 373 +++++++++++++++++++++++++++++++++++ 1 file changed, 373 insertions(+) create mode 100644 examples/train_pointcloud.py diff --git a/examples/train_pointcloud.py b/examples/train_pointcloud.py new file mode 100644 index 00000000..c909353e --- /dev/null +++ b/examples/train_pointcloud.py @@ -0,0 +1,373 @@ +# Copyright (c) 2021-2022, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +import random +import shutil +import sys + +import torch +import torch.nn as nn +import torch.optim as optim + +from torch.utils.data import DataLoader +from torchvision.transforms import Compose + +import compressai.transforms as transforms + +from compressai.datasets import ModelNetDataset +from compressai.losses import ChamferPccRateDistortionLoss +from compressai.optimizers import net_aux_optimizer +from compressai.registry import MODELS +from compressai.zoo import pointcloud_models + + +class AverageMeter: + """Compute running average.""" + + def __init__(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +class CustomDataParallel(nn.DataParallel): + """Custom DataParallel to access the module methods.""" + + def __getattr__(self, key): + try: + return super().__getattr__(key) + except AttributeError: + return getattr(self.module, key) + + +def configure_optimizers(net, args): + """Separate parameters for the main optimizer and the auxiliary optimizer. + Return two optimizers""" + conf = { + "net": {"type": "Adam", "lr": args.learning_rate}, + "aux": {"type": "Adam", "lr": args.aux_learning_rate}, + } + optimizer = net_aux_optimizer(net, conf) + return optimizer["net"], optimizer["aux"] + + +def train_one_epoch( + model, criterion, train_dataloader, optimizer, aux_optimizer, epoch, clip_max_norm +): + model.train() + device = next(model.parameters()).device + + for i, d in enumerate(train_dataloader): + d = {k: v.to(device) for k, v in d.items()} + + optimizer.zero_grad() + aux_optimizer.zero_grad() + + out_net = model(d) + + out_criterion = criterion(out_net, d) + out_criterion["loss"].backward() + if clip_max_norm > 0: + torch.nn.utils.clip_grad_norm_(model.parameters(), clip_max_norm) + optimizer.step() + + aux_loss = model.aux_loss() + aux_loss.backward() + aux_optimizer.step() + + if i % 10 == 0: + print( + f"Train epoch {epoch}: [" + f"{i*len(d)}/{len(train_dataloader.dataset)} " + f"({100. * i / len(train_dataloader):.0f}%)] " + f'Loss: {out_criterion["loss"].item():.3f} | ' + f'Bpp loss: {out_criterion["bpp_loss"].item():.3f} | ' + f'Rec loss: {out_criterion["rec_loss"].item():.4f} | ' + # f'Aux loss: {aux_loss.item():.0f} | ' + "\n" + ) + + +def test_epoch(epoch, test_dataloader, model, criterion): + model.eval() + model.update(force=True, update_quantiles=True) + device = next(model.parameters()).device + + meter_keys = ["loss", "bpp_loss", "rec_loss", "aux_loss"] + meters = {key: AverageMeter() for key in meter_keys} + + with torch.no_grad(): + for d in test_dataloader: + d = {k: v.to(device) for k, v in d.items()} + + out_net = model(d) + out_criterion = criterion(out_net, d) + out_criterion["aux_loss"] = model.aux_loss() + + for key in meters: + if key in out_criterion: + meters[key].update(out_criterion[key]) + + print( + f"Test epoch {epoch}: Average losses: " + f'Loss: {meters["loss"].avg:.3f} | ' + f'Bpp loss: {meters["bpp_loss"].avg:.3f} | ' + f'Rec loss: {meters["rec_loss"].avg:.4f} | ' + # f'Aux loss: {meters["aux_loss"].avg:.0f} | ' + "\n" + ) + + return meters["loss"].avg + + +def save_checkpoint(state, is_best, filename="checkpoint.pth.tar"): + torch.save(state, filename) + if is_best: + shutil.copyfile(filename, "checkpoint_best_loss.pth.tar") + + +def parse_args(argv): + parser = argparse.ArgumentParser(description="Example training script.") + parser.add_argument( + "-m", + "--model", + default="sfu2023-pcc-rec-pointnet", + choices=pointcloud_models.keys(), + help="Model architecture (default: %(default)s)", + ) + parser.add_argument( + "-d", "--dataset", type=str, required=True, help="Training dataset" + ) + parser.add_argument( + "-e", + "--epochs", + default=100, + type=int, + help="Number of epochs (default: %(default)s)", + ) + parser.add_argument( + "-lr", + "--learning-rate", + default=1e-4, + type=float, + help="Learning rate (default: %(default)s)", + ) + parser.add_argument( + "-n", + "--num-workers", + type=int, + default=4, + help="Dataloaders threads (default: %(default)s)", + ) + parser.add_argument( + "--lambda", + dest="lmbda", + type=float, + default=100, + help="Bit-rate distortion parameter (default: %(default)s)", + ) + parser.add_argument( + "--batch-size", type=int, default=16, help="Batch size (default: %(default)s)" + ) + parser.add_argument( + "--test-batch-size", + type=int, + default=64, + help="Test batch size (default: %(default)s)", + ) + parser.add_argument( + "--aux-learning-rate", + type=float, + default=1e-3, + help="Auxiliary loss learning rate (default: %(default)s)", + ) + parser.add_argument( + "--patch-size", + type=int, + nargs=2, + default=(256, 256), + help="Size of the patches to be cropped (default: %(default)s)", + ) + parser.add_argument("--cuda", action="store_true", help="Use cuda") + parser.add_argument( + "--save", action="store_true", default=True, help="Save model to disk" + ) + parser.add_argument("--seed", type=int, help="Set random seed for reproducibility") + parser.add_argument( + "--clip_max_norm", + default=1.0, + type=float, + help="gradient clipping max norm (default: %(default)s", + ) + parser.add_argument("--checkpoint", type=str, help="Path to a checkpoint") + args = parser.parse_args(argv) + return args + + +def main(argv): + args = parse_args(argv) + + if args.seed is not None: + torch.manual_seed(args.seed) + random.seed(args.seed) + + num_points = 1024 + + train_dataset = ModelNetDataset( + args.dataset, + split="train", + pre_transform=Compose( + [ + transforms.ToDict(wrapper="torch_geometric.data.Data"), + transforms.SamplePointsV2( + num=8192, remove_faces=True, include_normals=True, static_seed=1234 + ), + transforms.NormalizeScaleV2(center=True, scale_method="l2"), + transforms.ToDict(wrapper="dict"), + ] + ), + transform=Compose( + [ + transforms.ToDict(wrapper="torch_geometric.data.Data"), + transforms.RandomSample(num=num_points, attrs=["pos", "normal"]), + transforms.ToDict(wrapper="dict"), + ] + ), + ) + + test_dataset = ModelNetDataset( + args.dataset, + split="test", + pre_transform=Compose( + [ + transforms.ToDict(wrapper="torch_geometric.data.Data"), + transforms.SamplePointsV2( + num=8192, remove_faces=True, include_normals=True, static_seed=1234 + ), + transforms.NormalizeScaleV2(center=True, scale_method="l2"), + transforms.ToDict(wrapper="dict"), + ] + ), + transform=Compose( + [ + transforms.ToDict(wrapper="torch_geometric.data.Data"), + transforms.RandomSample( + num=num_points, attrs=["pos", "normal"], static_seed=1234 + ), + transforms.ToDict(wrapper="dict"), + ] + ), + ) + + device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu" + + train_dataloader = DataLoader( + train_dataset, + batch_size=args.batch_size, + num_workers=args.num_workers, + shuffle=True, + pin_memory=(device == "cuda"), + ) + + test_dataloader = DataLoader( + test_dataset, + batch_size=args.test_batch_size, + num_workers=args.num_workers, + shuffle=False, + pin_memory=(device == "cuda"), + ) + + net = MODELS[args.model]() + net = net.to(device) + + if args.cuda and torch.cuda.device_count() > 1: + net = CustomDataParallel(net) + + optimizer, aux_optimizer = configure_optimizers(net, args) + lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min") + criterion = ChamferPccRateDistortionLoss(lmbda={"bpp": 1.0, "rec": args.lmbda}) + + last_epoch = 0 + if args.checkpoint: # load from previous checkpoint + print("Loading", args.checkpoint) + checkpoint = torch.load(args.checkpoint, map_location=device) + last_epoch = checkpoint["epoch"] + 1 + net.load_state_dict(checkpoint["state_dict"]) + optimizer.load_state_dict(checkpoint["optimizer"]) + aux_optimizer.load_state_dict(checkpoint["aux_optimizer"]) + lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) + + best_loss = float("inf") + for epoch in range(last_epoch, args.epochs): + print(f"Learning rate: {optimizer.param_groups[0]['lr']}") + train_one_epoch( + net, + criterion, + train_dataloader, + optimizer, + aux_optimizer, + epoch, + args.clip_max_norm, + ) + loss = test_epoch(epoch, test_dataloader, net, criterion) + lr_scheduler.step(loss) + + is_best = loss < best_loss + best_loss = min(loss, best_loss) + + if args.save: + save_checkpoint( + { + "epoch": epoch, + "state_dict": net.state_dict(), + "loss": loss, + "optimizer": optimizer.state_dict(), + "aux_optimizer": aux_optimizer.state_dict(), + "lr_scheduler": lr_scheduler.state_dict(), + }, + is_best, + ) + + +if __name__ == "__main__": + main(sys.argv[1:]) + + +# NOTE: A more complete trainer with experiment tracking, visualizations, etc +# that uses CompressAI Trainer can be found at: +# +# https://github.com/multimedialabsfu/learned-point-cloud-compression-for-classification From 50ab5ea0c7eb1cdb7bba331e93bff0edc05622b0 Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Thu, 1 Feb 2024 18:42:03 -0800 Subject: [PATCH 22/27] chore(deps): python_requires>=3.7 Drop Python 3.6 support since `torch-geometric>=2.3.0` requires Python 3.7+. Python 3.6 (released 2016) went EOL in 2021, and Ubuntu 18.04 LTS went EOL in 2023. --- .github/workflows/pytest.yml | 4 ---- .github/workflows/python-package.yml | 2 +- .github/workflows/static-analysis.yml | 4 ---- .gitlab-ci.yml | 8 ++++---- Readme.md | 2 +- setup.py | 2 +- 6 files changed, 7 insertions(+), 15 deletions(-) diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index c96f095a..025f7bfd 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -8,15 +8,11 @@ jobs: strategy: matrix: python-version: - - "3.6" - "3.7" - "3.8" - "3.9" include: - os: "ubuntu-latest" - # no Python 3.6 in ubuntu>20.04. - - os: "ubuntu-20.04" - python-version: "3.6" steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 43ac98ad..0bc5d75e 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -38,7 +38,7 @@ jobs: runs-on: macos-latest strategy: matrix: - python-version: [3.6, 3.7, 3.8, 3.9] + python-version: [3.7, 3.8, 3.9] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} diff --git a/.github/workflows/static-analysis.yml b/.github/workflows/static-analysis.yml index 4d9467c1..0a5de6b6 100644 --- a/.github/workflows/static-analysis.yml +++ b/.github/workflows/static-analysis.yml @@ -8,15 +8,11 @@ jobs: strategy: matrix: python-version: - - "3.6" - "3.7" - "3.8" - "3.9" include: - os: "ubuntu-latest" - # no Python 3.6 in ubuntu>20.04. - - os: "ubuntu-20.04" - python-version: "3.6" steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 61a6fd4b..e84de40d 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -24,12 +24,12 @@ wheel: expire_in: 1 day parallel: matrix: - - PYTHON_VERSION: ['3.6', '3.7', '3.8', '3.9'] + - PYTHON_VERSION: ['3.7', '3.8', '3.9'] tags: - docker sdist: - image: python:3.6-buster + image: python:3.7-buster stage: build before_script: - pip install build @@ -51,7 +51,7 @@ flake8: black: stage: static-analysis - image: python:3.6-buster + image: python:3.7-buster before_script: - python --version - pip install black @@ -62,7 +62,7 @@ black: isort: stage: static-analysis - image: python:3.6-buster + image: python:3.7-buster before_script: - python --version - pip install . diff --git a/Readme.md b/Readme.md index 048e098f..b6fd69b9 100644 --- a/Readme.md +++ b/Readme.md @@ -24,7 +24,7 @@ CompressAI currently provides: ## Installation -CompressAI supports python 3.6+ and PyTorch 1.7+. +CompressAI supports python 3.7+ and PyTorch 1.7+. **pip**: diff --git a/setup.py b/setup.py index c97958d5..dd6785a1 100644 --- a/setup.py +++ b/setup.py @@ -131,7 +131,7 @@ def get_extra_requirements(): author_email="compressai@interdigital.com", packages=find_packages(exclude=("tests",)), zip_safe=False, - python_requires=">=3.6", + python_requires=">=3.7", install_requires=[ "einops", "numpy", From 115d91f1ab7c12e6cc9a02616ad463f6acc4b9f8 Mon Sep 17 00:00:00 2001 From: Fabien Racape Date: Sat, 3 Feb 2024 01:11:21 -0800 Subject: [PATCH 23/27] update github workflow and gitlab-ci --- .github/workflows/python-package.yml | 14 +++++++------- .gitlab-ci.yml | 1 - 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 0bc5d75e..4bfeb35e 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -9,13 +9,13 @@ jobs: sdist: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Set up Python 3.8 - uses: actions/setup-python@v2 + uses: actions/setup-python@v3 with: python-version: 3.8 - name: Cache pip - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ~/.cache/pip key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }} @@ -40,13 +40,13 @@ jobs: matrix: python-version: [3.7, 3.8, 3.9] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v3 with: python-version: ${{ matrix.python-version }} - name: Cache pip - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ~/.cache/pip key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }} @@ -72,7 +72,7 @@ jobs: matrix: python-version: [cp36-cp36m, cp37-cp37m, cp38-cp38, cp39-cp39] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Install dependencies run: /opt/python/${{ matrix.python-version }}/bin/python -m pip install build twine - name: Build wheel diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index e84de40d..0c93aa02 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -79,7 +79,6 @@ test: before_script: - python --version - pip install -e . - - pip install -r requirements.txt - pip install pytest pytest-cov plotly script: - > From 2a6876e16bf163650b69130ac7d8d726cfccadee Mon Sep 17 00:00:00 2001 From: Fabien Racape Date: Sat, 3 Feb 2024 18:30:13 -0800 Subject: [PATCH 24/27] update github workflow and gitlab-ci --- .gitlab-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 0c93aa02..225d3588 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -92,7 +92,7 @@ test: - PYTORCH_IMAGE: - "1.9.0-cuda11.1-cudnn8-devel" - "1.8.1-cuda11.1-cudnn8-devel" - - "1.7.1-cuda11.0-cudnn8-devel" + # - "1.7.1-cuda11.0-cudnn8-devel" tags: - docker From cb07223695409b662d9d5e632da117759bda2d17 Mon Sep 17 00:00:00 2001 From: Fabien Racape Date: Sat, 3 Feb 2024 21:49:16 -0800 Subject: [PATCH 25/27] update numpy version for point cloud deps (pandas) --- setup.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/setup.py b/setup.py index dd6785a1..9ac6e9c6 100644 --- a/setup.py +++ b/setup.py @@ -134,7 +134,7 @@ def get_extra_requirements(): python_requires=">=3.7", install_requires=[ "einops", - "numpy", + "numpy>=1.21.0", "pandas", "scipy", "matplotlib", @@ -151,7 +151,6 @@ def get_extra_requirements(): "Intended Audience :: Developers", "Intended Audience :: Science/Research", "License :: OSI Approved :: BSD License", - "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", From ce8b9cfc082ab9ac6f3acd4b5a12855c18d79543 Mon Sep 17 00:00:00 2001 From: Fabien Racape Date: Sat, 3 Feb 2024 23:53:52 -0800 Subject: [PATCH 26/27] [chores] update github actions --- .github/workflows/pytest.yml | 4 ++-- .github/workflows/python-package.yml | 10 +++++----- .github/workflows/static-analysis.yml | 4 ++-- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 025f7bfd..bb7ed283 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -14,9 +14,9 @@ jobs: include: - os: "ubuntu-latest" steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v3 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - name: Install Python dependencies diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 4bfeb35e..edcd99c8 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -9,9 +9,9 @@ jobs: sdist: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python 3.8 - uses: actions/setup-python@v3 + uses: actions/setup-python@v4 with: python-version: 3.8 - name: Cache pip @@ -40,9 +40,9 @@ jobs: matrix: python-version: [3.7, 3.8, 3.9] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v3 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - name: Cache pip @@ -72,7 +72,7 @@ jobs: matrix: python-version: [cp36-cp36m, cp37-cp37m, cp38-cp38, cp39-cp39] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Install dependencies run: /opt/python/${{ matrix.python-version }}/bin/python -m pip install build twine - name: Build wheel diff --git a/.github/workflows/static-analysis.yml b/.github/workflows/static-analysis.yml index 0a5de6b6..f237c2bb 100644 --- a/.github/workflows/static-analysis.yml +++ b/.github/workflows/static-analysis.yml @@ -14,9 +14,9 @@ jobs: include: - os: "ubuntu-latest" steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v3 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - name: Install Python dependencies From 9c9b37b999ba97b9f0626f247a4fbaecd24797a6 Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Sun, 4 Feb 2024 16:28:25 -0800 Subject: [PATCH 27/27] refactor: rename transforms.point -> transforms.pointcloud --- compressai/transforms/__init__.py | 4 ++-- compressai/transforms/{point => pointcloud}/__init__.py | 0 .../{point => pointcloud}/generate_position_normals.py | 0 .../transforms/{point => pointcloud}/normalize_scale_v2.py | 0 .../transforms/{point => pointcloud}/random_permutation.py | 0 .../transforms/{point => pointcloud}/random_rotate_full.py | 0 compressai/transforms/{point => pointcloud}/random_sample.py | 0 .../transforms/{point => pointcloud}/sample_points_v2.py | 0 compressai/transforms/{point => pointcloud}/to_dict.py | 0 docs/source/transforms.rst | 2 +- 10 files changed, 3 insertions(+), 3 deletions(-) rename compressai/transforms/{point => pointcloud}/__init__.py (100%) rename compressai/transforms/{point => pointcloud}/generate_position_normals.py (100%) rename compressai/transforms/{point => pointcloud}/normalize_scale_v2.py (100%) rename compressai/transforms/{point => pointcloud}/random_permutation.py (100%) rename compressai/transforms/{point => pointcloud}/random_rotate_full.py (100%) rename compressai/transforms/{point => pointcloud}/random_sample.py (100%) rename compressai/transforms/{point => pointcloud}/sample_points_v2.py (100%) rename compressai/transforms/{point => pointcloud}/to_dict.py (100%) diff --git a/compressai/transforms/__init__.py b/compressai/transforms/__init__.py index 3cc90328..c3bb1bd1 100644 --- a/compressai/transforms/__init__.py +++ b/compressai/transforms/__init__.py @@ -27,6 +27,6 @@ # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from . import point as point -from .point import * +from . import pointcloud as pointcloud +from .pointcloud import * from .transforms import * diff --git a/compressai/transforms/point/__init__.py b/compressai/transforms/pointcloud/__init__.py similarity index 100% rename from compressai/transforms/point/__init__.py rename to compressai/transforms/pointcloud/__init__.py diff --git a/compressai/transforms/point/generate_position_normals.py b/compressai/transforms/pointcloud/generate_position_normals.py similarity index 100% rename from compressai/transforms/point/generate_position_normals.py rename to compressai/transforms/pointcloud/generate_position_normals.py diff --git a/compressai/transforms/point/normalize_scale_v2.py b/compressai/transforms/pointcloud/normalize_scale_v2.py similarity index 100% rename from compressai/transforms/point/normalize_scale_v2.py rename to compressai/transforms/pointcloud/normalize_scale_v2.py diff --git a/compressai/transforms/point/random_permutation.py b/compressai/transforms/pointcloud/random_permutation.py similarity index 100% rename from compressai/transforms/point/random_permutation.py rename to compressai/transforms/pointcloud/random_permutation.py diff --git a/compressai/transforms/point/random_rotate_full.py b/compressai/transforms/pointcloud/random_rotate_full.py similarity index 100% rename from compressai/transforms/point/random_rotate_full.py rename to compressai/transforms/pointcloud/random_rotate_full.py diff --git a/compressai/transforms/point/random_sample.py b/compressai/transforms/pointcloud/random_sample.py similarity index 100% rename from compressai/transforms/point/random_sample.py rename to compressai/transforms/pointcloud/random_sample.py diff --git a/compressai/transforms/point/sample_points_v2.py b/compressai/transforms/pointcloud/sample_points_v2.py similarity index 100% rename from compressai/transforms/point/sample_points_v2.py rename to compressai/transforms/pointcloud/sample_points_v2.py diff --git a/compressai/transforms/point/to_dict.py b/compressai/transforms/pointcloud/to_dict.py similarity index 100% rename from compressai/transforms/point/to_dict.py rename to compressai/transforms/pointcloud/to_dict.py diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index cd4ced50..9737b299 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -28,5 +28,5 @@ Functional transforms can be used to define custom transform classes. Point Cloud Transforms ---------------------- -.. automodule:: compressai.transforms.point +.. automodule:: compressai.transforms.pointcloud :members: