Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add temporal bounds and center times for group_average() API #717

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 33 additions & 6 deletions xcdat/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from xcdat import bounds # noqa: F401
from xcdat._logger import _setup_custom_logger
from xcdat.axis import get_dim_coords
from xcdat.axis import center_times, get_dim_coords
from xcdat.dataset import _get_data_var

logger = _setup_custom_logger(__name__)
Expand Down Expand Up @@ -876,7 +876,7 @@ def _averager(
if self._mode == "average":
dv_avg = self._average(ds, data_var)
elif self._mode in ["group_average", "climatology", "departures"]:
dv_avg = self._group_average(ds, data_var)
dv_avg, time_bnds = self._group_average(ds, data_var)

# The original time dimension is dropped from the dataset because
# it becomes obsolete after the data variable is averaged. When the
Expand All @@ -885,6 +885,12 @@ def _averager(
ds = ds.drop_dims(self.dim)
ds[dv_avg.name] = dv_avg

if self._mode in ["group_average", "climatology", "departures"]:
ds[time_bnds.name] = time_bnds
# FIXME: This is not working when time bounds are datetime and
# time is cftime.
ds = center_times(ds)

if keep_weights:
ds = self._keep_weights(ds)

Expand Down Expand Up @@ -1475,7 +1481,9 @@ def _average(self, ds: xr.Dataset, data_var: str) -> xr.DataArray:

return dv

def _group_average(self, ds: xr.Dataset, data_var: str) -> xr.DataArray:
def _group_average(
self, ds: xr.Dataset, data_var: str
) -> Tuple[xr.DataArray, xr.DataArray]:
"""Averages a data variable by time group.

Parameters
Expand All @@ -1487,7 +1495,7 @@ def _group_average(self, ds: xr.Dataset, data_var: str) -> xr.DataArray:

Returns
-------
xr.DataArray
Tuple[xr.DataArray, xr.DataArray]
The data variable averaged by time group.
"""
dv = _get_data_var(ds, data_var)
Expand All @@ -1496,9 +1504,9 @@ def _group_average(self, ds: xr.Dataset, data_var: str) -> xr.DataArray:
# values.
self._labeled_time = self._label_time_coords(dv[self.dim])
dv = dv.assign_coords({self.dim: self._labeled_time})
time_bounds = ds.bounds.get_bounds("T", var_key=data_var)

if self._weighted:
time_bounds = ds.bounds.get_bounds("T", var_key=data_var)
self._weights = self._get_weights(time_bounds)

# Weight the data variable.
Expand All @@ -1522,6 +1530,25 @@ def _group_average(self, ds: xr.Dataset, data_var: str) -> xr.DataArray:
else:
dv = self._group_data(dv).mean()

"""I think we'll need to collect the bounds for each group, (e.g., group_bounds_array = [("2000-01-01 00:00", "2000-01-02 00:00"), ("2000-01-02 00:00", "2000-01-03 00:00"), ..., ("2000-01-31 00:00", "2000-02-01 00:00")] and then take the min of the lower bound and the max of the upper bound (i.e., group_bnd = [np.min(groups_bound_array[:, 0]), np.max(group_bounds_array[:, 1])].
"""
# Create time bounds for each group
time_bounds_grouped = self._group_data(time_bounds)
group_bounds = []

for _, group_data in time_bounds_grouped:
group_times = group_data.values
group_bnds = (np.min(group_times[:, 0]), np.max(group_times[:, 1]))
group_bounds.append(group_bnds)

# Convert group bounds to DataArray
da_bnds = xr.DataArray(
data=np.array(group_bounds),
dims=[self.dim, "bnds"],
coords={self.dim: dv[self.dim].values},
name=f"{self.dim}_bnds",
)

# After grouping and aggregating, the grouped time dimension's
# attributes are removed. Xarray's `keep_attrs=True` option only keeps
# attributes for data variables and not their coordinates, so the
Expand All @@ -1531,7 +1558,7 @@ def _group_average(self, ds: xr.Dataset, data_var: str) -> xr.DataArray:

dv = self._add_operation_attrs(dv)

return dv
return dv, da_bnds

def _get_weights(self, time_bounds: xr.DataArray) -> xr.DataArray:
"""Calculates weights for a data variable using time bounds.
Expand Down
Loading