From 32485ae79c013009b8265cb34b371e02bdabf6de Mon Sep 17 00:00:00 2001 From: Henrik Andersson Date: Thu, 9 Jan 2025 18:08:03 +0100 Subject: [PATCH] Experiment with duckdb for persistence --- modelskill/comparison/_comparison.py | 104 ++++++++++++++++++--------- pyproject.toml | 1 + tests/test_comparer.py | 11 +++ 3 files changed, 81 insertions(+), 35 deletions(-) diff --git a/modelskill/comparison/_comparison.py b/modelskill/comparison/_comparison.py index 79e55d84..366cfcbd 100644 --- a/modelskill/comparison/_comparison.py +++ b/modelskill/comparison/_comparison.py @@ -1250,35 +1250,42 @@ def to_dataframe(self) -> pd.DataFrame: raise NotImplementedError(f"Unknown gtype: {self.gtype}") def save(self, filename: Union[str, Path]) -> None: - """Save to netcdf file + """Save to duckdb file Parameters ---------- filename : str or Path filename """ + ext = Path(filename).suffix ds = self.data + if ext == ".db": + import duckdb + + con = duckdb.connect(filename) + # TODO figure out how to save the x, y, z coordinates and other attributes later + df = ds.to_dataframe().drop(columns=["x", "y", "z"]).reset_index() # noqa + duckdb.sql("CREATE TABLE data AS SELECT * FROM df", connection=con) + con.close() + elif ext == ".nc": + if self.gtype == "point": + ds = self.data.copy() # copy needed to avoid modifying self.data - # add self.raw_mod_data to ds with prefix 'raw_' to avoid name conflicts - # an alternative strategy would be to use NetCDF groups - # https://docs.xarray.dev/en/stable/user-guide/io.html#groups - - # There is no need to save raw data for track data, since it is identical to the matched data - if self.gtype == "point": - ds = self.data.copy() # copy needed to avoid modifying self.data + for key, ts_mod in self.raw_mod_data.items(): + ts_mod = ts_mod.copy() + # rename time to unique name + ts_mod.data = ts_mod.data.rename({"time": "_time_raw_" + key}) + # da = ds_mod.to_xarray()[key] + ds["_raw_" + key] = ts_mod.data[key] - for key, ts_mod in self.raw_mod_data.items(): - ts_mod = ts_mod.copy() - # rename time to unique name - ts_mod.data = ts_mod.data.rename({"time": "_time_raw_" + key}) - # da = ds_mod.to_xarray()[key] - ds["_raw_" + key] = ts_mod.data[key] + ds.to_netcdf(filename) - ds.to_netcdf(filename) + else: + raise NotImplementedError(f"Unknown extension: {ext}") @staticmethod def load(filename: Union[str, Path]) -> "Comparer": - """Load from netcdf file + """Load from duckdb file Parameters ---------- @@ -1289,30 +1296,57 @@ def load(filename: Union[str, Path]) -> "Comparer": ------- Comparer """ - with xr.open_dataset(filename) as ds: - data = ds.load() - if data.gtype == "track": - return Comparer(matched_data=data) + # get extension + ext = Path(filename).suffix - if data.gtype == "point": - raw_mod_data: Dict[str, PointModelResult] = {} + if ext == ".db": + import duckdb - for var in data.data_vars: - var_name = str(var) - if var_name[:5] == "_raw_": - new_key = var_name[5:] # remove prefix '_raw_' - ds = data[[var_name]].rename( - {"_time_raw_" + new_key: "time", var_name: new_key} - ) - ts = PointModelResult(data=ds, name=new_key) + con = duckdb.connect(filename) + df = duckdb.sql("SELECT * FROM data", connection=con).df().set_index("time") + + # convert pandas dataframe to xarray dataset + ds = xr.Dataset.from_dataframe(df) + + # set observation attribute + ds.Observation.attrs["kind"] = "observation" - raw_mod_data[new_key] = ts + # set model attributes + for key in ds.data_vars: + if key != "Observation": + ds[key].attrs["kind"] = "model" - # filter variables, only keep the ones with a 'time' dimension - data = data[[v for v in data.data_vars if "time" in data[v].dims]] + # TODO figure out aux variables - return Comparer(matched_data=data, raw_mod_data=raw_mod_data) + return Comparer(matched_data=ds) + elif ext == ".nc": + with xr.open_dataset(filename) as ds: + data = ds.load() + if data.gtype == "track": + return Comparer(matched_data=data) + + if data.gtype == "point": + raw_mod_data: Dict[str, PointModelResult] = {} + + for var in data.data_vars: + var_name = str(var) + if var_name[:5] == "_raw_": + new_key = var_name[5:] # remove prefix '_raw_' + ds = data[[var_name]].rename( + {"_time_raw_" + new_key: "time", var_name: new_key} + ) + ts = PointModelResult(data=ds, name=new_key) + + raw_mod_data[new_key] = ts + + # filter variables, only keep the ones with a 'time' dimension + data = data[[v for v in data.data_vars if "time" in data[v].dims]] + + return Comparer(matched_data=data, raw_mod_data=raw_mod_data) + + else: + raise NotImplementedError(f"Unknown gtype: {data.gtype}") else: - raise NotImplementedError(f"Unknown gtype: {data.gtype}") + raise NotImplementedError(f"Unknown extension: {ext}") diff --git a/pyproject.toml b/pyproject.toml index 1f16fa2e..162bdcaf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ dependencies = [ "netCDF4", "scipy", "jinja2", + "duckdb>=1.1.3", ] authors = [ diff --git a/tests/test_comparer.py b/tests/test_comparer.py index 2d8691d0..81bfbd90 100644 --- a/tests/test_comparer.py +++ b/tests/test_comparer.py @@ -951,3 +951,14 @@ def test_from_matched_non_scalar_xy_fails(): x=df.lon, y=df.lat, ) + + +def test_save_load(pc, tmp_path) -> None: + fp = tmp_path / "test.db" + + pc.save(fp) + pc2 = Comparer.load(fp) + + assert "m1" in pc2.mod_names + assert "m2" in pc2.mod_names + assert pc2.n_points == 5