diff --git a/xarray_tensorstore.py b/xarray_tensorstore.py index a0390cb..629717e 100644 --- a/xarray_tensorstore.py +++ b/xarray_tensorstore.py @@ -24,6 +24,7 @@ import tensorstore import xarray from xarray.core import indexing +import zarr __version__ = '0.1.5' # keep in sync with setup.py @@ -176,12 +177,12 @@ def read(xarraydata: XarrayData, /) -> XarrayData: _DEFAULT_STORAGE_DRIVER = 'file' -def _zarr_spec_from_path(path: str) -> ...: +def _zarr_spec_from_path(path: str, zarr_format: int) -> ...: if re.match(r'\w+\://', path): # path is a URI kv_store = path else: kv_store = {'driver': _DEFAULT_STORAGE_DRIVER, 'path': path} - return {'driver': 'zarr', 'kvstore': kv_store} + return {'driver': f'zarr{zarr_format}', 'kvstore': kv_store} def _raise_if_mask_and_scale_used_for_data_vars(ds: xarray.Dataset): @@ -207,6 +208,14 @@ def _raise_if_mask_and_scale_used_for_data_vars(ds: xarray.Dataset): ) +def _get_zarr_format(path: str) -> int: + """Returns the Zarr format of the given path.""" + if zarr.__version__ >= '3.0.0': + return zarr.open_group(path, mode='r').metadata.zarr_format + else: + return 2 + + def open_zarr( path: str, *, @@ -271,7 +280,10 @@ def open_zarr( # incorrect data values. _raise_if_mask_and_scale_used_for_data_vars(ds) - specs = {k: _zarr_spec_from_path(os.path.join(path, k)) for k in ds} + zarr_format = _get_zarr_format(path) + specs = { + k: _zarr_spec_from_path(os.path.join(path, k), zarr_format) for k in ds + } array_futures = { k: tensorstore.open(spec, read=True, write=write, context=context) for k, spec in specs.items() diff --git a/xarray_tensorstore_test.py b/xarray_tensorstore_test.py index c40ab9b..93c48ce 100644 --- a/xarray_tensorstore_test.py +++ b/xarray_tensorstore_test.py @@ -20,6 +20,7 @@ import xarray from xarray.core import indexing import xarray_tensorstore +import zarr class XarrayTensorstoreTest(parameterized.TestCase): @@ -145,13 +146,19 @@ def test_open_zarr_from_uri(self): opened = xarray_tensorstore.open_zarr('file://' + path) xarray.testing.assert_identical(source, opened) - def test_read_dataset(self): + @parameterized.parameters( + {'zarr_format': 2}, + {'zarr_format': 3}, + ) + def test_read_dataset(self, zarr_format): + if zarr.__version__ < '3.0.0' and zarr_format == 3: + self.skipTest('zarr format 3 is not supported in zarr < 3.0.0') source = xarray.Dataset( {'baz': (('x', 'y', 'z'), np.arange(24).reshape(2, 3, 4))}, coords={'x': np.arange(2)}, ) path = self.create_tempdir().full_path - source.chunk().to_zarr(path) + source.chunk().to_zarr(path, zarr_format=zarr_format) opened = xarray_tensorstore.open_zarr(path) read = xarray_tensorstore.read(opened) @@ -160,7 +167,13 @@ def test_read_dataset(self): self.assertIsNotNone(read.variables['baz']._data.future) xarray.testing.assert_identical(read, source) - def test_read_dataarray(self): + @parameterized.parameters( + {'zarr_format': 2}, + {'zarr_format': 3}, + ) + def test_read_dataarray(self, zarr_format): + if zarr.__version__ < '3.0.0' and zarr_format == 3: + self.skipTest('zarr format 3 is not supported in zarr < 3.0.0') source = xarray.DataArray( np.arange(24).reshape(2, 3, 4), dims=('x', 'y', 'z'), @@ -168,7 +181,7 @@ def test_read_dataarray(self): coords={'x': np.arange(2)}, ) path = self.create_tempdir().full_path - source.to_dataset().chunk().to_zarr(path) + source.to_dataset().chunk().to_zarr(path, zarr_format=zarr_format) opened = xarray_tensorstore.open_zarr(path)['baz'] read = xarray_tensorstore.read(opened)