From 1a7d6acac049f758f0f8538f095e1bb54968ad48 Mon Sep 17 00:00:00 2001 From: Henrik Andersson Date: Thu, 9 Jan 2025 21:03:27 +0100 Subject: [PATCH] Extract methods --- modelskill/comparison/_comparison.py | 129 +++++++++++++++++---------- pyproject.toml | 2 +- tests/test_comparer.py | 8 +- 3 files changed, 90 insertions(+), 49 deletions(-) diff --git a/modelskill/comparison/_comparison.py b/modelskill/comparison/_comparison.py index 26be14ce..d0ab081b 100644 --- a/modelskill/comparison/_comparison.py +++ b/modelskill/comparison/_comparison.py @@ -1258,38 +1258,56 @@ def save(self, filename: Union[str, Path]) -> None: filename """ ext = Path(filename).suffix + + match ext: + case ".db": + self._save_to_duckdb(filename) + case ".nc": + self._save_to_netcdf(filename) + case _: + raise NotImplementedError(f"Unknown extension: {ext}") + + def _save_to_netcdf(self, filename) -> None: ds = self.data - if ext == ".db": - import duckdb + if self.gtype == "point": + ds = self.data.copy() # copy needed to avoid modifying self.data - con = duckdb.connect(filename) - # TODO figure out how to save the x, y, z coordinates and other attributes later - df = ds.to_dataframe().drop(columns=["x", "y", "z"]).reset_index() # noqa - duckdb.sql("CREATE TABLE matched_data AS SELECT * FROM df", connection=con) + for key, ts_mod in self.raw_mod_data.items(): + ts_mod = ts_mod.copy() + # rename time to unique name + ts_mod.data = ts_mod.data.rename({"time": "_time_raw_" + key}) + # da = ds_mod.to_xarray()[key] + ds["_raw_" + key] = ts_mod.data[key] - attr_dict = {key: str(ds[key].attrs) for key in ds.data_vars} - attr_df = pd.DataFrame(attr_dict.items(), columns=["key", "value"]) # noqa + ds.to_netcdf(filename) - # attr_df["global", "key"] = str(ds.attrs) + def _save_to_duckdb( + self, + filename: Union[str, Path], + ) -> None: + import duckdb - duckdb.sql("CREATE TABLE attrs AS SELECT * FROM attr_df", connection=con) + ds = self.data - con.close() - elif ext == ".nc": - if self.gtype == "point": - ds = self.data.copy() # copy needed to avoid modifying self.data + con = duckdb.connect(filename) + # TODO figure out how to save the x, y, z coordinates + df = ds.to_dataframe().drop(columns=["x", "y", "z"]).reset_index() # noqa - for key, ts_mod in self.raw_mod_data.items(): - ts_mod = ts_mod.copy() - # rename time to unique name - ts_mod.data = ts_mod.data.rename({"time": "_time_raw_" + key}) - # da = ds_mod.to_xarray()[key] - ds["_raw_" + key] = ts_mod.data[key] + # TODO use self.name as prefix in table name to allow for multiple Comparers in the same db + duckdb.sql("CREATE TABLE matched_data AS SELECT * FROM df", connection=con) - ds.to_netcdf(filename) + attr_dict = {key: str(ds[key].attrs) for key in ds.data_vars} + attr_dict["global"] = str(ds.attrs) + attr_df = pd.DataFrame(attr_dict.items(), columns=["key", "value"]) # noqa + duckdb.sql("CREATE TABLE attrs AS SELECT * FROM attr_df", connection=con) - else: - raise NotImplementedError(f"Unknown extension: {ext}") + for key, value in self.raw_mod_data.items(): + rdf = ( # noqa + value.data.to_dataframe().reset_index().drop(columns=["x", "y", "z"]) + ) + duckdb.sql(f"CREATE TABLE raw_{key} AS SELECT * FROM rdf", connection=con) + + con.close() @staticmethod def load(filename: Union[str, Path]) -> "Comparer": @@ -1308,31 +1326,55 @@ def load(filename: Union[str, Path]) -> "Comparer": # get extension ext = Path(filename).suffix - if ext == ".db": - import duckdb + match ext: + case ".db": + return Comparer._load_from_duckdb(filename) + case ".nc": + return Comparer._load_from_netcdf(filename) + case _: + raise NotImplementedError(f"Unknown extension: {ext}") - con = duckdb.connect(filename) - df = ( - duckdb.sql("SELECT * FROM matched_data", connection=con) - .df() - .set_index("time") - ) + @staticmethod + def _load_from_duckdb(filename) -> "Comparer": + import duckdb + + con = duckdb.connect(filename) + df = ( + duckdb.sql("SELECT * FROM matched_data", connection=con) + .df() + .set_index("time") + ) - # convert pandas dataframe to xarray dataset - ds = xr.Dataset.from_dataframe(df) + # convert pandas dataframe to xarray dataset + ds = xr.Dataset.from_dataframe(df) - attrs = duckdb.sql("SELECT * FROM attrs", connection=con).df() - for row in attrs.iterrows(): - key = row[1]["key"] - value = row[1]["value"] + attrs = duckdb.sql("SELECT * FROM attrs", connection=con).df() + for row in attrs.iterrows(): + key = row[1]["key"] + value = row[1]["value"] + if key == "global": + ds.attrs = eval(value) + else: ds[key].attrs = eval(value) - # TODO figure out aux variables + raw_mod_data = {} + table_names = con.sql("SHOW TABLES").df()["name"].to_list() + raw_tables = [t for t in table_names if t[:4] == "raw_"] + for table in raw_tables: + key = table[4:] + rdf = ( + duckdb.sql(f"SELECT * FROM {table}", connection=con) + .df() + .set_index("time") + ) + raw_mod_data[key] = PointModelResult(data=rdf) + + return Comparer(matched_data=ds, raw_mod_data=raw_mod_data) - return Comparer(matched_data=ds) - elif ext == ".nc": - with xr.open_dataset(filename) as ds: - data = ds.load() + @staticmethod + def _load_from_netcdf(filename) -> "Comparer": + with xr.open_dataset(filename) as ds: + data = ds.load() if data.gtype == "track": return Comparer(matched_data=data) @@ -1355,8 +1397,5 @@ def load(filename: Union[str, Path]) -> "Comparer": data = data[[v for v in data.data_vars if "time" in data[v].dims]] return Comparer(matched_data=data, raw_mod_data=raw_mod_data) - else: raise NotImplementedError(f"Unknown gtype: {data.gtype}") - else: - raise NotImplementedError(f"Unknown extension: {ext}") diff --git a/pyproject.toml b/pyproject.toml index 162bdcaf..0d6fde4e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,7 +79,7 @@ ignore = ["E501"] select = ["E4", "E7", "E9", "F", "D200", "D205"] [tool.mypy] -python_version = "3.9" +python_version = "3.10" ignore_missing_imports = true warn_unreachable = false no_implicit_optional = true diff --git a/tests/test_comparer.py b/tests/test_comparer.py index 7f5b15ab..714ac7ce 100644 --- a/tests/test_comparer.py +++ b/tests/test_comparer.py @@ -6,6 +6,7 @@ from modelskill.comparison import Comparer from modelskill import __version__ import modelskill as ms +from modelskill.model.point import PointModelResult @pytest.fixture @@ -966,6 +967,7 @@ def test_save_load(pc, tmp_path) -> None: assert pc2.data.m2.attrs["kind"] == "model" assert pc2.data.Observation.attrs["kind"] == "observation" - # TODO global attrs - - # TODO raw_mod_data + assert pc2.name == "fake point obs" + assert pc2.gtype == "point" + assert len(pc2.raw_mod_data["m1"]) == 6 + assert isinstance(pc2.raw_mod_data["m2"], PointModelResult)