Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create better xarray from common MIKE Dfs0 output #650

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 77 additions & 6 deletions mikeio/dataset/_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,11 @@ class Dataset:

def __init__(
self,
data: Mapping[str, DataArray]
| Sequence[DataArray]
| Sequence[NDArray[np.floating]],
data: (
Mapping[str, DataArray]
| Sequence[DataArray]
| Sequence[NDArray[np.floating]]
),
time: pd.DatetimeIndex | None = None,
items: Sequence[ItemInfo] | None = None,
geometry: Any = None,
Expand Down Expand Up @@ -1883,11 +1885,80 @@ def _to_dfsu(self, filename: str | Path) -> None:

_write_dfsu(filename, self)

def to_xarray(self) -> "xarray.Dataset":
"""Export to xarray.Dataset"""
def to_xarray(self, add_location_dim: bool = False) -> "xarray.Dataset":
"""Export to xarray.Dataset

Parameters
----------
add_location_dim: bool, optional
If True, add a location dimension to the dataset, by default False

Returns
-------
xarray.Dataset

Examples
--------
>>> import mikeio
>>> ds = mikeio.read("tests/testdata/waves.dfs2")
>>> ds.to_xarray()
<xarray.Dataset>
Dimensions: (time: 3, y: 31, x: 31)
Coordinates:
* time (time) datetime64[ns] 2004-01-01 2004-01-02 2004-01-03
* y (y) float64 25.0 75.0 125.0 ... 1.475e+03 1.525e+03
* x (x) float64 25.0 75.0 125.0 ... 1.475e+03 1.525e+03
Data variables:
Sign. Wave Height (time, y, x) float32 nan nan nan nan ... nan nan nan
Peak Wave Period (time, y, x) float32 nan nan nan nan ... nan nan nan
Mean Wave Direction (time, y, x) float32 nan nan nan nan ... nan nan nan
>>> ds = mikeio.read("tests/testdata/sw_points.dfs0")
>>> ds.items[0]
Buoy 2: Sign. Wave Height <Significant wave height> (meter)
>>> ds.to_xarray(add_location_dim=True)
<xarray.Dataset>
Dimensions: (location: 4, time: 11)
Coordinates:
* location (location) <U8 'Point 3' 'Point 42' ... 'Buoy 2'
* time (time) datetime64[ns] 2017-01-01 ... 2017-01-01T1...
Data variables: (12/15)
Peak Wave Period, W (location, time) float64 6.922 6.932 ... 6.911 6.911
Wave Period, T02 (location, time) float64 8.324 8.663 ... 7.223 7.198
Peak Wave Period, S (location, time) float64 12.39 12.4 ... 12.25 12.2
Sign. Wave Height (location, time) float64 2.978 2.887 ... 1.74 1.689
Wave Period, T02, W (location, time) float64 4.626 4.631 ... 4.561 4.581
Peak Wave Direction, W (location, time) float64 4.976 5.498 ... 5.257 5.264
... ...
Peak Wave Period (location, time) float64 12.39 12.4 ... 12.25 12.2
Mean Wave Direction (location, time) float64 5.343 5.357 ... 4.938 4.939
Mean Wave Direction, W (location, time) float64 5.175 5.387 ... 5.143 5.153
Wave Period, T02, S (location, time) float64 11.48 11.49 ... 10.99 10.94
Peak Wave Direction, S (location, time) float64 5.498 5.498 ... 4.915 4.915
Peak Wave Direction (location, time) float64 5.498 5.498 ... 4.921 4.921
"""
import xarray

data = {da.name: da.to_xarray() for da in self}
if add_location_dim:
item_names = [item.name.split(": ") for item in self.items]
if not all(len(item) == 2 for item in item_names):
raise ValueError(
"All items must have a location and variable name separated by ':'."
)
locations, variables = zip(*item_names)
locations = list(set([loc for loc in locations]))
variables = list(set([var for var in variables]))

data = {}
for var in variables:
var_items = [f"{loc}: {var}" for loc in locations]
var_data = np.array([self[item].to_numpy() for item in var_items])
data[var] = xarray.DataArray(
var_data,
dims=("location", "time"),
coords=[locations, self.time],
)
else:
data = {da.name: da.to_xarray() for da in self}
return xarray.Dataset(data)

# ===============================================
Expand Down
15 changes: 15 additions & 0 deletions tests/test_dfs0.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,3 +584,18 @@ def test_non_equidistant_time_can_read_correctly_with_open(tmp_path):
ds = dfs.read()

assert all(dfs.time == ds.time)


def test_locations_in_item_name_to_xarray():

ds = mikeio.read("tests/testdata/sw_points.dfs0")
assert ds.n_timesteps == 11

xr_ds = ds.to_xarray(add_location_dim=True)
assert xr_ds.sizes["time"] == 11
assert xr_ds["Sign. Wave Height"].sel(location="Buoy 2").isel(
time=0
).values == pytest.approx(2.09833)
xr_ds["Sign. Wave Height"].sel(
location="Buoy 2", time="2017-01-01 10:00:00"
).values == pytest.approx(1.68875253)
Loading