From f3bbaf6d1182e8b6a3b225f323328543b0b0a76d Mon Sep 17 00:00:00 2001 From: David Stuebe Date: Tue, 27 Feb 2024 09:31:12 +0000 Subject: [PATCH 1/4] Add parallel chunk_getitems for grib reader --- zarr/core.py | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/zarr/core.py b/zarr/core.py index d22a9d79c3..7a44ebdbaa 100644 --- a/zarr/core.py +++ b/zarr/core.py @@ -10,6 +10,10 @@ import numpy as np from numcodecs.compat import ensure_bytes +import tempfile +import logging +logger = logging.getLogger(__name__) + from zarr._storage.store import _prefix_to_attrs_key, assert_zarr_v3_api_available from zarr.attrs import Attributes from zarr.codecs import AsType, get_codec @@ -2163,6 +2167,42 @@ def _chunk_getitems( partial_read_decode = False values = self.chunk_store.get_partial_values([(ckey, (0, None)) for ckey in ckeys]) cdatas = {key: value for key, value in zip(ckeys, values) if value is not None} + + + elif "GRIBCodec" in list(map(lambda x: str(x.__class__.__name__), self.filters or [])): + # Start parallel grib hack + # Make this really specific to GRIBCodec for now - we can make this more general later? + from joblib import Parallel, delayed + + def parallel_io_method(instance, c_key, c_select, out_sel, my_out): + try: + cdata = instance.chunk_store[c_key] + chunk = instance._decode_chunk(cdata) + tmp = chunk[c_select] + if drop_axes: + tmp = np.squeeze(tmp, axis=drop_axes) + my_out[out_sel] = tmp + + except Exception: + # TODO: get more context from the mapper about what chunk failed! + logger.exception("Error reading chunk %s", c_key) + my_out[out_sel] = instance._fill_value + + with tempfile.NamedTemporaryFile(mode="w+b", prefix="zarr_memmap") as f: + logger.warning("Creating memmap array of shape %s - this could oom", out.shape) + output = np.memmap(f, dtype=out.dtype, shape=out.shape, mode='w+') + + # Just setting mmap_mode to w+ doesn't seem to copy the data back to out... + Parallel()( + delayed(parallel_io_method)(self, ckey, chunk_select, out_select, output) + for ckey, chunk_select, out_select in zip(ckeys, lchunk_selection, lout_selection) + ) + + out[:] = output[:] + + return + # End parallel grib hack + else: partial_read_decode = False contexts = {} From dd75eb85974e12d869cb477d064265de595db8ba Mon Sep 17 00:00:00 2001 From: David Stuebe Date: Tue, 27 Feb 2024 09:42:37 +0000 Subject: [PATCH 2/4] Only use parallel for multiple chunks --- zarr/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zarr/core.py b/zarr/core.py index 7a44ebdbaa..62d56b00f3 100644 --- a/zarr/core.py +++ b/zarr/core.py @@ -2169,7 +2169,7 @@ def _chunk_getitems( cdatas = {key: value for key, value in zip(ckeys, values) if value is not None} - elif "GRIBCodec" in list(map(lambda x: str(x.__class__.__name__), self.filters or [])): + elif "GRIBCodec" in list(map(lambda x: str(x.__class__.__name__), self.filters or [])) and len(ckeys) > 1: # Start parallel grib hack # Make this really specific to GRIBCodec for now - we can make this more general later? from joblib import Parallel, delayed From b2c1f624232d636b39e2ad152ba4f2c2f734a6eb Mon Sep 17 00:00:00 2001 From: David Stuebe Date: Fri, 1 Mar 2024 04:08:23 +0000 Subject: [PATCH 3/4] Force /dev/shm --- zarr/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zarr/core.py b/zarr/core.py index 62d56b00f3..9eaf050b74 100644 --- a/zarr/core.py +++ b/zarr/core.py @@ -2188,7 +2188,7 @@ def parallel_io_method(instance, c_key, c_select, out_sel, my_out): logger.exception("Error reading chunk %s", c_key) my_out[out_sel] = instance._fill_value - with tempfile.NamedTemporaryFile(mode="w+b", prefix="zarr_memmap") as f: + with tempfile.NamedTemporaryFile(mode="w+b", prefix="zarr_memmap", dir="/dev/shm") as f: logger.warning("Creating memmap array of shape %s - this could oom", out.shape) output = np.memmap(f, dtype=out.dtype, shape=out.shape, mode='w+') From 3425b8da85962903feb37be2c1780d2261474895 Mon Sep 17 00:00:00 2001 From: David Stuebe Date: Fri, 1 Mar 2024 17:27:19 +0000 Subject: [PATCH 4/4] Use some hueristics for batch size --- zarr/core.py | 57 ++++++++++++++++++++++++++++++++++------------------ 1 file changed, 37 insertions(+), 20 deletions(-) diff --git a/zarr/core.py b/zarr/core.py index 9eaf050b74..4c288b3521 100644 --- a/zarr/core.py +++ b/zarr/core.py @@ -64,8 +64,30 @@ ensure_ndarray_like, ) + +from joblib import Parallel, delayed + __all__ = ["Array"] +# Number of ckeys required to trigger parallelism in _chunk_getitems +PARALLEL_THRESHOLD=12 +# Number of chunks to batch per parallel worker task +PARALLEL_BATCH_SIZE=8 +def parallel_io_method(instance, c_key, c_select, out_sel, drop_axes, my_out): + try: + cdata = instance.chunk_store[c_key] + chunk = instance._decode_chunk(cdata) + tmp = chunk[c_select] + if drop_axes: + tmp = np.squeeze(tmp, axis=drop_axes) + my_out[out_sel] = tmp + + except Exception: + # If the read/parse failed, the file name will be in the exception. + # If the key is not present, there is no more info to log. + logger.exception("Error reading chunk %s", c_key) + my_out[out_sel] = instance._fill_value + # noinspection PyUnresolvedReferences class Array: @@ -2169,32 +2191,27 @@ def _chunk_getitems( cdatas = {key: value for key, value in zip(ckeys, values) if value is not None} - elif "GRIBCodec" in list(map(lambda x: str(x.__class__.__name__), self.filters or [])) and len(ckeys) > 1: + elif "GRIBCodec" in list(map(lambda x: str(x.__class__.__name__), self.filters or [])): # Start parallel grib hack # Make this really specific to GRIBCodec for now - we can make this more general later? - from joblib import Parallel, delayed - - def parallel_io_method(instance, c_key, c_select, out_sel, my_out): - try: - cdata = instance.chunk_store[c_key] - chunk = instance._decode_chunk(cdata) - tmp = chunk[c_select] - if drop_axes: - tmp = np.squeeze(tmp, axis=drop_axes) - my_out[out_sel] = tmp - - except Exception: - # TODO: get more context from the mapper about what chunk failed! - logger.exception("Error reading chunk %s", c_key) - my_out[out_sel] = instance._fill_value + # Can we pass parameters for the heuristic behavior thresholds? Use module constants for now + key_count = len(ckeys) + if key_count <= PARALLEL_THRESHOLD: + logger.info("Chunk Count %s <= Parallel Threshold %s: Using Serial Chunk GetItems", key_count, PARALLEL_THRESHOLD) + for ckey, chunk_select, out_select in zip(ckeys, lchunk_selection, lout_selection): + parallel_io_method(self, ckey, chunk_select, out_select, drop_axes, out) + return + logger.info("Chunk Count %s greater than Parallel Threshold %s: Using Parallel Chunk GetItems with Parallel Batch Size: %s", key_count, PARALLEL_THRESHOLD, PARALLEL_BATCH_SIZE) + # Explicitly use /dev/shm to ensure we are working in memory with tempfile.NamedTemporaryFile(mode="w+b", prefix="zarr_memmap", dir="/dev/shm") as f: - logger.warning("Creating memmap array of shape %s - this could oom", out.shape) + logger.warning("Creating memmap array of shape %s, size %s - this could oom or exceed the size of /dev/shm", out.shape, out.nbytes) output = np.memmap(f, dtype=out.dtype, shape=out.shape, mode='w+') - + # Just setting mmap_mode to w+ doesn't seem to copy the data back to out... - Parallel()( - delayed(parallel_io_method)(self, ckey, chunk_select, out_select, output) + # Hard to know batch_size without n_jobs. Use a const here too + Parallel(pre_dispatch="2*n_jobs", batch_size=PARALLEL_BATCH_SIZE)( + delayed(parallel_io_method)(self, ckey, chunk_select, out_select, drop_axes, output) for ckey, chunk_select, out_select in zip(ckeys, lchunk_selection, lout_selection) )