diff --git a/test/test_plot.py b/test/test_plot.py index 8af794c90..90e568483 100644 --- a/test/test_plot.py +++ b/test/test_plot.py @@ -103,7 +103,25 @@ def test_to_raster(gridpath): mesh_path = gridpath("mpas", "QU", "oQU480.231010.nc") uxds = ux.open_dataset(mesh_path, mesh_path) - raster = uxds['bottomDepth'].to_raster(ax=ax) + with pytest.warns(UserWarning, match=r"Axes extent was default"): + raster = uxds['bottomDepth'].to_raster(ax=ax) + + assert isinstance(raster, np.ndarray) + + +def test_to_raster_with_extra_dims(gridpath): + fig, ax = plt.subplots( + subplot_kw={'projection': ccrs.Robinson()}, + constrained_layout=True, + figsize=(10, 5), + ) + + mesh_path = gridpath("mpas", "QU", "oQU480.231010.nc") + uxds = ux.open_dataset(mesh_path, mesh_path) + + da = uxds['bottomDepth'].expand_dims(time=[0]) + with pytest.warns(UserWarning, match=r"Axes extent was default"): + raster = da.to_raster(ax=ax) assert isinstance(raster, np.ndarray) @@ -121,9 +139,10 @@ def test_to_raster_reuse_mapping(gridpath, tmpdir): uxds = ux.open_dataset(mesh_path, mesh_path) # Returning - raster1, pixel_mapping = uxds['bottomDepth'].to_raster( - ax=ax, pixel_ratio=0.5, return_pixel_mapping=True - ) + with pytest.warns(UserWarning, match=r"Axes extent was default"): + raster1, pixel_mapping = uxds['bottomDepth'].to_raster( + ax=ax, pixel_ratio=0.5, return_pixel_mapping=True + ) assert isinstance(raster1, np.ndarray) assert isinstance(pixel_mapping, xr.DataArray) diff --git a/uxarray/core/dataarray.py b/uxarray/core/dataarray.py index e4c802ce8..2c52710d1 100644 --- a/uxarray/core/dataarray.py +++ b/uxarray/core/dataarray.py @@ -375,7 +375,7 @@ def to_raster( _RasterAxAttrs, ) - _ensure_dimensions(self) + data = _ensure_dimensions(self) if not isinstance(ax, GeoAxes): raise TypeError("`ax` must be an instance of cartopy.mpl.geoaxes.GeoAxes") @@ -383,8 +383,8 @@ def to_raster( pixel_ratio_set = pixel_ratio is not None if not pixel_ratio_set: pixel_ratio = 1.0 - input_ax_attrs = _RasterAxAttrs.from_ax(ax, pixel_ratio=pixel_ratio) if pixel_mapping is not None: + input_ax_attrs = _RasterAxAttrs.from_ax(ax, pixel_ratio=pixel_ratio) if isinstance(pixel_mapping, xr.DataArray): pixel_ratio_input = pixel_ratio pixel_ratio = pixel_mapping.attrs["pixel_ratio"] @@ -403,9 +403,43 @@ def to_raster( + input_ax_attrs._value_comparison_message(pm_ax_attrs) ) pixel_mapping = np.asarray(pixel_mapping, dtype=INT_DTYPE) + else: + + def _is_default_extent() -> bool: + # Default extents are indicated by xlim/ylim being (0, 1) + # when autoscale is still on (no extent has been explicitly set) + if not ax.get_autoscale_on(): + return False + xlim, ylim = ax.get_xlim(), ax.get_ylim() + return np.allclose(xlim, (0.0, 1.0)) and np.allclose(ylim, (0.0, 1.0)) + + if _is_default_extent(): + try: + import cartopy.crs as ccrs + + lon_min = float(self.uxgrid.node_lon.min(skipna=True).values) + lon_max = float(self.uxgrid.node_lon.max(skipna=True).values) + lat_min = float(self.uxgrid.node_lat.min(skipna=True).values) + lat_max = float(self.uxgrid.node_lat.max(skipna=True).values) + ax.set_extent( + (lon_min, lon_max, lat_min, lat_max), + crs=ccrs.PlateCarree(), + ) + warn( + "Axes extent was default; auto-setting from grid lon/lat bounds for rasterization. " + "Set the extent explicitly to control this, e.g. via ax.set_global(), " + "ax.set_extent(...), or ax.set_xlim(...) + ax.set_ylim(...).", + stacklevel=2, + ) + except Exception as e: + warn( + f"Failed to auto-set extent from grid bounds: {e}", + stacklevel=2, + ) + input_ax_attrs = _RasterAxAttrs.from_ax(ax, pixel_ratio=pixel_ratio) raster, pixel_mapping_np = _nearest_neighbor_resample( - self, + data, ax, pixel_ratio=pixel_ratio, pixel_mapping=pixel_mapping, diff --git a/uxarray/plot/matplotlib.py b/uxarray/plot/matplotlib.py index adb0a8697..5e43a30af 100644 --- a/uxarray/plot/matplotlib.py +++ b/uxarray/plot/matplotlib.py @@ -22,17 +22,21 @@ def _ensure_dimensions(data: UxDataArray) -> UxDataArray: ValueError If the sole dimension is not named "n_face". """ - # Check dimensionality - if data.ndim != 1: + # Allow extra singleton dimensions as long as there's exactly one non-singleton dim + non_trivial_dims = [dim for dim, size in zip(data.dims, data.shape) if size != 1] + + if len(non_trivial_dims) != 1: raise ValueError( - f"Expected a 1D DataArray over 'n_face', but got {data.ndim} dimensions: {data.dims}" + "Expected data with a single dimension (other axes may be length 1), " + f"but got dims {data.dims} with shape {data.shape}" ) - # Check dimension name - if data.dims[0] != "n_face": - raise ValueError(f"Expected dimension 'n_face', but got '{data.dims[0]}'") + sole_dim = non_trivial_dims[0] + if sole_dim != "n_face": + raise ValueError(f"Expected dimension 'n_face', but got '{sole_dim}'") - return data + # Squeeze any singleton axes to ensure we return a true 1D array over n_face + return data.squeeze() class _RasterAxAttrs(NamedTuple): diff --git a/uxarray/utils/computing.py b/uxarray/utils/computing.py index 88699f3db..ca5ca5183 100644 --- a/uxarray/utils/computing.py +++ b/uxarray/utils/computing.py @@ -102,7 +102,6 @@ def dot_fma(v1, v2): ---------- S. Graillat, Ph. Langlois, and N. Louvet. "Accurate dot products with FMA." Presented at RNC 7, 2007, Nancy, France. DALI-LP2A Laboratory, University of Perpignan, France. - [Poster](https://www-pequan.lip6.fr/~graillat/papers/posterRNC7.pdf) """ if len(v1) != len(v2): raise ValueError("Input vectors must be of the same length")