Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 114 additions & 12 deletions tensorstore/driver/zarr/driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
#include "absl/status/status.h"
#include "absl/strings/cord.h"
#include <nlohmann/json_fwd.hpp>
#include "riegeli/bytes/cord_reader.h"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing BUILD deps for these.

#include "riegeli/bytes/cord_writer.h"
#include "riegeli/bytes/read_all.h"
#include "riegeli/bytes/write.h"
#include "tensorstore/array.h"
#include "tensorstore/array_storage_statistics.h"
#include "tensorstore/box.h"
Expand All @@ -55,6 +59,7 @@
#include "tensorstore/internal/chunk_grid_specification.h"
#include "tensorstore/internal/grid_storage_statistics.h"
#include "tensorstore/internal/intrusive_ptr.h"
#include "tensorstore/internal/riegeli/array_endian_codec.h"
#include "tensorstore/internal/json_binding/bindable.h"
#include "tensorstore/internal/json_binding/json_binding.h"
#include "tensorstore/internal/uri_utils.h"
Expand Down Expand Up @@ -137,7 +142,8 @@ absl::Status ZarrDriverSpec::ApplyOptions(SpecOptions&& options) {
}

Result<SpecRankAndFieldInfo> ZarrDriverSpec::GetSpecInfo() const {
return GetSpecRankAndFieldInfo(partial_metadata, selected_field, schema);
return GetSpecRankAndFieldInfo(partial_metadata, selected_field, schema,
open_as_void);
}

TENSORSTORE_DEFINE_JSON_DEFAULT_BINDER(
Expand Down Expand Up @@ -171,7 +177,16 @@ TENSORSTORE_DEFINE_JSON_DEFAULT_BINDER(
jb::Member("field", jb::Projection<&ZarrDriverSpec::selected_field>(
jb::DefaultValue<jb::kNeverIncludeDefaults>(
[](auto* obj) { *obj = std::string{}; }))),
jb::Member("open_as_void",
jb::Projection<&ZarrDriverSpec::open_as_void>(
jb::DefaultValue<jb::kNeverIncludeDefaults>(
[](auto* v) { *v = false; }))),
jb::Initialize([](auto* obj) {
// Validate that field and open_as_void are mutually exclusive
if (obj->open_as_void && !obj->selected_field.empty()) {
return absl::InvalidArgumentError(
"\"field\" and \"open_as_void\" are mutually exclusive");
}
TENSORSTORE_ASSIGN_OR_RETURN(auto info, obj->GetSpecInfo());
if (info.full_rank != dynamic_rank) {
TENSORSTORE_RETURN_IF_ERROR(
Expand Down Expand Up @@ -209,9 +224,19 @@ Result<SharedArray<const void>> ZarrDriverSpec::GetFillValue(

const auto& metadata = partial_metadata;
if (metadata.dtype && metadata.fill_value) {
TENSORSTORE_ASSIGN_OR_RETURN(
size_t field_index, GetFieldIndex(*metadata.dtype, selected_field));
fill_value = (*metadata.fill_value)[field_index];
// For void access, synthesize a byte-level fill value
if (open_as_void) {
const Index nbytes = metadata.dtype->bytes_per_outer_element;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would expect this to get the fill_value from the metadata. More like:

size_t field_index = 0; // open_as_void has a single field.
if (!open_as_void) {
TENSORSTORE_ASSIGN_OR_RETURN(
field_index,
GetFieldIndex(*metadata.dtype, selected_field));
}
fill_value = (*metadata.fill_value)[field_index];

That might require CreateVoidMetadata to set a proper fill value?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes CreateVoidMetadata should set a proper fill value.

auto byte_arr = AllocateArray(
span<const Index, 1>({nbytes}), c_order, value_init,
dtype_v<tensorstore::dtypes::byte_t>);
fill_value = byte_arr;
} else {
TENSORSTORE_ASSIGN_OR_RETURN(
size_t field_index,
GetFieldIndex(*metadata.dtype, selected_field));
fill_value = (*metadata.fill_value)[field_index];
}
}

if (!fill_value.valid() || !transform.valid()) {
Expand Down Expand Up @@ -356,6 +381,7 @@ absl::Status DataCache::GetBoundSpecData(
const auto& metadata = *static_cast<const ZarrMetadata*>(metadata_ptr);
spec.selected_field = EncodeSelectedField(component_index, metadata.dtype);
spec.metadata_key = metadata_key_;
spec.open_as_void = false;
auto& pm = spec.partial_metadata;
pm.rank = metadata.rank;
pm.zarr_format = metadata.zarr_format;
Expand All @@ -382,6 +408,58 @@ Result<ChunkLayout> DataCache::GetChunkLayoutFromMetadata(
}

std::string DataCache::GetBaseKvstorePath() { return key_prefix_; }

// VoidDataCache implementation
// Uses inherited DataCache constructor and encode/decode methods.
// The void metadata (with dtype containing only the void field) is created
// in GetDataCache and passed via the initializer, so standard encode/decode
// paths work correctly.

absl::Status VoidDataCache::ValidateMetadataCompatibility(
const void* existing_metadata_ptr, const void* new_metadata_ptr) {
assert(existing_metadata_ptr);
assert(new_metadata_ptr);
const auto& existing_metadata =
*static_cast<const ZarrMetadata*>(existing_metadata_ptr);
const auto& new_metadata =
*static_cast<const ZarrMetadata*>(new_metadata_ptr);

// For void access, we only require that bytes_per_outer_element matches,
// since we're treating the data as raw bytes regardless of the actual dtype.
// Shape is allowed to differ (handled by base class for resizing).
// Other fields like compressor, order, chunks must still match.
if (existing_metadata.dtype.bytes_per_outer_element !=
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could just rely on the normal validate but applied to the void metadata.

new_metadata.dtype.bytes_per_outer_element) {
return absl::FailedPreconditionError(tensorstore::StrCat(
"Void access metadata bytes_per_outer_element mismatch: existing=",
existing_metadata.dtype.bytes_per_outer_element,
", new=", new_metadata.dtype.bytes_per_outer_element));
}

// Check that other critical fields match (same as base, but ignoring dtype)
if (existing_metadata.chunks != new_metadata.chunks) {
return absl::FailedPreconditionError("Chunk shape mismatch");
}
if (existing_metadata.order != new_metadata.order) {
return absl::FailedPreconditionError("Order mismatch");
}
if (existing_metadata.compressor != new_metadata.compressor) {
return absl::FailedPreconditionError("Compressor mismatch");
}

return absl::OkStatus();
}

absl::Status VoidDataCache::GetBoundSpecData(
internal_kvs_backed_chunk_driver::KvsDriverSpec& spec_base,
const void* metadata_ptr, size_t component_index) {
TENSORSTORE_RETURN_IF_ERROR(
DataCache::GetBoundSpecData(spec_base, metadata_ptr, component_index));
auto& spec = static_cast<ZarrDriverSpec&>(spec_base);
spec.open_as_void = true;
return absl::OkStatus();
}

Result<CodecSpec> ZarrDriver::GetCodec() {
return internal_zarr::GetCodecSpecFromMetadata(metadata());
}
Expand Down Expand Up @@ -416,6 +494,10 @@ Result<std::string> ZarrDriverSpec::ToUrl() const {
return absl::InvalidArgumentError(
"zarr2 URL syntax not supported with selected_field specified");
}
if (open_as_void) {
return absl::InvalidArgumentError(
"zarr2 URL syntax not supported with open_as_void specified");
}
TENSORSTORE_ASSIGN_OR_RETURN(auto base_url, store.ToUrl());
return tensorstore::StrCat(base_url, "|", kUrlScheme, ":");
}
Expand Down Expand Up @@ -451,7 +533,7 @@ Future<ArrayStorageStatistics> ZarrDriver::GetStorageStatistics(
/*chunk_shape=*/grid.chunk_shape,
/*shape=*/metadata->shape,
/*dimension_separator=*/
GetDimensionSeparatorChar(cache->dimension_separator_),
GetDimensionSeparatorChar(cache->dimension_separator()),
staleness_bound, request.options));
}),
std::move(promise), std::move(metadata_future));
Expand Down Expand Up @@ -483,7 +565,8 @@ class ZarrDriver::OpenState : public ZarrDriver::OpenStateBase {
TENSORSTORE_ASSIGN_OR_RETURN(
auto metadata,
internal_zarr::GetNewMetadata(spec().partial_metadata,
spec().selected_field, spec().schema),
spec().selected_field, spec().schema,
spec().open_as_void),
tensorstore::MaybeAnnotateStatus(
_, "Cannot create using specified \"metadata\" and schema"));
return metadata;
Expand All @@ -496,17 +579,28 @@ class ZarrDriver::OpenState : public ZarrDriver::OpenStateBase {
internal::EncodeCacheKey(
&result, spec.store.path,
GetDimensionSeparator(spec.partial_metadata, zarr_metadata),
zarr_metadata, spec.metadata_key);
zarr_metadata, spec.metadata_key,
spec.open_as_void ? "void" : "normal");
return result;
}

std::unique_ptr<internal_kvs_backed_chunk_driver::DataCacheBase> GetDataCache(
DataCache::Initializer&& initializer) override {
const auto& metadata =
const auto& original_metadata =
*static_cast<const ZarrMetadata*>(initializer.metadata.get());
auto dim_sep = GetDimensionSeparator(spec().partial_metadata, original_metadata);
if (spec().open_as_void) {
// Create void metadata from the original. This modifies the dtype to
// contain only the void field, allowing standard encode/decode to work.
// CreateVoidMetadata uses the same chunks and bytes_per_outer_element as
// the original validated metadata, so it should never fail.
initializer.metadata = CreateVoidMetadata(original_metadata).value();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably would be better to have the original metadata cache the pointer to the void metadata and initialize it on first access.

return std::make_unique<VoidDataCache>(
std::move(initializer), spec().store.path, dim_sep,
spec().metadata_key);
}
return std::make_unique<DataCache>(
std::move(initializer), spec().store.path,
GetDimensionSeparator(spec().partial_metadata, metadata),
std::move(initializer), spec().store.path, dim_sep,
spec().metadata_key);
}

Expand All @@ -515,8 +609,16 @@ class ZarrDriver::OpenState : public ZarrDriver::OpenStateBase {
const auto& metadata = *static_cast<const ZarrMetadata*>(metadata_ptr);
TENSORSTORE_RETURN_IF_ERROR(
ValidateMetadata(metadata, spec().partial_metadata));
TENSORSTORE_ASSIGN_OR_RETURN(
auto field_index, GetFieldIndex(metadata.dtype, spec().selected_field));
// For void access, use component index 0 since we create a special
// component for raw byte access
size_t field_index;
if (spec().open_as_void) {
field_index = 0;
} else {
TENSORSTORE_ASSIGN_OR_RETURN(
field_index,
GetFieldIndex(metadata.dtype, spec().selected_field));
}
TENSORSTORE_RETURN_IF_ERROR(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to make sure to validate the schema against the void metadata, not the regular metadata. Note: The partial_metadata validation is still against the regular metadata, though.

ValidateMetadataSchema(metadata, field_index, spec().schema));
return field_index;
Expand Down
28 changes: 27 additions & 1 deletion tensorstore/driver/zarr/driver_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,11 @@ class ZarrDriverSpec
ZarrPartialMetadata partial_metadata;
SelectedField selected_field;
std::string metadata_key;
bool open_as_void = false;

constexpr static auto ApplyMembers = [](auto& x, auto f) {
return f(internal::BaseCast<KvsDriverSpec>(x), x.partial_metadata,
x.selected_field, x.metadata_key);
x.selected_field, x.metadata_key, x.open_as_void);
};
absl::Status ApplyOptions(SpecOptions&& options) override;

Expand Down Expand Up @@ -137,11 +138,36 @@ class DataCache : public internal_kvs_backed_chunk_driver::DataCache {

std::string GetBaseKvstorePath() override;

DimensionSeparator dimension_separator() const { return dimension_separator_; }

protected:
std::string key_prefix_;
DimensionSeparator dimension_separator_;
std::string metadata_key_;
};

/// Derived DataCache for open_as_void mode that provides raw byte access.
///
/// The void metadata (created via CreateVoidMetadata) has dtype.fields
/// containing only the void field, so inherited encode/decode methods
/// work correctly for raw byte access. GetBoundSpecData is overridden
/// to set open_as_void=true in the spec, and ValidateMetadataCompatibility
/// is overridden to allow different dtypes with the same bytes_per_outer_element.
class VoidDataCache : public DataCache {
public:
using DataCache::DataCache;

/// For void access, metadata is compatible if bytes_per_outer_element matches,
/// regardless of the actual dtype (since we treat everything as raw bytes).
absl::Status ValidateMetadataCompatibility(
const void* existing_metadata_ptr,
const void* new_metadata_ptr) override;

absl::Status GetBoundSpecData(
internal_kvs_backed_chunk_driver::KvsDriverSpec& spec_base,
const void* metadata_ptr, size_t component_index) override;
};

class ZarrDriver;
using ZarrDriverBase = internal_kvs_backed_chunk_driver::RegisteredKvsDriver<
ZarrDriver, ZarrDriverSpec, DataCache,
Expand Down
Loading