Skip to content

Commit

Permalink
Extract methods
Browse files Browse the repository at this point in the history
  • Loading branch information
ecomodeller committed Jan 9, 2025
1 parent e6ea96f commit 1a7d6ac
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 49 deletions.
129 changes: 84 additions & 45 deletions modelskill/comparison/_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -1258,38 +1258,56 @@ def save(self, filename: Union[str, Path]) -> None:
filename
"""
ext = Path(filename).suffix

match ext:
case ".db":
self._save_to_duckdb(filename)
case ".nc":
self._save_to_netcdf(filename)
case _:
raise NotImplementedError(f"Unknown extension: {ext}")

def _save_to_netcdf(self, filename) -> None:
ds = self.data
if ext == ".db":
import duckdb
if self.gtype == "point":
ds = self.data.copy() # copy needed to avoid modifying self.data

con = duckdb.connect(filename)
# TODO figure out how to save the x, y, z coordinates and other attributes later
df = ds.to_dataframe().drop(columns=["x", "y", "z"]).reset_index() # noqa
duckdb.sql("CREATE TABLE matched_data AS SELECT * FROM df", connection=con)
for key, ts_mod in self.raw_mod_data.items():
ts_mod = ts_mod.copy()
# rename time to unique name
ts_mod.data = ts_mod.data.rename({"time": "_time_raw_" + key})
# da = ds_mod.to_xarray()[key]
ds["_raw_" + key] = ts_mod.data[key]

attr_dict = {key: str(ds[key].attrs) for key in ds.data_vars}
attr_df = pd.DataFrame(attr_dict.items(), columns=["key", "value"]) # noqa
ds.to_netcdf(filename)

# attr_df["global", "key"] = str(ds.attrs)
def _save_to_duckdb(
self,
filename: Union[str, Path],
) -> None:
import duckdb

duckdb.sql("CREATE TABLE attrs AS SELECT * FROM attr_df", connection=con)
ds = self.data

con.close()
elif ext == ".nc":
if self.gtype == "point":
ds = self.data.copy() # copy needed to avoid modifying self.data
con = duckdb.connect(filename)
# TODO figure out how to save the x, y, z coordinates
df = ds.to_dataframe().drop(columns=["x", "y", "z"]).reset_index() # noqa

for key, ts_mod in self.raw_mod_data.items():
ts_mod = ts_mod.copy()
# rename time to unique name
ts_mod.data = ts_mod.data.rename({"time": "_time_raw_" + key})
# da = ds_mod.to_xarray()[key]
ds["_raw_" + key] = ts_mod.data[key]
# TODO use self.name as prefix in table name to allow for multiple Comparers in the same db
duckdb.sql("CREATE TABLE matched_data AS SELECT * FROM df", connection=con)

ds.to_netcdf(filename)
attr_dict = {key: str(ds[key].attrs) for key in ds.data_vars}
attr_dict["global"] = str(ds.attrs)
attr_df = pd.DataFrame(attr_dict.items(), columns=["key", "value"]) # noqa
duckdb.sql("CREATE TABLE attrs AS SELECT * FROM attr_df", connection=con)

else:
raise NotImplementedError(f"Unknown extension: {ext}")
for key, value in self.raw_mod_data.items():
rdf = ( # noqa
value.data.to_dataframe().reset_index().drop(columns=["x", "y", "z"])
)
duckdb.sql(f"CREATE TABLE raw_{key} AS SELECT * FROM rdf", connection=con)

con.close()

@staticmethod
def load(filename: Union[str, Path]) -> "Comparer":
Expand All @@ -1308,31 +1326,55 @@ def load(filename: Union[str, Path]) -> "Comparer":
# get extension
ext = Path(filename).suffix

if ext == ".db":
import duckdb
match ext:
case ".db":
return Comparer._load_from_duckdb(filename)
case ".nc":
return Comparer._load_from_netcdf(filename)
case _:
raise NotImplementedError(f"Unknown extension: {ext}")

con = duckdb.connect(filename)
df = (
duckdb.sql("SELECT * FROM matched_data", connection=con)
.df()
.set_index("time")
)
@staticmethod
def _load_from_duckdb(filename) -> "Comparer":
import duckdb

con = duckdb.connect(filename)
df = (
duckdb.sql("SELECT * FROM matched_data", connection=con)
.df()
.set_index("time")
)

# convert pandas dataframe to xarray dataset
ds = xr.Dataset.from_dataframe(df)
# convert pandas dataframe to xarray dataset
ds = xr.Dataset.from_dataframe(df)

attrs = duckdb.sql("SELECT * FROM attrs", connection=con).df()
for row in attrs.iterrows():
key = row[1]["key"]
value = row[1]["value"]
attrs = duckdb.sql("SELECT * FROM attrs", connection=con).df()
for row in attrs.iterrows():
key = row[1]["key"]
value = row[1]["value"]
if key == "global":
ds.attrs = eval(value)
else:
ds[key].attrs = eval(value)

# TODO figure out aux variables
raw_mod_data = {}
table_names = con.sql("SHOW TABLES").df()["name"].to_list()
raw_tables = [t for t in table_names if t[:4] == "raw_"]
for table in raw_tables:
key = table[4:]
rdf = (
duckdb.sql(f"SELECT * FROM {table}", connection=con)
.df()
.set_index("time")
)
raw_mod_data[key] = PointModelResult(data=rdf)

return Comparer(matched_data=ds, raw_mod_data=raw_mod_data)

return Comparer(matched_data=ds)
elif ext == ".nc":
with xr.open_dataset(filename) as ds:
data = ds.load()
@staticmethod
def _load_from_netcdf(filename) -> "Comparer":
with xr.open_dataset(filename) as ds:
data = ds.load()

if data.gtype == "track":
return Comparer(matched_data=data)
Expand All @@ -1355,8 +1397,5 @@ def load(filename: Union[str, Path]) -> "Comparer":
data = data[[v for v in data.data_vars if "time" in data[v].dims]]

return Comparer(matched_data=data, raw_mod_data=raw_mod_data)

else:
raise NotImplementedError(f"Unknown gtype: {data.gtype}")
else:
raise NotImplementedError(f"Unknown extension: {ext}")
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ ignore = ["E501"]
select = ["E4", "E7", "E9", "F", "D200", "D205"]

[tool.mypy]
python_version = "3.9"
python_version = "3.10"
ignore_missing_imports = true
warn_unreachable = false
no_implicit_optional = true
8 changes: 5 additions & 3 deletions tests/test_comparer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from modelskill.comparison import Comparer
from modelskill import __version__
import modelskill as ms
from modelskill.model.point import PointModelResult


@pytest.fixture
Expand Down Expand Up @@ -966,6 +967,7 @@ def test_save_load(pc, tmp_path) -> None:
assert pc2.data.m2.attrs["kind"] == "model"
assert pc2.data.Observation.attrs["kind"] == "observation"

# TODO global attrs

# TODO raw_mod_data
assert pc2.name == "fake point obs"
assert pc2.gtype == "point"
assert len(pc2.raw_mod_data["m1"]) == 6
assert isinstance(pc2.raw_mod_data["m2"], PointModelResult)

0 comments on commit 1a7d6ac

Please sign in to comment.