Skip to content

Commit

Permalink
Experiment with duckdb for persistence
Browse files Browse the repository at this point in the history
  • Loading branch information
ecomodeller committed Jan 9, 2025
1 parent 5e11449 commit 32485ae
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 35 deletions.
104 changes: 69 additions & 35 deletions modelskill/comparison/_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -1250,35 +1250,42 @@ def to_dataframe(self) -> pd.DataFrame:
raise NotImplementedError(f"Unknown gtype: {self.gtype}")

def save(self, filename: Union[str, Path]) -> None:
"""Save to netcdf file
"""Save to duckdb file
Parameters
----------
filename : str or Path
filename
"""
ext = Path(filename).suffix
ds = self.data
if ext == ".db":
import duckdb

con = duckdb.connect(filename)
# TODO figure out how to save the x, y, z coordinates and other attributes later
df = ds.to_dataframe().drop(columns=["x", "y", "z"]).reset_index() # noqa
duckdb.sql("CREATE TABLE data AS SELECT * FROM df", connection=con)
con.close()
elif ext == ".nc":
if self.gtype == "point":
ds = self.data.copy() # copy needed to avoid modifying self.data

# add self.raw_mod_data to ds with prefix 'raw_' to avoid name conflicts
# an alternative strategy would be to use NetCDF groups
# https://docs.xarray.dev/en/stable/user-guide/io.html#groups

# There is no need to save raw data for track data, since it is identical to the matched data
if self.gtype == "point":
ds = self.data.copy() # copy needed to avoid modifying self.data
for key, ts_mod in self.raw_mod_data.items():
ts_mod = ts_mod.copy()
# rename time to unique name
ts_mod.data = ts_mod.data.rename({"time": "_time_raw_" + key})
# da = ds_mod.to_xarray()[key]
ds["_raw_" + key] = ts_mod.data[key]

for key, ts_mod in self.raw_mod_data.items():
ts_mod = ts_mod.copy()
# rename time to unique name
ts_mod.data = ts_mod.data.rename({"time": "_time_raw_" + key})
# da = ds_mod.to_xarray()[key]
ds["_raw_" + key] = ts_mod.data[key]
ds.to_netcdf(filename)

ds.to_netcdf(filename)
else:
raise NotImplementedError(f"Unknown extension: {ext}")

@staticmethod
def load(filename: Union[str, Path]) -> "Comparer":
"""Load from netcdf file
"""Load from duckdb file
Parameters
----------
Expand All @@ -1289,30 +1296,57 @@ def load(filename: Union[str, Path]) -> "Comparer":
-------
Comparer
"""
with xr.open_dataset(filename) as ds:
data = ds.load()

if data.gtype == "track":
return Comparer(matched_data=data)
# get extension
ext = Path(filename).suffix

if data.gtype == "point":
raw_mod_data: Dict[str, PointModelResult] = {}
if ext == ".db":
import duckdb

for var in data.data_vars:
var_name = str(var)
if var_name[:5] == "_raw_":
new_key = var_name[5:] # remove prefix '_raw_'
ds = data[[var_name]].rename(
{"_time_raw_" + new_key: "time", var_name: new_key}
)
ts = PointModelResult(data=ds, name=new_key)
con = duckdb.connect(filename)
df = duckdb.sql("SELECT * FROM data", connection=con).df().set_index("time")

# convert pandas dataframe to xarray dataset
ds = xr.Dataset.from_dataframe(df)

# set observation attribute
ds.Observation.attrs["kind"] = "observation"

raw_mod_data[new_key] = ts
# set model attributes
for key in ds.data_vars:
if key != "Observation":
ds[key].attrs["kind"] = "model"

# filter variables, only keep the ones with a 'time' dimension
data = data[[v for v in data.data_vars if "time" in data[v].dims]]
# TODO figure out aux variables

return Comparer(matched_data=data, raw_mod_data=raw_mod_data)
return Comparer(matched_data=ds)
elif ext == ".nc":
with xr.open_dataset(filename) as ds:
data = ds.load()

if data.gtype == "track":
return Comparer(matched_data=data)

if data.gtype == "point":
raw_mod_data: Dict[str, PointModelResult] = {}

for var in data.data_vars:
var_name = str(var)
if var_name[:5] == "_raw_":
new_key = var_name[5:] # remove prefix '_raw_'
ds = data[[var_name]].rename(
{"_time_raw_" + new_key: "time", var_name: new_key}
)
ts = PointModelResult(data=ds, name=new_key)

raw_mod_data[new_key] = ts

# filter variables, only keep the ones with a 'time' dimension
data = data[[v for v in data.data_vars if "time" in data[v].dims]]

return Comparer(matched_data=data, raw_mod_data=raw_mod_data)

else:
raise NotImplementedError(f"Unknown gtype: {data.gtype}")
else:
raise NotImplementedError(f"Unknown gtype: {data.gtype}")
raise NotImplementedError(f"Unknown extension: {ext}")
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ dependencies = [
"netCDF4",
"scipy",
"jinja2",
"duckdb>=1.1.3",
]

authors = [
Expand Down
11 changes: 11 additions & 0 deletions tests/test_comparer.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,3 +951,14 @@ def test_from_matched_non_scalar_xy_fails():
x=df.lon,
y=df.lat,
)


def test_save_load(pc, tmp_path) -> None:
fp = tmp_path / "test.db"

pc.save(fp)
pc2 = Comparer.load(fp)

assert "m1" in pc2.mod_names
assert "m2" in pc2.mod_names
assert pc2.n_points == 5

0 comments on commit 32485ae

Please sign in to comment.