diff --git a/dabench/dacycler/_dacycler.py b/dabench/dacycler/_dacycler.py index 9aa4047..9f24a91 100644 --- a/dabench/dacycler/_dacycler.py +++ b/dabench/dacycler/_dacycler.py @@ -53,52 +53,78 @@ def cycle(self, input_state, start_time, obs_vector, - timesteps, + n_cycles, analysis_window, - analysis_time_in_window=None): + analysis_time_in_window=None, + return_forecast=False): """Perform DA cycle repeatedly, including analysis and forecast Args: input_state (vector.StateVector): Input state. start_time (float or datetime-like): Starting time. obs_vector (vector.ObsVector): Observations vector. - timesteps (int): Number of timesteps, in model time. + n_cycles (int): Number of analysis cycles to run, each of length + analysis_window. analysis_window (float): Time window from which to gather observations for DA Cycle. analysis_time_in_window (float): Where within analysis_window to perform analysis. For example, 0.0 is the start of the window. Default is None, which selects the middle of the window. + return_forecast (bool): If True, returns forecast at each model + timestep. If False, returns only analyses, one per analysis + cycle. Default is False. Returns: vector.StateVector of analyses and times. """ + # If don't specify analysis_time_in_window, is assumed to be middle if analysis_time_in_window is None: analysis_time_in_window = analysis_window/2 + # Time offset from middle of time window, for gathering observations + _time_offset = (analysis_window/2) - analysis_time_in_window + + # Number of model steps to run per window + steps_per_window = round(analysis_window/self.delta_t) + 1 + # For storing outputs - all_analyses = [] + all_output_states = [] all_times = [] - cur_time = start_time + analysis_time_in_window + cur_time = start_time cur_state = input_state - for i in range(timesteps): - # 1. Filter observations to plus/minus 0.1 from that time + for i in range(n_cycles): + # 1. Filter observations to inside analysis window + window_middle = cur_time + _time_offset + window_start = window_middle - analysis_window/2 + window_end = window_middle + analysis_window/2 obs_vec_timefilt = obs_vector.filter_times( - cur_time - analysis_window/2, cur_time + analysis_window/2) + window_start, window_end + ) if obs_vec_timefilt.values.shape[0] > 0: # 2. Calculate analysis analysis, kh = self._step_cycle(cur_state, obs_vec_timefilt) - # 3. Forecast next timestep - cur_state = self._step_forecast(analysis) + # 3. Forecast through analysis window + forecast_states = self._step_forecast(analysis, + n_steps=steps_per_window) # 4. Save outputs - all_analyses.append(analysis.values) - all_times.append(cur_time) - - cur_time += self.delta_t - - return vector.StateVector(values=np.stack(all_analyses), - times=np.array(all_times)) + if return_forecast: + # Append forecast to current state, excluding last step + all_output_states.append(forecast_states.values[:-1]) + all_times.append( + np.arange(steps_per_window-1)*self.delta_t + cur_time + ) + else: + all_output_states.append(analysis.values[np.newaxis]) + all_times.append([cur_time]) + + # Starting point for next cycle is last step of forecast + cur_state = forecast_states[-1] + cur_time += analysis_window + + return vector.StateVector(values=np.concatenate(all_output_states), + times=np.concatenate(all_times)) diff --git a/dabench/dacycler/_etkf.py b/dabench/dacycler/_etkf.py index fa712a8..ed509e6 100644 --- a/dabench/dacycler/_etkf.py +++ b/dabench/dacycler/_etkf.py @@ -7,6 +7,7 @@ from jax.scipy import linalg from dabench import dacycler, vector +import dabench.dacycler._utils as dac_utils class ETKF(dacycler.DACycler): @@ -60,18 +61,20 @@ def __init__(self, ensemble=True, B=B, R=R, H=H, h=h) - def _step_cycle(self, xb, yo, H=None, h=None, R=None, B=None): + def _step_cycle(self, xb, yo, obs_time_mask, obs_loc_mask, + H=None, h=None, R=None, B=None): if H is not None or h is None: vals, kh = self._cycle_obsop( - xb.values, yo.values, yo.location_indices, yo.error_sd, - H, R, B) + xb.values, yo.values, yo.location_indices, yo.error_sd, obs_time_mask, + obs_loc_mask, H, R, B) return vector.StateVector(values=vals, store_as_jax=True), kh else: return self._cycle_general_obsop(xb, yo, h, R, B) def _calc_default_H(self, obs_values, obs_loc_indices): H = jnp.zeros((obs_values.flatten().shape[0], self.system_dim)) - H = H.at[jnp.arange(H.shape[0]), obs_loc_indices.flatten() + H = H.at[jnp.arange(H.shape[0]), + obs_loc_indices.flatten() ].set(1) return H @@ -82,6 +85,7 @@ def _calc_default_B(self): return jnp.identity(self.system_dim) def _cycle_obsop(self, Xbt, obs_values, obs_loc_indices, obs_error_sd, + obs_time_mask, obs_loc_mask, H=None, h=None, R=None, B=None): if H is None and h is None: if self.H is None: @@ -107,6 +111,10 @@ def _cycle_obsop(self, Xbt, obs_values, obs_loc_indices, obs_error_sd, 'cycle:: model_forecast must have dimension {}x{}').format( self.ensemble_dim, self.system_dim) + # Apply obs masks to H + H = jnp.where(obs_time_mask, H.T, 0).T + H = jnp.where(obs_loc_mask.flatten(), H.T, 0).T + # Analysis cycles over all obs in data_obs Xa = self._compute_analysis(Xb=Xbt.T, y=obs_values, @@ -117,15 +125,17 @@ def _cycle_obsop(self, Xbt, obs_values, obs_loc_indices, obs_error_sd, return Xa.T, 0 - def _step_forecast(self, xa): + def _step_forecast(self, xa, n_steps): data_forecast = [] for i in range(self.ensemble_dim): new_vals = self.model_obj.forecast( - vector.StateVector(values=xa.values[i], store_as_jax=True) + vector.StateVector(values=xa.values[i], store_as_jax=True), + n_steps=n_steps ).values data_forecast.append(new_vals) - return vector.StateVector(values=jnp.stack(data_forecast), + out_vals = jnp.moveaxis(jnp.stack(data_forecast), [0,1,2],[1,0,2]) + return vector.StateVector(values=out_vals, store_as_jax=True) def _apply_obsop(self, Xb, H, h): @@ -182,11 +192,11 @@ def _compute_analysis(self, Xb, y, H, h, R, rho=1.0, Yb=None): # Compute the analysis if len(R) > 0: - Rinv = jnp.linalg.pinv(R, rcond=1e-15) + Rinv = jnp.linalg.pinv(R, rtol=1e-15) Pa_ens = jnp.linalg.pinv((ensemble_dim-1)/rho*I + Yb_pert.T @ Rinv @ Yb_pert, - rcond=1e-15) + rtol=1e-15) Wa = linalg.sqrtm((ensemble_dim-1) * Pa_ens) Wa = Wa.real else: @@ -206,49 +216,63 @@ def _compute_analysis(self, Xb, y, H, h, R, rho=1.0, Yb=None): return Xa def _cycle_and_forecast(self, state_obs_tuple, filtered_idx): + # 1. Get data cur_state_vals = state_obs_tuple[0] obs_vals = state_obs_tuple[1] obs_times = state_obs_tuple[2] obs_loc_indices = state_obs_tuple[3] - obs_error_sd = state_obs_tuple[4] + obs_loc_masks = state_obs_tuple[4] + obs_error_sd = state_obs_tuple[5] + + # 1-b. Calculate obs_time_mask and restore filtered_idx to original values + obs_time_mask = jnp.repeat(filtered_idx > 0, obs_loc_indices.shape[1]) + filtered_idx = filtered_idx - 1 # 2. Calculate analysis - new_obs_vals = jax.lax.dynamic_slice_in_dim(obs_vals, filtered_idx[0], - len(filtered_idx)) - new_obs_loc_indices = jax.lax.dynamic_slice_in_dim( - obs_loc_indices, filtered_idx[0], len(filtered_idx)) + new_obs_vals = obs_vals[filtered_idx] + new_obs_loc_indices = obs_loc_indices[filtered_idx] + new_obs_loc_mask = obs_loc_masks[filtered_idx] analysis, kh = self._step_cycle( vector.StateVector(values=cur_state_vals, store_as_jax=True), vector.ObsVector(values=new_obs_vals, location_indices=new_obs_loc_indices, - error_sd=obs_error_sd, store_as_jax=True) + error_sd=obs_error_sd, store_as_jax=True), + obs_loc_mask=new_obs_loc_mask, + obs_time_mask=obs_time_mask ) # 3. Forecast next timestep - cur_state = self._step_forecast(analysis) + forecast_states = self._step_forecast(analysis, n_steps=self.steps_per_window) + next_state = forecast_states.values[-1] - return (cur_state.values, obs_vals, obs_times, obs_loc_indices, - obs_error_sd), cur_state.values + return (next_state, obs_vals, obs_times, obs_loc_indices, + obs_loc_masks, obs_error_sd), forecast_states.values[:-1] def cycle(self, input_state, start_time, obs_vector, - timesteps, + n_cycles, obs_error_sd=None, - analysis_window=0.2): + analysis_window=0.2, + analysis_time_in_window=None, + return_forecast=False): """Perform DA cycle repeatedly, including analysis and forecast Args: input_state (vector.StateVector): Input state. start_time (float or datetime-like): Starting time. obs_vector (vector.ObsVector): Observations vector. - timesteps (int): Number of timesteps, in model time. + n_cycles (int): Number of analysis cycles to run, each of length + analysis_window. analysis_window (float): Time window from which to gather observations for DA Cycle. analysis_time_in_window (float): Where within analysis_window to perform analysis. For example, 0.0 is the start of the window. Default is None, which selects the middle of the window. + return_forecast (bool): If True, returns forecast at each model + timestep. If False, returns only analyses, one per analysis + cycle. Default is False. Returns: vector.StateVector of analyses and times. @@ -257,26 +281,62 @@ def cycle(self, if obs_error_sd is None: obs_error_sd = obs_vector.error_sd self.analysis_window = analysis_window - all_times = (jnp.repeat(start_time, timesteps) - + (jnp.arange(0, timesteps)*self.delta_t)) + + # If don't specify analysis_time_in_window, is assumed to be middle + if analysis_time_in_window is None: + analysis_time_in_window = analysis_window/2 + + # Steps per window + 1 to include start + self.steps_per_window = round(analysis_window/self.delta_t) + 1 + + # Time offset from middle of time window, for gathering observations + _time_offset = (analysis_window/2) - analysis_time_in_window + + # Set up for jax.lax.scan, which is very fast + all_times = dac_utils._get_all_times( + start_time, + analysis_window, + n_cycles) + + # Get the obs vectors for each analysis window - all_filtered_idx = jnp.stack([jnp.where( - # Greater than start of window - (obs_vector.times > cur_time - analysis_window/2) - # AND Less than end of window - * (obs_vector.times < cur_time + analysis_window/2) - # AND not equal to end of window - * (1-jnp.isclose(obs_vector.times, cur_time + analysis_window/2, - rtol=0)) - # OR Equal to start of window - + jnp.isclose(obs_vector.times, cur_time - analysis_window/2, - rtol=0) - )[0] for cur_time in all_times]) - cur_state, all_values = jax.lax.scan( - self._cycle_and_forecast, - (input_state.values, obs_vector.values, obs_vector.times, - obs_vector.location_indices, obs_error_sd), - all_filtered_idx) - - return vector.StateVector(values=jnp.stack(all_values), - times=all_times) + all_filtered_idx = dac_utils._get_obs_indices( + obs_times=obs_vector.times, + analysis_times=all_times+_time_offset, + start_inclusive=True, + end_inclusive=False, + analysis_window=analysis_window + ) + + all_filtered_padded = dac_utils._pad_time_indices(all_filtered_idx, add_one=True) + + # Padding observations + if obs_vector.stationary_observers: + obs_loc_masks = jnp.ones(obs_vector.values.shape, dtype=bool) + cur_state, all_values = jax.lax.scan( + self._cycle_and_forecast, + (input_state.values, obs_vector.values, obs_vector.times, + obs_vector.location_indices, obs_loc_masks, obs_error_sd), + all_filtered_padded) + else: + obs_vals, obs_locs, obs_loc_masks = dac_utils._pad_obs_locs(obs_vector) + cur_state, all_values = jax.lax.scan( + self._cycle_and_forecast, + (input_state.values, obs_vals, obs_vector.times, + obs_locs, obs_loc_masks, obs_error_sd), + all_filtered_padded) + + + if return_forecast: + all_times_forecast = jnp.arange( + 0, + n_cycles*analysis_window, + self.delta_t + ) + start_time + return vector.StateVector(values=jnp.concatenate(all_values), + times=all_times_forecast) + else: + return vector.StateVector(values=jnp.vstack([ + forecast[0][jnp.newaxis] for forecast in all_values] + ), + times=all_times) diff --git a/dabench/dacycler/_utils.py b/dabench/dacycler/_utils.py new file mode 100644 index 0000000..147eec3 --- /dev/null +++ b/dabench/dacycler/_utils.py @@ -0,0 +1,143 @@ +"""Utils for data assimilation cyclers""" + +import jax.numpy as jnp +import numpy as np + + +def _get_all_times( + start_time, + analysis_window, + analysis_cycles, + ): + """Calculate times of the centers of all analysis windows. + + Args: + start_time (float): Start time of DA experiment in model time units. + analysis_window (float): Length of analysis window, in model time + units. + analysis_cycles (int): Number of analysis cycles to perform. + + Returns: + array of all analysis window center-times. + + + """ + all_times = ( + jnp.repeat(start_time, analysis_cycles) + + jnp.arange(0, analysis_cycles*analysis_window, + analysis_window) + ) + + return all_times + + +def _get_obs_indices( + analysis_times, + obs_times, + analysis_window, + start_inclusive=True, + end_inclusive=False + ): + """Get indices of obs times for each analysis cycle to pass to jax.lax.scan + + Args: + analysis_times (list): List of times for all analysis window, centered + in middle of time window. Output of _get_all_times(). + obs_times (list): List of times for all observations. + analysis_window (float): Length of analysis window. + start_inclusive (bool): Include obs times equal to beginning of + analysis window. Default is True + end_inclusive (bool): Include obs times equal to end of + analysis window. Default is False. + + Returns: + list with each element containing array of obs indices for the + corresponding analysis cycle. + """ + # Get the obs vectors for each analysis window + all_filtered_idx = [jnp.where( + # Greater than start of window + (obs_times > cur_time - analysis_window/2) + # AND Less than end of window + * (obs_times < cur_time + analysis_window/2) + # AND not equal to start of window + * (1-(1-start_inclusive)*jnp.isclose(obs_times, cur_time - analysis_window/2, + rtol=0)) + # AND not equal to end of window + * (1-(1-end_inclusive)*jnp.isclose(obs_times, cur_time + analysis_window/2, + rtol=0)) + # OR Equal to start of window end + + start_inclusive*jnp.isclose(obs_times, cur_time - analysis_window/2, + rtol=0) + # OR Equal to end of window + + end_inclusive*jnp.isclose(obs_times, cur_time + analysis_window/2, + rtol=0) + )[0] for cur_time in analysis_times] + + return all_filtered_idx + + +def _pad_time_indices( + obs_indices, + add_one=True + ): + """Pad observation indices for each analysis window. + + Args: + obs_indices (list): List of arrays where each array contains + obs indices for an analysis cycle. Result of _get_obs_indices. + add_one (bool): If True, will add one to all index values to encode + indices to be masked out for DA (i.e. zeros represent indices to + be masked out). Default is True. + + Returns: + padded_indices (array): Array of padded obs_indices, with shape: + (num_analysis_cycles, max_obs_per_cycle). + """ + + def resize(row, size, add_one): + new = np.array(row) + add_one + new.resize(size) + return new + + # find longest row length + row_length = max(obs_indices, key=len).__len__() + padded_indices = np.array([resize(row, row_length, add_one) for row in obs_indices]) + + return padded_indices + + +def _pad_obs_locs(obs_vec): + """Pad observation location indices to equal spacing + + Args: + obs_vec (dabench.vector.ObsVector): Observation vector + object containing times, locations, and values of obs. + + Returns: + (vals, locs, masks): Tuple containing padded arrays of obs + values and locations, and binary array masks where 1 is + a valid observation value/location and 0 is not. + """ + + def resize(row, size): + new_vals_locs = np.array(np.stack(row), order='F') + new_vals_locs.resize((new_vals_locs.shape[0], size)) + mask = np.ones_like(new_vals_locs[0]).astype(int) + if size > len(row[0]): + mask[-(size-len(row[0])):] = 0 + return np.vstack([new_vals_locs, mask]).T + + # Find longest row length + row_length = max(obs_vec.values, key=len).__len__() + padded_arrays_masks = np.array([resize(row, row_length) for row in + np.stack([obs_vec.values, + obs_vec.location_indices], + axis=1)], dtype=float) + vals, locs, masks = (padded_arrays_masks[...,0], + padded_arrays_masks[...,1:-1].astype(int), + padded_arrays_masks[...,2].astype(bool)) + if locs.shape[-1] == 1: + locs = locs[..., 0] + + return vals, locs, masks \ No newline at end of file diff --git a/dabench/dacycler/_var3d.py b/dabench/dacycler/_var3d.py index 3f4a25d..f84fd87 100644 --- a/dabench/dacycler/_var3d.py +++ b/dabench/dacycler/_var3d.py @@ -51,12 +51,6 @@ def __init__(self, def _step_cycle(self, xb, yo, H=None, h=None, R=None, B=None): """Perform one step of DA Cycle - Args: - xb: - yo: - H - - Returns: vector.StateVector containing analysis results @@ -127,8 +121,8 @@ def _cycle_linear_obsop(self, forecast, obs_vec, H=None, R=None, HBHtPlusR_inv = jnp.linalg.inv(H @ BHt + R) KH = BHt @ HBHtPlusR_inv @ H - return vector.StateVector(values=xa, store_as_jax=True), KH + return vector.StateVector(values=xa.T[0], store_as_jax=True), KH - def _step_forecast(self, xa): - """One step of the forecast.""" - return self.model_obj.forecast(xa) + def _step_forecast(self, xa, n_steps): + """n_steps forward of model forecast""" + return self.model_obj.forecast(xa, n_steps=n_steps) diff --git a/dabench/dacycler/_var4d.py b/dabench/dacycler/_var4d.py index 29781bc..7dd6cee 100644 --- a/dabench/dacycler/_var4d.py +++ b/dabench/dacycler/_var4d.py @@ -1,6 +1,7 @@ """Class for Var 4D Data Assimilation Cycler object""" import inspect +import warnings import numpy as np import jax.numpy as jnp @@ -12,6 +13,7 @@ from functools import partial from dabench import dacycler, vector +import dabench.dacycler._utils as dac_utils class Var4D(dacycler.DACycler): @@ -40,11 +42,14 @@ class Var4D(dacycler.DACycler): 4DVar. Increasing this may result in higher accuracy but slower performance. Default is 1. steps_per_window (int): Number of timesteps per analysis window. + If None (default), will calculate automatically based on delta_t + and .cycle() analysis_window length. obs_window_indices (list): Timestep indices where observations fall within each analysis window. For example, if analysis window is 0 - 0.05 with delta_t = 0.01 and observations fall at 0, 0.01, 0.02, 0.03, 0.04, and 0.05, obs_window_indices = - [0, 1, 2, 3, 4, 5]. + [0, 1, 2, 3, 4, 5]. If None (default), will calculate + automatically. """ def __init__(self, @@ -58,7 +63,7 @@ def __init__(self, solver='bicgstab', n_outer_loops=1, steps_per_window=1, - obs_window_indices=[0], + obs_window_indices=None, **kwargs ): @@ -67,6 +72,10 @@ def __init__(self, self.n_outer_loops = n_outer_loops self.solver = solver + # Var4D requires H to be a JAX array + if H is not None: + H = jnp.array(H) + super().__init__(system_dim=system_dim, delta_t=delta_t, model_obj=model_obj, @@ -74,20 +83,24 @@ def __init__(self, ensemble=False, B=B, R=R, H=H, h=h) - def _calc_default_H(self, obs_values, obs_loc_indices): - H = jnp.zeros((obs_values[0].shape[0], self.system_dim)) - H = H.at[jnp.arange(H.shape[0]), obs_loc_indices[0] - ].set(1) - return H + def _calc_default_H(self, obs_loc_indices): + Hs = jnp.zeros((obs_loc_indices.shape[0], obs_loc_indices.shape[1], + self.system_dim), + dtype=int) + for i in range(Hs.shape[0]): + Hs = Hs.at[i, jnp.arange(Hs.shape[1]), obs_loc_indices + ].set(1) + return Hs + def _calc_default_R(self, obs_values, obs_error_sd): - return jnp.identity(obs_values.flatten().shape[0])*(obs_error_sd**2) + return jnp.identity(obs_values[0].shape[0])*(obs_error_sd**2) def _calc_default_B(self): return jnp.identity(self.system_dim) - def _make_outerloop_4d(self, xb0, H, B, Rinv, - obs_values, obs_window_indices, + def _make_outerloop_4d(self, xb0, Hs, B, Rinv, + obs_values, obs_window_indices, obs_time_mask, n_steps): def _outerloop_4d(x0, _): @@ -100,25 +113,42 @@ def _outerloop_4d(x0, _): ) # 4D-Var inner loop - x0 = self._innerloop_4d(self.system_dim, x, xb0, - obs_values, H, B, - Rinv, M, obs_window_indices) + x0 = self._innerloop_4d(self.system_dim, + x, xb0, obs_values, + Hs, B, Rinv, M, + obs_window_indices, + obs_time_mask) return x0, x0 return _outerloop_4d - def _cycle_obsop(self, xb0, obs_values, obs_loc_indices, - obs_error_sd, H=None, h=None, R=None, B=None, - obs_window_indices=None, n_steps=1): + def _cycle_obsop(self, xb0, obs_values, obs_loc_indices, obs_error_sd, + obs_window_indices, obs_time_mask, obs_loc_mask, + H=None, h=None, R=None, B=None, + n_steps=1): if H is None and h is None: if self.H is None: if self.h is None: - H = self._calc_default_H(obs_values, obs_loc_indices) + H = self._calc_default_H(obs_loc_indices) + # Apply obs loc mask + # NOTE: nonstationary observer case runs MUCH slower. Not sure why + # Ideally, this conditional would not be necessary, but this is a + # workaround to prevent slowing down stationary observer case. + Hs = jax.lax.cond( + self._obs_vector.stationary_observers, + lambda: H, + lambda: (obs_loc_mask[:, :, jnp.newaxis] * H)) else: h = self.h else: - H = self.H + # Assumes self.H is for a single timestep + H = self.H[jnp.newaxis] + Hs = jax.lax.cond( + self._obs_vector.stationary_observers, + lambda: jnp.repeat(H, obs_values.shape[0], axis=0), + lambda: (obs_loc_mask[:, :, jnp.newaxis] * H)) + if R is None: if self.R is None: R = self._calc_default_R(obs_values, obs_error_sd) @@ -137,7 +167,8 @@ def _cycle_obsop(self, xb0, obs_values, obs_loc_indices, x0 = deepcopy(xb0) outerloop_4d_func = self._make_outerloop_4d( - xb0, H, B, Rinv, obs_values, obs_window_indices, n_steps) + xb0, Hs, B, Rinv, obs_values, obs_window_indices, + obs_time_mask, n_steps) x0, all_x0s = jax.lax.scan(outerloop_4d_func, init=x0, xs=None, length=self.n_outer_loops) @@ -148,21 +179,23 @@ def _cycle_obsop(self, xb0, obs_values, obs_loc_indices, x0=vector.StateVector(values=x0, store_as_jax=True) ).values - return x, None + return x - def step_cycle(self, x0, yo, obs_window_indices=[0], - H=None, h=None, R=None, B=None, + def step_cycle(self, x0, yo, obs_time_mask, obs_loc_mask, + obs_window_indices, H=None, h=None, R=None, B=None, n_steps=1): """Perform one step of DA Cycle""" if H is not None or h is None: return self._cycle_obsop( x0.values, yo.values, yo.location_indices, yo.error_sd, - H, R, B, obs_window_indices=obs_window_indices, + obs_loc_mask=obs_loc_mask, obs_time_mask=obs_time_mask, + obs_window_indices=obs_window_indices, + H=H, R=R, B=B, n_steps=n_steps) else: return self._cycle_obsop( - x0.values, yo.values, yo.location_indices, yo.error_sd, h, - R, B, obs_window_indices=obs_window_indices, + x0.values, yo.values, yo.location_indices, yo.error_sd, h=h, + R=R, B=B, obs_window_indices=obs_window_indices, n_steps=n_steps) def step_forecast(self, x0, n_steps=1): @@ -180,51 +213,38 @@ def step_forecast(self, x0, n_steps=1): out.append(xi) return vector.StateVector(jnp.vstack(xi), store_as_jax=True) - @partial(jax.jit, static_argnums=[0,1]) - def _innerloop_4d(self, system_dim, x, xb0, y, H, B, Rinv, M, - obs_window_indices=[0]): - """4DVar innerloop - Args: - system_dim (int): The dimension of the system state. - x (ndarray): Current best guess for trajectory. Updated each outer - loop. (time_dim, system_dim) - xb0 (ndarray): Initial background estimate for initial conditions. - Stays constant throughout analysis cycle. Shape: (system_dim,) - y (ndarray): Time array of observation. Shape: (num_obs, obs_dim) - H (ndarray): Observation operator matrix. Shape: - (obs_dim, system_dim) - B (ndarray): Background/forecast error covariance matrix. Shape: - (system_dim, system_dim) - Rinv (ndarray): Inverted observation error covariance matrix. Shape: - (obs_dim, obs_dim)] - M (ndarray): List of TLMs for each model timestep. Shape: - (time_dim,system_dim, system_dim) - obs_window_indices (ndarray): Indices of observations w.r.t. model - timesteps in analysis window. + def _calc_J_term(self, H, M, Rinv, y, x): + # The Jb Term (A) + HM = H @ M + MtHtRinv = HM.T @ Rinv - Returns: - xa0 (ndarray): inner loop estimate of optimal initial conditions. - Shape: (system_dim,) + # The Jo Term (b) + D = (y - (H @ x)) + return MtHtRinv @ HM, MtHtRinv @ D[:, None] - """ + + @partial(jax.jit, static_argnums=[0, 1]) + def _innerloop_4d(self, system_dim, x, xb0, obs_vals, Hs, B, Rinv, M, + obs_window_indices, obs_time_mask): + """4DVar innerloop""" x0_last = x[0] # Set up Variables - SumMtHtRinvD = jnp.zeros((system_dim, 1)) # b input SumMtHtRinvHM = jnp.zeros_like(B) # A input + SumMtHtRinvD = jnp.zeros((system_dim, 1)) # b input # Loop over observations for i, j in enumerate(obs_window_indices): - # The Jb Term (A) - HM = H @ M[j, :, :] - MtHtRinv = HM.T @ Rinv - SumMtHtRinvHM += MtHtRinv @ HM - - # The Jo Term (b) - D = y[i] - (H @ x[j]) - SumMtHtRinvD += MtHtRinv @ D[:, None] - + Jb, Jo = jax.lax.cond( + obs_time_mask.at[i].get(mode='fill', fill_value=0), + lambda: self._calc_J_term(Hs.at[i].get(mode='clip'), M[j], + Rinv, obs_vals[i], x[j]), + lambda: (jnp.zeros_like(SumMtHtRinvHM), + jnp.zeros_like(SumMtHtRinvD)) + ) + SumMtHtRinvHM += Jb + SumMtHtRinvD += Jo # Compute initial departure db0 = xb0 - x0_last @@ -266,86 +286,142 @@ def _solve(self, db0, SumMtHtRinvHM, SumMtHtRinvD, B): return dx0 - @partial(jax.jit, static_argnums=0) - def _cycle_and_forecast(self, cur_state_vals, filtered_idx): - obs_vals = self._obs_vector.values - obs_loc_indices = self._obs_vector.location_indices + def _cycle_and_forecast(self, cur_state_vals_time_tuple, filtered_idx): + cur_state_vals, cur_time = cur_state_vals_time_tuple obs_error_sd = self._obs_error_sd - cur_obs_vals = jax.lax.dynamic_slice_in_dim(obs_vals, filtered_idx[0], - len(filtered_idx)) - cur_obs_loc_indices = jax.lax.dynamic_slice_in_dim(obs_loc_indices, - filtered_idx[0], - len(filtered_idx)) - analysis, kh = self.step_cycle( + # Calculate obs_time_mask and restore filtered_idx to original values + obs_time_mask = filtered_idx > 0 + filtered_idx = filtered_idx - 1 + + cur_obs_vals = jnp.array(self._obs_vector.values).at[filtered_idx].get() + cur_obs_loc_indices = jnp.array(self._obs_vector.location_indices).at[filtered_idx].get() + cur_obs_times = jnp.array(self._obs_vector.times).at[filtered_idx].get() + cur_obs_loc_mask = jnp.array(self._obs_loc_masks).at[filtered_idx].get().astype(bool) + + # Calculate obs window indices: closest model timesteps that match obs + obs_window_indices = jax.lax.cond( + self.obs_window_indices is None, + lambda: jnp.array([ + jnp.argmin( + jnp.abs(obs_time - (cur_time + self._model_timesteps)) + ) for obs_time in cur_obs_times + ]), + lambda: jnp.array(self.obs_window_indices) + ) + + analysis = self.step_cycle( vector.StateVector(values=cur_state_vals, store_as_jax=True), vector.ObsVector(values=cur_obs_vals, location_indices=cur_obs_loc_indices, error_sd=obs_error_sd, store_as_jax=True), + obs_time_mask=obs_time_mask, + obs_loc_mask=cur_obs_loc_mask, n_steps=self.steps_per_window, - obs_window_indices=self.obs_window_indices) - + obs_window_indices=obs_window_indices) + new_time = cur_time + self.analysis_window - return analysis[-1], analysis[:-1] + return (analysis[-1], new_time), analysis[:-1] def cycle(self, input_state, start_time, obs_vector, obs_error_sd, - timesteps, + n_cycles, analysis_window, - analysis_time_in_window=None): + analysis_time_in_window=0, + return_forecast=False): """Perform DA cycle repeatedly, including analysis and forecast Args: input_state (vector.StateVector): Input state. + start_time (float or datetime-like): Starting time. obs_vector (vector.ObsVector): Observations vector. obs_error_sd (float): Standard deviation of observation error. Typically not known, so provide a best-guess. - start_time (float or datetime-like): Starting time. - timesteps (int): Number of timesteps, in model time. - analysis_window (float): Time window from which to gather - observations for DA Cycle. - analysis_time_in_window (float): Where within analysis_window + n_cycles (int): Number of analysis cycles to run, each of length + analysis_window. + analysis_window (float): Length of time window from which to gather + observations for each DA Cycle, in model time units. + analysis_time_in_window (float): At what time within analysis_window to perform analysis. For example, 0.0 is the start of the - window. Default is None, which selects the middle of the - window. + window. Default is 0, the start of the window. + return_forecast (bool): If True, returns forecast at each model + timestep. If False, returns only analyses, one per analysis + cycle. Default is False. Returns: vector.StateVector of analyses and times. """ + if (not obs_vector.stationary_observers and + (self.H is not None or self.h is not None)): + warnings.warn( + "Provided obs vector has nonstationary observers. When" + " providing a custom obs operator (H/h), the Var4DBackprop" + "DA cycler may not function properly. If you encounter " + "errors, try again with an observer where" + "stationary_observers=True or without specifying H or h (a " + "default H matrix will be used to map observations to system " + "space)." + ) + self.analysis_window = analysis_window + # If don't specify analysis_time_in_window, is assumed to be middle if analysis_time_in_window is None: - analysis_time_in_window = analysis_window/2 + analysis_time_in_window = self.analysis_window/2 + + # Time offset from middle of time window, for gathering observations + _time_offset = (analysis_window/2) - analysis_time_in_window # Set up for jax.lax.scan, which is very fast - all_times = ( - jnp.repeat(start_time + analysis_time_in_window, timesteps) - + jnp.arange(0, timesteps*analysis_window, - analysis_window) - ) + all_times = dac_utils._get_all_times(start_time, analysis_window, + n_cycles) + + if self.steps_per_window is None: + self.steps_per_window = round(analysis_window/self.delta_t) + 1 + self._model_timesteps = jnp.arange(self.steps_per_window)*self.delta_t + # Get the obs vectors for each analysis window - all_filtered_idx = jnp.stack([jnp.where( - # Greater than start of window - (obs_vector.times > cur_time - analysis_window/2) - # AND Less than end of window - * (obs_vector.times < cur_time + analysis_window/2) - # OR Equal to start of window end - + jnp.isclose(obs_vector.times, cur_time - analysis_window/2, - rtol=0) - # OR Equal to end of window - + jnp.isclose(obs_vector.times, cur_time + analysis_window/2, - rtol=0) - )[0] for cur_time in all_times]) + all_filtered_idx = dac_utils._get_obs_indices( + obs_times=obs_vector.times, + analysis_times=all_times+_time_offset, + start_inclusive=True, + end_inclusive=True, + analysis_window=analysis_window + ) + + all_filtered_padded = dac_utils._pad_time_indices(all_filtered_idx) self._obs_vector = obs_vector self._obs_error_sd = obs_error_sd + + # Padding observations + if obs_vector.stationary_observers: + self._obs_loc_masks = jnp.ones(obs_vector.values.shape, dtype=bool) + else: + obs_vals, obs_locs, obs_loc_masks = dac_utils._pad_obs_locs( + obs_vector) + self._obs_vector.values = obs_vals + self._obs_vector.location_indices = obs_locs + self._obs_loc_masks = jnp.array(obs_loc_masks) + cur_state, all_values = jax.lax.scan( self._cycle_and_forecast, - init=input_state.values, - xs=all_filtered_idx) - - return vector.StateVector(values=jnp.vstack(all_values), - store_as_jax=True) + init=(input_state.values, start_time), + xs=all_filtered_padded) + + if return_forecast: + all_times_forecast = jnp.arange( + 0, + n_cycles*analysis_window, + self.delta_t + ) + start_time + return vector.StateVector(values=jnp.concatenate(all_values), + times=all_times_forecast) + else: + return vector.StateVector(values=jnp.vstack([ + forecast[0] for forecast in all_values] + ), + times=all_times) diff --git a/dabench/dacycler/_var4d_backprop.py b/dabench/dacycler/_var4d_backprop.py index 9b4d4a0..161f66a 100644 --- a/dabench/dacycler/_var4d_backprop.py +++ b/dabench/dacycler/_var4d_backprop.py @@ -1,6 +1,7 @@ """Class for Var 4D Backpropagation Data Assimilation Cycler object""" import inspect +import warnings import numpy as np import jax.numpy as jnp @@ -9,8 +10,10 @@ from jax.scipy import optimize import jax import optax +from functools import partial from dabench import dacycler, vector +import dabench.dacycler._utils as dac_utils class Var4DBackprop(dacycler.DACycler): @@ -37,15 +40,22 @@ class Var4DBackprop(dacycler.DACycler): num_iters (int): Number of iterations for backpropagation per analysis cycle. Default is 3. steps_per_window (int): Number of timesteps per analysis window. - learning_rate (float): LR for backpropogation. Default is 1e-5, but + If None (default), will calculate automatically based on delta_t + and .cycle() analysis_window length. + learning_rate (float): LR for backpropogation. Default is 0.5, but DA results can be quite sensitive to this parameter. lr_decay (float): Exponential learning rate decay. If set to 1, - no decay. Default is 1. + no decay. Default is 0.5. obs_window_indices (list): Timestep indices where observations fall within each analysis window. For example, if analysis window is 0 - 0.05 with delta_t = 0.01 and observations fall at 0, 0.01, 0.02, 0.03, 0.04, and 0.05, obs_window_indices = - [0, 1, 2, 3, 4, 5]. + [0, 1, 2, 3, 4, 5]. If None (default), will calculate + automatically. + loss_growth_limit (float): If loss grows by more than this factor + during one analysis cycle, JAX will cut off computation and + return an error. This prevents it from hanging indefinitely + when loss grows exponentionally. Default is 10. """ def __init__(self, @@ -56,11 +66,11 @@ def __init__(self, R=None, H=None, h=None, - learning_rate=1e-5, - lr_decay=1.0, + learning_rate=0.5, + lr_decay=0.5, num_iters=3, - steps_per_window=1, - obs_window_indices=[0], + steps_per_window=None, + obs_window_indices=None, loss_growth_limit=10, **kwargs ): @@ -72,6 +82,10 @@ def __init__(self, self.obs_window_indices = obs_window_indices self.loss_growth_limit = loss_growth_limit + # Var4D Backprop requires H to be a JAX array + if H is not None: + H = jnp.array(H) + super().__init__(system_dim=system_dim, delta_t=delta_t, model_obj=model_obj, @@ -79,14 +93,19 @@ def __init__(self, ensemble=False, B=B, R=R, H=H, h=h) - def _calc_default_H(self, obs_values, obs_loc_indices): - H = jnp.zeros((obs_values[0].shape[0], self.system_dim)) - H = H.at[jnp.arange(H.shape[0]), obs_loc_indices[0] - ].set(1) - return H + + def _calc_default_H(self, obs_loc_indices): + Hs = jnp.zeros((obs_loc_indices.shape[0], obs_loc_indices.shape[1], + self.system_dim), + dtype=int) + for i in range(Hs.shape[0]): + Hs = Hs.at[i, jnp.arange(Hs.shape[1]), obs_loc_indices + ].set(1) + + return Hs def _calc_default_R(self, obs_values, obs_error_sd): - return jnp.identity(obs_values.flatten().shape[0])*(obs_error_sd**2) + return jnp.identity(obs_values[0].shape[0])*(obs_error_sd**2) def _calc_default_B(self): return jnp.identity(self.system_dim) @@ -101,7 +120,16 @@ def _callback_raise_error(self, error_method, loss_val): jax.debug.callback(error_method) return loss_val - def _make_loss(self, xb0, obs_vals, Ht, Binv, Rinv, obs_window_indices, n_steps): + @partial(jax.jit, static_argnums=[0]) + def _calc_obs_term(self, pred_x, obs_vals, Ht, Rinv): + pred_obs = pred_x @ Ht + resid = pred_obs.ravel() - obs_vals.ravel() + + return jnp.sum(resid.T @ Rinv @ resid) + + def _make_loss(self, xb0, obs_vals, Hs, Binv, Rinv, + obs_window_indices, + obs_time_mask, n_steps): """Define loss function based on 4dvar cost""" @jax.jit @@ -117,10 +145,13 @@ def loss_4dvarcost(x0): # Calculate observation term of J_0 obs_term = 0 for i, j in enumerate(obs_window_indices): - pred_obs = pred_x[j] @ Ht - resid = pred_obs.ravel() - obs_vals[i].ravel() - - obs_term += np.sum(resid.T @ Rinv @ resid) + obs_term += jax.lax.cond( + obs_time_mask.at[i].get(mode='fill', fill_value=0), + lambda: self._calc_obs_term(pred_x[j], obs_vals[i], + Hs.at[i].get(mode='clip').T, + Rinv), + lambda: 0.0 + ) # Calculate initial departure term of J_0 based on original x0 initial_term = (db0.T @ Binv @ db0) @@ -129,50 +160,64 @@ def loss_4dvarcost(x0): loss_val = initial_term + obs_term return jax.lax.cond( jnp.isnan(loss_val), - lambda: self._callback_raise_error(self._raise_nan_error, loss_val), + lambda: self._callback_raise_error(self._raise_nan_error, + loss_val), lambda: loss_val) return loss_4dvarcost - def _make_backprop_epoch(self, loss_func, optimizer, hessian_inv): loss_value_grad = value_and_grad(loss_func, argnums=0) @jax.jit - def _backprop_epoch(epoch_state_tuple, _): - x0, xb0, init_loss, i, opt_state = epoch_state_tuple + def _backprop_epoch(epoch_state_tuple, i): + x0, init_loss, opt_state = epoch_state_tuple loss_val, dx0 = loss_value_grad(x0) dx0_hess = hessian_inv @ dx0 init_loss = jax.lax.cond( - i==0, + i == 0, lambda: loss_val, lambda: init_loss) loss_val = jax.lax.cond( loss_val/init_loss > self.loss_growth_limit, - lambda: self._callback_raise_error(self._raise_loss_growth_error, loss_val), + lambda: self._callback_raise_error( + self._raise_loss_growth_error, loss_val), lambda: loss_val) - + updates, opt_state = optimizer.update(dx0_hess, opt_state) x0_new = optax.apply_updates(x0, updates) - return (x0_new, x0, init_loss, i+1, opt_state), loss_val + return (x0_new, init_loss, opt_state), loss_val return _backprop_epoch def _cycle_obsop(self, x0, obs_values, obs_loc_indices, obs_error_sd, + obs_time_mask, obs_loc_mask, H=None, h=None, R=None, B=None, obs_window_indices=None, n_steps=1): if H is None and h is None: if self.H is None: if self.h is None: - H = self._calc_default_H(obs_values, obs_loc_indices) + H = self._calc_default_H(obs_loc_indices) + # Apply obs loc mask + # NOTE: nonstationary observer case runs MUCH slower. Not sure why + # Ideally, this conditional would not be necessary, but this is a + # workaround to prevent slowing down stationary observer case. + Hs = jax.lax.cond( + self._obs_vector.stationary_observers, + lambda: H, + lambda: (obs_loc_mask[:, :, jnp.newaxis] * H)) else: h = self.h else: - H = self.H - Ht = H.T + # Assumes self.H is for a single timestep + H = self.H[jnp.newaxis] + Hs = jax.lax.cond( + self._obs_vector.stationary_observers, + lambda: jnp.repeat(H, obs_values.shape[0], axis=0), + lambda: (obs_loc_mask[:, :, jnp.newaxis] * H)) if R is None: if self.R is None: @@ -191,17 +236,17 @@ def _cycle_obsop(self, x0, obs_values, obs_loc_indices, obs_error_sd, Binv = jscipy.linalg.inv(B) # Compute Hessian - hessian_inv = jscipy.linalg.inv(Binv + Ht @ Rinv @ H) - - # Get initial observations and jacobian + hessian_inv = jscipy.linalg.inv( + Binv + Hs.at[0].get().T @ Rinv @ Hs.at[0].get()) loss_func = self._make_loss( x0, obs_values, - Ht, + Hs, Binv, Rinv, obs_window_indices, + obs_time_mask, n_steps=n_steps) lr = optax.exponential_decay( @@ -212,12 +257,13 @@ def _cycle_obsop(self, x0, obs_values, obs_loc_indices, obs_error_sd, opt_state = optimizer.init(x0) # Make initial forecast and calculate loss - backprop_epoch_func = self._make_backprop_epoch(loss_func, optimizer, hessian_inv) + backprop_epoch_func = self._make_backprop_epoch(loss_func, optimizer, + hessian_inv) epoch_state_tuple, loss_vals = jax.lax.scan( - backprop_epoch_func, init=(x0, x0, 0., 0, opt_state), - xs=None, length=self.num_iters) + backprop_epoch_func, init=(x0, 0., opt_state), + xs=jnp.arange(self.num_iters)) - x0, xb0, init_loss, i, opt_state = epoch_state_tuple + x0, init_loss, opt_state = epoch_state_tuple xa = self.step_forecast( vector.StateVector(values=x0, store_as_jax=True), @@ -225,13 +271,16 @@ def _cycle_obsop(self, x0, obs_values, obs_loc_indices, obs_error_sd, return xa, loss_vals - def step_cycle(self, xb, yo, H=None, h=None, R=None, B=None, n_steps=1, - obs_window_indices=[0]): + def step_cycle(self, xb, yo, obs_time_mask, obs_loc_mask, + obs_window_indices, H=None, h=None, R=None, B=None, + n_steps=1): """Perform one step of DA Cycle""" if H is not None or h is None: return self._cycle_obsop( xb.values, yo.values, yo.location_indices, yo.error_sd, - H, R, B, obs_window_indices=obs_window_indices, n_steps=n_steps) + obs_time_mask=obs_time_mask, obs_loc_mask=obs_loc_mask, + H=H, R=R, B=B, + obs_window_indices=obs_window_indices, n_steps=n_steps) else: return self._cycle_obsop( xb, yo, h, R, B, obs_window_indices=obs_window_indices, @@ -252,88 +301,144 @@ def step_forecast(self, xa, n_steps=1): out.append(xi) return vector.StateVector(jnp.vstack(xi), store_as_jax=True) - def _cycle_and_forecast(self, cur_state_vals, filtered_idx): - obs_vals = self._obs_vector.values - obs_loc_indices = self._obs_vector.location_indices + def _cycle_and_forecast(self, cur_state_vals_time_tuple, filtered_idx): + cur_state_vals, cur_time = cur_state_vals_time_tuple obs_error_sd = self._obs_error_sd - cur_obs_vals = jax.lax.dynamic_slice_in_dim(obs_vals, filtered_idx[0], - len(filtered_idx)) - cur_obs_loc_indices = jax.lax.dynamic_slice_in_dim(obs_loc_indices, - filtered_idx[0], - len(filtered_idx)) + # Calculate obs_time_mask and restore filtered_idx to original values + obs_time_mask = filtered_idx > 0 + filtered_idx = filtered_idx - 1 + + cur_obs_vals = jnp.array(self._obs_vector.values).at[filtered_idx].get() + cur_obs_loc_indices = jnp.array(self._obs_vector.location_indices).at[filtered_idx].get() + cur_obs_times = jnp.array(self._obs_vector.times).at[filtered_idx].get() + cur_obs_loc_mask = jnp.array(self._obs_loc_masks).at[filtered_idx].get().astype(bool) + + # Calculate obs window indices: closest model timesteps that match obs + obs_window_indices = jax.lax.cond( + self.obs_window_indices is None, + lambda: jnp.array([ + jnp.argmin( + jnp.abs(obs_time - (cur_time + self._model_timesteps)) + ) for obs_time in cur_obs_times + ]), + lambda: jnp.array(self.obs_window_indices) + ) + analysis, loss_vals = self.step_cycle( vector.StateVector(values=cur_state_vals, store_as_jax=True), vector.ObsVector(values=cur_obs_vals, location_indices=cur_obs_loc_indices, error_sd=obs_error_sd, store_as_jax=True), + obs_time_mask=obs_time_mask, + obs_loc_mask=cur_obs_loc_mask, n_steps=self.steps_per_window, - obs_window_indices=self.obs_window_indices) + obs_window_indices=obs_window_indices) + new_time = cur_time + self.analysis_window - return analysis.values[-1], (analysis.values[:-1], loss_vals) + return (analysis.values[-1], new_time), (analysis.values[:-1], loss_vals) def cycle(self, input_state, start_time, obs_vector, obs_error_sd, - timesteps, + n_cycles, analysis_window, - analysis_time_in_window=None): + analysis_time_in_window=0, + return_forecast=False): """Perform DA cycle repeatedly, including analysis and forecast Args: input_state (vector.StateVector): Input state. + start_time (float or datetime-like): Starting time. obs_vector (vector.ObsVector): Observations vector. obs_error_sd (float): Standard deviation of observation error. Typically not known, so provide a best-guess. - start_time (float or datetime-like): Starting time. - timesteps (int): Number of timesteps, in model time. - analysis_window (float): Time window from which to gather - observations for DA Cycle. - analysis_time_in_window (float): Where within analysis_window + n_cycles (int): Number of analysis cycles to run, each of length + analysis_window. + analysis_window (float): Length of time window from which to gather + observations for each DA Cycle, in model time units. + analysis_time_in_window (float): At what time within analysis_window to perform analysis. For example, 0.0 is the start of the - window. Default is None, which selects the middle of the - window. + window. Default is 0, the start of the window. + return_forecast (bool): If True, returns forecast at each model + timestep. If False, returns only analyses, one per analysis + cycle. Default is False. Returns: vector.StateVector of analyses and times. """ - + if (not obs_vector.stationary_observers and + (self.H is not None or self.h is not None)): + warnings.warn( + "Provided obs vector has nonstationary observers. When" + " providing a custom obs operator (H/h), the Var4DBackprop" + "DA cycler may not function properly. If you encounter " + "errors, try again with an observer where" + "stationary_observers=True or without specifying H or h (a " + "default H matrix will be used to map observations to system " + "space)." + ) + self.analysis_window = analysis_window + + # If don't specify analysis_time_in_window, is assumed to be middle if analysis_time_in_window is None: - analysis_time_in_window = analysis_window/2 + analysis_time_in_window = self.analysis_window/2 + + # Time offset from middle of time window, for gathering observations + _time_offset = (analysis_window/2) - analysis_time_in_window # Set up for jax.lax.scan, which is very fast - all_times = ( - jnp.repeat(start_time + analysis_time_in_window, timesteps) - + jnp.arange(0, timesteps*analysis_window, - analysis_window) - ) + all_times = dac_utils._get_all_times(start_time, analysis_window, + n_cycles) + + if self.steps_per_window is None: + self.steps_per_window = round(analysis_window/self.delta_t) + 1 + self._model_timesteps = jnp.arange(self.steps_per_window)*self.delta_t + # Get the obs vectors for each analysis window - all_filtered_idx = jnp.stack([jnp.where( - # Greater than start of window - (obs_vector.times > cur_time - analysis_window/2) - # AND Less than end of window - * (obs_vector.times < cur_time + analysis_window/2) - # OR Equal to start of window end - + jnp.isclose(obs_vector.times, cur_time - analysis_window/2, - rtol=0) - # OR Equal to end of window - + jnp.isclose(obs_vector.times, cur_time + analysis_window/2, - rtol=0) - )[0] for cur_time in all_times]) + all_filtered_idx = dac_utils._get_obs_indices( + obs_times=obs_vector.times, + analysis_times=all_times+_time_offset, + start_inclusive=True, + end_inclusive=True, + analysis_window=analysis_window + ) + + all_filtered_padded = dac_utils._pad_time_indices(all_filtered_idx) self._obs_vector = obs_vector self._obs_error_sd = obs_error_sd - cur_state, all_values = jax.lax.scan( + + # Padding observations + if obs_vector.stationary_observers: + self._obs_loc_masks = jnp.ones(obs_vector.values.shape, dtype=bool) + else: + obs_vals, obs_locs, obs_loc_masks = dac_utils._pad_obs_locs( + obs_vector) + self._obs_vector.values = obs_vals + self._obs_vector.location_indices = obs_locs + self._obs_loc_masks = jnp.array(obs_loc_masks) + + cur_state, all_results = jax.lax.scan( self._cycle_and_forecast, - init=input_state.values, - xs=all_filtered_idx) - all_losses = all_values[1] - print(all_losses[:, -3:]) - all_values = all_values[0] - - return vector.StateVector( - values=jnp.vstack(all_values), - store_as_jax=True) + init=(input_state.values, start_time), + xs=all_filtered_padded) + self.loss_values = all_results[1] + all_values = all_results[0] + + if return_forecast: + all_times_forecast = jnp.arange( + 0, + n_cycles*analysis_window, + self.delta_t + ) + start_time + return vector.StateVector(values=jnp.concatenate(all_values), + times=all_times_forecast) + else: + return vector.StateVector(values=jnp.vstack([ + forecast[0] for forecast in all_values] + ), + times=all_times) diff --git a/dabench/dacycler/_var4d_backprop_exacthessian.py b/dabench/dacycler/_var4d_backprop_exacthessian.py index ca29b3f..3d7858f 100644 --- a/dabench/dacycler/_var4d_backprop_exacthessian.py +++ b/dabench/dacycler/_var4d_backprop_exacthessian.py @@ -1,6 +1,7 @@ """Class for Var 4D Backpropagation Data Assimilation Cycler object""" import inspect +import warnings import numpy as np import jax.numpy as jnp @@ -9,12 +10,14 @@ from jax.scipy import optimize import jax import optax +from functools import partial from dabench import dacycler, vector +import dabench.dacycler._utils as dac_utils class Var4DBackpropExactHessian(dacycler.DACycler): - """Class for building Backpropagation 4D DA Cycler + """Class for building Backpropagation 4D DA Cycler with exact hessian Attributes: system_dim (int): System dimension. @@ -37,15 +40,21 @@ class Var4DBackpropExactHessian(dacycler.DACycler): num_iters (int): Number of iterations for backpropagation per analysis cycle. Default is 3. steps_per_window (int): Number of timesteps per analysis window. - learning_rate (float): LR for backpropogation. Default is 1e-5, but - DA results can be quite sensitive to this parameter. + If None (default), will calculate automatically based on delta_t + and .cycle() analysis_window length. + learning_rate (float): LR for backpropogation. Default is 1. lr_decay (float): Exponential learning rate decay. If set to 1, no decay. Default is 1. obs_window_indices (list): Timestep indices where observations fall within each analysis window. For example, if analysis window is 0 - 0.05 with delta_t = 0.01 and observations fall at 0, 0.01, 0.02, 0.03, 0.04, and 0.05, obs_window_indices = - [0, 1, 2, 3, 4, 5]. + [0, 1, 2, 3, 4, 5]. If None (default), will calculate + automatically. + loss_growth_limit (float): If loss grows by more than this factor + during one analysis cycle, JAX will cut off computation and + return an error. This prevents it from hanging indefinitely + when loss grows exponentionally. Default is 10. """ def __init__(self, @@ -56,11 +65,11 @@ def __init__(self, R=None, H=None, h=None, - learning_rate=1e-5, + learning_rate=1.0, lr_decay=1.0, num_iters=3, - steps_per_window=1, - obs_window_indices=[0], + steps_per_window=None, + obs_window_indices=None, loss_growth_limit=10, **kwargs ): @@ -72,6 +81,10 @@ def __init__(self, self.obs_window_indices = obs_window_indices self.loss_growth_limit = loss_growth_limit + # Var4D Backprop requires H to be a JAX array + if H is not None: + H = jnp.array(H) + super().__init__(system_dim=system_dim, delta_t=delta_t, model_obj=model_obj, @@ -79,14 +92,19 @@ def __init__(self, ensemble=False, B=B, R=R, H=H, h=h) - def _calc_default_H(self, obs_values, obs_loc_indices): - H = jnp.zeros((obs_values[0].shape[0], self.system_dim)) - H = H.at[jnp.arange(H.shape[0]), obs_loc_indices[0] - ].set(1) - return H + + def _calc_default_H(self, obs_loc_indices): + Hs = jnp.zeros((obs_loc_indices.shape[0], obs_loc_indices.shape[1], + self.system_dim), + dtype=int) + for i in range(Hs.shape[0]): + Hs = Hs.at[i, jnp.arange(Hs.shape[1]), obs_loc_indices + ].set(1) + + return Hs def _calc_default_R(self, obs_values, obs_error_sd): - return jnp.identity(obs_values.flatten().shape[0])*(obs_error_sd**2) + return jnp.identity(obs_values[0].shape[0])*(obs_error_sd**2) def _calc_default_B(self): return jnp.identity(self.system_dim) @@ -101,7 +119,16 @@ def _callback_raise_error(self, error_method, loss_val): jax.debug.callback(error_method) return loss_val - def _make_loss(self, xb0, obs_vals, Ht, Binv, Rinv, obs_window_indices, n_steps): + @partial(jax.jit, static_argnums=[0]) + def _calc_obs_term(self, pred_x, obs_vals, Ht, Rinv): + pred_obs = pred_x @ Ht + resid = pred_obs.ravel() - obs_vals.ravel() + + return jnp.sum(resid.T @ Rinv @ resid) + + def _make_loss(self, xb0, obs_vals, Hs, Binv, Rinv, + obs_window_indices, + obs_time_mask, n_steps): """Define loss function based on 4dvar cost""" @jax.jit @@ -117,10 +144,13 @@ def loss_4dvarcost(x0): # Calculate observation term of J_0 obs_term = 0 for i, j in enumerate(obs_window_indices): - pred_obs = pred_x[j] @ Ht - resid = pred_obs.ravel() - obs_vals[i].ravel() - - obs_term += np.sum(resid.T @ Rinv @ resid) + obs_term += jax.lax.cond( + obs_time_mask.at[i].get(mode='fill', fill_value=0), + lambda: self._calc_obs_term(pred_x[j], obs_vals[i], + Hs.at[i].get(mode='clip').T, + Rinv), + lambda: 0.0 + ) # Calculate initial departure term of J_0 based on original x0 initial_term = (db0.T @ Binv @ db0) @@ -129,52 +159,66 @@ def loss_4dvarcost(x0): loss_val = initial_term + obs_term return jax.lax.cond( jnp.isnan(loss_val), - lambda: self._callback_raise_error(self._raise_nan_error, loss_val), + lambda: self._callback_raise_error(self._raise_nan_error, + loss_val), lambda: loss_val) return loss_4dvarcost - def _make_backprop_epoch(self, loss_func, optimizer): loss_value_grad = value_and_grad(loss_func, argnums=0) hessian = jax.hessian(loss_func, argnums=0) @jax.jit - def _backprop_epoch(epoch_state_tuple, _): - x0, xb0, init_loss, i, opt_state = epoch_state_tuple + def _backprop_epoch(epoch_state_tuple, i): + x0, init_loss, opt_state = epoch_state_tuple loss_val, dx0 = loss_value_grad(x0) hessian_inv = jscipy.linalg.inv(hessian(x0)) dx0_hess = hessian_inv @ dx0 init_loss = jax.lax.cond( - i==0, + i == 0, lambda: loss_val, lambda: init_loss) loss_val = jax.lax.cond( loss_val/init_loss > self.loss_growth_limit, - lambda: self._callback_raise_error(self._raise_loss_growth_error, loss_val), + lambda: self._callback_raise_error( + self._raise_loss_growth_error, loss_val), lambda: loss_val) - + updates, opt_state = optimizer.update(dx0_hess, opt_state) x0_new = optax.apply_updates(x0, updates) - return (x0_new, x0, init_loss, i+1, opt_state), loss_val + return (x0_new, init_loss, opt_state), loss_val return _backprop_epoch def _cycle_obsop(self, x0, obs_values, obs_loc_indices, obs_error_sd, + obs_time_mask, obs_loc_mask, H=None, h=None, R=None, B=None, obs_window_indices=None, n_steps=1): if H is None and h is None: if self.H is None: if self.h is None: - H = self._calc_default_H(obs_values, obs_loc_indices) + H = self._calc_default_H(obs_loc_indices) + # Apply obs loc mask + # NOTE: nonstationary observer case runs MUCH slower. Not sure why + # Ideally, this conditional would not be necessary, but this is a + # workaround to prevent slowing down stationary observer case. + Hs = jax.lax.cond( + self._obs_vector.stationary_observers, + lambda: H, + lambda: (obs_loc_mask[:, :, jnp.newaxis] * H)) else: h = self.h else: - H = self.H - Ht = H.T + # Assumes self.H is for a single timestep + H = self.H[jnp.newaxis] + Hs = jax.lax.cond( + self._obs_vector.stationary_observers, + lambda: jnp.repeat(H, obs_values.shape[0], axis=0), + lambda: (obs_loc_mask[:, :, jnp.newaxis] * H)) if R is None: if self.R is None: @@ -195,10 +239,11 @@ def _cycle_obsop(self, x0, obs_values, obs_loc_indices, obs_error_sd, loss_func = self._make_loss( x0, obs_values, - Ht, + Hs, Binv, Rinv, obs_window_indices, + obs_time_mask, n_steps=n_steps) lr = optax.exponential_decay( @@ -210,11 +255,12 @@ def _cycle_obsop(self, x0, obs_values, obs_loc_indices, obs_error_sd, # Make initial forecast and calculate loss backprop_epoch_func = self._make_backprop_epoch(loss_func, optimizer) + epoch_state_tuple, loss_vals = jax.lax.scan( - backprop_epoch_func, init=(x0, x0, 0., 0, opt_state), - xs=None, length=self.num_iters) + backprop_epoch_func, init=(x0, 0., opt_state), + xs=jnp.arange(self.num_iters)) - x0, xb0, init_loss, i, opt_state = epoch_state_tuple + x0, init_loss, opt_state = epoch_state_tuple xa = self.step_forecast( vector.StateVector(values=x0, store_as_jax=True), @@ -222,13 +268,16 @@ def _cycle_obsop(self, x0, obs_values, obs_loc_indices, obs_error_sd, return xa, loss_vals - def step_cycle(self, xb, yo, H=None, h=None, R=None, B=None, n_steps=1, - obs_window_indices=[0]): + def step_cycle(self, xb, yo, obs_time_mask, obs_loc_mask, + obs_window_indices, H=None, h=None, R=None, B=None, + n_steps=1): """Perform one step of DA Cycle""" if H is not None or h is None: return self._cycle_obsop( xb.values, yo.values, yo.location_indices, yo.error_sd, - H, R, B, obs_window_indices=obs_window_indices, n_steps=n_steps) + obs_time_mask=obs_time_mask, obs_loc_mask=obs_loc_mask, + H=H, R=R, B=B, + obs_window_indices=obs_window_indices, n_steps=n_steps) else: return self._cycle_obsop( xb, yo, h, R, B, obs_window_indices=obs_window_indices, @@ -249,88 +298,144 @@ def step_forecast(self, xa, n_steps=1): out.append(xi) return vector.StateVector(jnp.vstack(xi), store_as_jax=True) - def _cycle_and_forecast(self, cur_state_vals, filtered_idx): - obs_vals = self._obs_vector.values - obs_loc_indices = self._obs_vector.location_indices + def _cycle_and_forecast(self, cur_state_vals_time_tuple, filtered_idx): + cur_state_vals, cur_time = cur_state_vals_time_tuple obs_error_sd = self._obs_error_sd - cur_obs_vals = jax.lax.dynamic_slice_in_dim(obs_vals, filtered_idx[0], - len(filtered_idx)) - cur_obs_loc_indices = jax.lax.dynamic_slice_in_dim(obs_loc_indices, - filtered_idx[0], - len(filtered_idx)) + # Calculate obs_time_mask and restore filtered_idx to original values + obs_time_mask = filtered_idx > 0 + filtered_idx = filtered_idx - 1 + + cur_obs_vals = jnp.array(self._obs_vector.values).at[filtered_idx].get() + cur_obs_loc_indices = jnp.array(self._obs_vector.location_indices).at[filtered_idx].get() + cur_obs_times = jnp.array(self._obs_vector.times).at[filtered_idx].get() + cur_obs_loc_mask = jnp.array(self._obs_loc_masks).at[filtered_idx].get().astype(bool) + + # Calculate obs window indices: closest model timesteps that match obs + obs_window_indices = jax.lax.cond( + self.obs_window_indices is None, + lambda: jnp.array([ + jnp.argmin( + jnp.abs(obs_time - (cur_time + self._model_timesteps)) + ) for obs_time in cur_obs_times + ]), + lambda: jnp.array(self.obs_window_indices) + ) + analysis, loss_vals = self.step_cycle( vector.StateVector(values=cur_state_vals, store_as_jax=True), vector.ObsVector(values=cur_obs_vals, location_indices=cur_obs_loc_indices, error_sd=obs_error_sd, store_as_jax=True), + obs_time_mask=obs_time_mask, + obs_loc_mask=cur_obs_loc_mask, n_steps=self.steps_per_window, - obs_window_indices=self.obs_window_indices) + obs_window_indices=obs_window_indices) + new_time = cur_time + self.analysis_window - return analysis.values[-1], (analysis.values[:-1], loss_vals) + return (analysis.values[-1], new_time), (analysis.values[:-1], loss_vals) def cycle(self, input_state, start_time, obs_vector, obs_error_sd, - timesteps, + n_cycles, analysis_window, - analysis_time_in_window=None): + analysis_time_in_window=0, + return_forecast=False): """Perform DA cycle repeatedly, including analysis and forecast Args: input_state (vector.StateVector): Input state. + start_time (float or datetime-like): Starting time. obs_vector (vector.ObsVector): Observations vector. obs_error_sd (float): Standard deviation of observation error. Typically not known, so provide a best-guess. - start_time (float or datetime-like): Starting time. - timesteps (int): Number of timesteps, in model time. - analysis_window (float): Time window from which to gather - observations for DA Cycle. - analysis_time_in_window (float): Where within analysis_window + n_cycles (int): Number of analysis cycles to run, each of length + analysis_window. + analysis_window (float): Length of time window from which to gather + observations for each DA Cycle, in model time units. + analysis_time_in_window (float): At what time within analysis_window to perform analysis. For example, 0.0 is the start of the - window. Default is None, which selects the middle of the - window. + window. Default is 0, the start of the window. + return_forecast (bool): If True, returns forecast at each model + timestep. If False, returns only analyses, one per analysis + cycle. Default is False. Returns: vector.StateVector of analyses and times. """ - + if (not obs_vector.stationary_observers and + (self.H is not None or self.h is not None)): + warnings.warn( + "Provided obs vector has nonstationary observers. When" + " providing a custom obs operator (H/h), the Var4DBackprop" + "DA cycler may not function properly. If you encounter " + "errors, try again with an observer where" + "stationary_observers=True or without specifying H or h (a " + "default H matrix will be used to map observations to system " + "space)." + ) + self.analysis_window = analysis_window + + # If don't specify analysis_time_in_window, is assumed to be middle if analysis_time_in_window is None: - analysis_time_in_window = analysis_window/2 + analysis_time_in_window = self.analysis_window/2 + + # Time offset from middle of time window, for gathering observations + _time_offset = (analysis_window/2) - analysis_time_in_window # Set up for jax.lax.scan, which is very fast - all_times = ( - jnp.repeat(start_time + analysis_time_in_window, timesteps) - + jnp.arange(0, timesteps*analysis_window, - analysis_window) - ) + all_times = dac_utils._get_all_times(start_time, analysis_window, + n_cycles) + + if self.steps_per_window is None: + self.steps_per_window = round(analysis_window/self.delta_t) + 1 + self._model_timesteps = jnp.arange(self.steps_per_window)*self.delta_t + # Get the obs vectors for each analysis window - all_filtered_idx = jnp.stack([jnp.where( - # Greater than start of window - (obs_vector.times > cur_time - analysis_window/2) - # AND Less than end of window - * (obs_vector.times < cur_time + analysis_window/2) - # OR Equal to start of window end - + jnp.isclose(obs_vector.times, cur_time - analysis_window/2, - rtol=0) - # OR Equal to end of window - + jnp.isclose(obs_vector.times, cur_time + analysis_window/2, - rtol=0) - )[0] for cur_time in all_times]) + all_filtered_idx = dac_utils._get_obs_indices( + obs_times=obs_vector.times, + analysis_times=all_times+_time_offset, + start_inclusive=True, + end_inclusive=True, + analysis_window=analysis_window + ) + + all_filtered_padded = dac_utils._pad_time_indices(all_filtered_idx) self._obs_vector = obs_vector self._obs_error_sd = obs_error_sd - cur_state, all_values = jax.lax.scan( + + # Padding observations + if obs_vector.stationary_observers: + self._obs_loc_masks = jnp.ones(obs_vector.values.shape, dtype=bool) + else: + obs_vals, obs_locs, obs_loc_masks = dac_utils._pad_obs_locs( + obs_vector) + self._obs_vector.values = obs_vals + self._obs_vector.location_indices = obs_locs + self._obs_loc_masks = jnp.array(obs_loc_masks) + + cur_state, all_results = jax.lax.scan( self._cycle_and_forecast, - init=input_state.values, - xs=all_filtered_idx) - all_losses = all_values[1] - print(all_losses[:, -3:]) - all_values = all_values[0] - - return vector.StateVector( - values=jnp.vstack(all_values), - store_as_jax=True) + init=(input_state.values, start_time), + xs=all_filtered_padded) + self.loss_values = all_results[1] + all_values = all_results[0] + + if return_forecast: + all_times_forecast = jnp.arange( + 0, + n_cycles*analysis_window, + self.delta_t + ) + start_time + return vector.StateVector(values=jnp.concatenate(all_values), + times=all_times_forecast) + else: + return vector.StateVector(values=jnp.vstack([ + forecast[0] for forecast in all_values] + ), + times=all_times) diff --git a/dabench/observer/_observer.py b/dabench/observer/_observer.py index b1fedb7..80998e6 100644 --- a/dabench/observer/_observer.py +++ b/dabench/observer/_observer.py @@ -3,6 +3,8 @@ Input is generated data, returns ObsVector with values, times, coords, etc """ +import warnings + import numpy as np import jax.numpy as jnp @@ -101,7 +103,17 @@ def __init__(self, self.stationary_observers = stationary_observers self.random_seed = random_seed - self.store_as_jax = store_as_jax + if (store_as_jax and self.random_location_density != 1. and + not self.stationary_observers): + warnings.warn( + "store_as_jax=True is not compatible with irregular" + "observations (i.e. stationary_observers = False AND" + "random_location_density != 1. Setting store_ax_jax" + " to False and proceeding.") + self.store_as_jax = False + else: + self.store_as_jax = store_as_jax + self.error_bias = error_bias self.error_sd = error_sd @@ -413,4 +425,5 @@ def observe(self): error_sd=self.error_sd, error_bias=self.error_bias, store_as_jax=self.store_as_jax, + stationary_observers=self.stationary_observers ) diff --git a/dabench/vector/_obs_vector.py b/dabench/vector/_obs_vector.py index 6b2a060..a228f39 100644 --- a/dabench/vector/_obs_vector.py +++ b/dabench/vector/_obs_vector.py @@ -37,6 +37,10 @@ class ObsVector(_vector._Vector): times (array): 1d array of times associated with each observation store_as_jax (bool): Store values as jax array instead of numpy array. Default is False (store as numpy). + stationary_observers (bool): If True, samples are from same indices at + each time step. If False, observations can be from different + indices, including irregular numbers of observations and different + time steps. """ def __init__(self, num_obs=None, @@ -51,6 +55,7 @@ def __init__(self, error_bias=None, times=None, store_as_jax=False, + stationary_observers=True, **kwargs): self.num_obs = num_obs @@ -59,6 +64,7 @@ def __init__(self, self.error_bias = error_bias self.time_indices = time_indices self.location_indices = location_indices + self.stationary_observers = stationary_observers super().__init__(times=times, store_as_jax=store_as_jax, diff --git a/tests/dacycler_base_test.py b/tests/dacycler_base_test.py new file mode 100644 index 0000000..f55a4bb --- /dev/null +++ b/tests/dacycler_base_test.py @@ -0,0 +1,19 @@ +"""Tests for base Data Assimilation Cycler class (dabench.dacycler._dacycler)""" + +import pytest +import dabench as dab + + +def test_dacycler_init(): + """Tests initialization of dacycler""" + + params = {'system_dim': 6, + 'delta_t': 0.5, + 'ensemble': True} + + test_dac = dab.dacycler.DACycler(**params) + + assert test_dac.system_dim == 6 + assert test_dac.delta_t == 0.5 + assert test_dac.ensemble + assert not test_dac.in_4d diff --git a/tests/dacycler_etkf_test.py b/tests/dacycler_etkf_test.py new file mode 100644 index 0000000..d9852fb --- /dev/null +++ b/tests/dacycler_etkf_test.py @@ -0,0 +1,106 @@ +"""Tests for ETKF Data Assimilation Cycler (dabench.dacycler._etkf)""" + +import pytest +import numpy as np +import jax.numpy as jnp +import jax.random as jrand +import dabench as dab + + +key = jrand.PRNGKey(42) + + +@pytest.fixture +def lorenz96(): + """Defines class Lorenz96 object for rest of tests.""" + l96 = dab.data.Lorenz96(system_dim=5, store_as_jax=True, delta_t=0.01) + l96.generate(n_steps=25) + + return l96 + +@pytest.fixture +def obs_vec_l96(lorenz96): + """Generate observations for rest of tests.""" + obs_l96 = dab.observer.Observer( + lorenz96, + time_indices=np.arange(0, 25, 5), + random_location_count=3, + error_bias=0.1, + error_sd=1.0, + random_seed=91, + stationary_observers=True, + store_as_jax=True + ) + + return obs_l96.observe() + +@pytest.fixture +def l96_fc_model(): + model_l96 = dab.data.Lorenz96(system_dim=5, store_as_jax=True) + + class L96Model(dab.model.Model): + """Defines model wrapper for Lorenz96 to test forecasting.""" + def forecast(self, state_vec, n_steps): + self.model_obj.generate(x0=state_vec.values, n_steps=n_steps) + new_vals = self.model_obj.values[:n_steps] + + new_vec = dab.vector.StateVector(values=new_vals, + store_as_jax=True) + + return new_vec + + return L96Model(model_obj=model_l96) + +@pytest.fixture +def etkf_cycler(l96_fc_model): + dc = dab.dacycler.ETKF( + system_dim=5, + delta_t=0.01, + ensemble_dim=8, + model_obj=l96_fc_model) + + return dc + +def test_etkf_l96(lorenz96, obs_vec_l96, etkf_cycler): + + cur_tstep=10 + init_noise = jrand.normal(key, shape=(8, 5)) + init_state = dab.vector.StateVector( + values=lorenz96.values[cur_tstep] + init_noise, + store_as_jax=True) + start_time = lorenz96.times[cur_tstep] + + out_sv = etkf_cycler.cycle( + input_state=init_state, + start_time=start_time, + obs_vector=obs_vec_l96, + obs_error_sd=1.5, + analysis_window=0.1, + n_cycles=10, + return_forecast=True + ) + + out_sv_mean = np.mean(out_sv.values, axis=1) + + assert out_sv.values.shape == (100, 8, 5) + assert out_sv_mean.shape == (100, 5) + # Check that ensemble members are different + assert not jnp.allclose( + out_sv.values[-1, 1, :], + out_sv.values[-1, 0, :], + ) + # Check first cycle against presaved results + assert jnp.allclose( + out_sv.values[0, 0, :], + jnp.array([-0.85402591, 1.03480315, 0.51005132, 6.61546551, 8.1166806]) + ) + # Check last cycle against presaved results + assert jnp.allclose( + out_sv.values[-1, 0, :], + jnp.array([0.66697948, 3.15465627, 5.39288975, -4.96130847, 2.17202611]) + ) + # Check mean against presaved results + assert jnp.allclose( + out_sv_mean[-1, :], + jnp.array([1.45024252, 3.81627191, 5.4507981, 1.21646539, 0.09439264]) + ) diff --git a/tests/dacycler_var3d_test.py b/tests/dacycler_var3d_test.py new file mode 100644 index 0000000..24ae1ff --- /dev/null +++ b/tests/dacycler_var3d_test.py @@ -0,0 +1,89 @@ +"""Tests for Var3D Data Assimilation Cycler (dabench.dacycler._var3d)""" + +import pytest +import jax.numpy as jnp +import jax.random as jrand +import dabench as dab + + +key = jrand.PRNGKey(42) + + +@pytest.fixture +def lorenz96(): + """Defines class Lorenz96 object for rest of tests.""" + l96 = dab.data.Lorenz96(system_dim=6, store_as_jax=True) + l96.generate(n_steps=50) + + return l96 + +@pytest.fixture +def obs_vec_l96(lorenz96): + """Generate observations for rest of tests.""" + obs_l96 = dab.observer.Observer( + lorenz96, + random_time_density = 0.7, + random_location_count = 3, + error_bias = 0.0, + error_sd = 0.7, + random_seed=94, + stationary_observers=True + ) + + return obs_l96.observe() + +@pytest.fixture +def l96_fc_model(): + model_l96 = dab.data.Lorenz96(system_dim=6) + + class L96Model(dab.model.Model): + """Defines model wrapper for Lorenz96 to test forecasting.""" + def forecast(self, state_vec, n_steps): + self.model_obj.generate(x0=state_vec.values, n_steps=n_steps) + new_vals = self.model_obj.values[:n_steps] + + new_vec = dab.vector.StateVector(values=new_vals, store_as_jax=True) + + return new_vec + + return L96Model(model_obj=model_l96) + +@pytest.fixture +def var3d_cycler(l96_fc_model): + dc = dab.dacycler.Var3D( + system_dim=6, + delta_t=0.05, + model_obj=l96_fc_model) + + return dc + +def test_var3d_l96(lorenz96, obs_vec_l96, var3d_cycler): + + # Adding some noise to our initial state and getting the start time in model units + init_noise = jrand.normal(key, shape=(6,)) + init_state = dab.vector.StateVector( + values=lorenz96.values[0] + init_noise, + store_as_jax=True) + start_time = lorenz96.times[0] + + # To run the experiment, we use the cycle() method: + out_sv = var3d_cycler.cycle( + input_state = init_state, + start_time = start_time, + obs_vector = obs_vec_l96, + n_cycles=10, + analysis_window=0.25, + return_forecast=False) + + assert out_sv.values.shape == (10, 6) + assert jnp.allclose( + out_sv.values[0], + # Presaved results + jnp.array([-0.90632236, -1.20861455, 1.64865068, + 5.11034063, 4.399881, -3.75779771]) + ) + assert jnp.allclose( + out_sv.values[-1], + jnp.array([3.92060079, 3.97290102, -0.763032, + -1.5979558, -0.0086728, 2.60395146]) + ) diff --git a/tests/dacycler_var4d_var4dbp_test.py b/tests/dacycler_var4d_var4dbp_test.py new file mode 100644 index 0000000..861eeca --- /dev/null +++ b/tests/dacycler_var4d_var4dbp_test.py @@ -0,0 +1,157 @@ +"""Tests for Var4D and Var4D-Backprop Data Assimilation Cyclers""" + +import pytest +import jax.numpy as jnp +import jax.random as jrand +import dabench as dab + + +key = jrand.PRNGKey(42) + + +@pytest.fixture +def lorenz96(): + """Defines class Lorenz96 object for rest of tests.""" + l96 = dab.data.Lorenz96(system_dim=6, store_as_jax=True, delta_t=0.01) + l96.generate(n_steps=120) + + return l96 + +@pytest.fixture +def obs_vec_l96(lorenz96): + """Generate observations for rest of tests.""" + obs_l96 = dab.observer.Observer( + lorenz96, + time_indices=jnp.arange(0, 120, 5), + random_location_count = 3, + error_bias = 0.0, + error_sd = 0.3, + random_seed=94, + stationary_observers=True, + store_as_jax=True + ) + + return obs_l96.observe() + +@pytest.fixture +def l96_fc_model(): + model_l96 = dab.data.Lorenz96(system_dim=6, store_as_jax=True) + + class L96Model(dab.model.Model): + """Defines model wrapper for Lorenz96 to test forecasting.""" + def forecast(self, state_vec, n_steps): + # NOTE: n_steps = 2 because the initial state counts as a "step" + self.model_obj.generate(x0=state_vec.values, n_steps=n_steps) + new_vals = self.model_obj.values + + new_vec = dab.vector.StateVector(values=new_vals, store_as_jax=True) + + return new_vec + def compute_tlm(self, state_vec, n_steps): + """Compute TLM""" + M = self.model_obj.generate(n_steps=n_steps, x0=state_vec.values, + return_tlm=True) + return M, self.model_obj.values + + return L96Model(model_obj=model_l96) + +@pytest.fixture +def var4d_cycler(l96_fc_model): + dc = dab.dacycler.Var4D( + system_dim=6, + delta_t=0.01, + model_obj=l96_fc_model, + obs_window_indices=[0,5, 10], + steps_per_window=11 + ) + + return dc + + +@pytest.fixture +def var4d_backprop_cycler(l96_fc_model): + B = jnp.identity(6)*0.05 + dc = dab.dacycler.Var4DBackprop( + system_dim=6, + delta_t=0.01, + model_obj=l96_fc_model, + obs_window_indices=[0,5, 10], + steps_per_window=11, + learning_rate=0.1, + lr_decay=0.5, + B=B + ) + + return dc + +def test_var4d_l96(lorenz96, obs_vec_l96, var4d_cycler): + """Test 4D-Var cycler""" + init_noise = jrand.normal(key, shape=(6,)) + init_state = dab.vector.StateVector( + values=lorenz96.values[0] + init_noise, + store_as_jax=True) + start_time = lorenz96.times[0] + + out_sv = var4d_cycler.cycle( + input_state = init_state, + start_time = start_time, + obs_vector = obs_vec_l96, + obs_error_sd=obs_vec_l96.error_sd*1.5, + n_cycles=10, + analysis_window=0.1, + return_forecast=True) + + assert out_sv.values.shape == (100, 6) + + # Check that timeseries is evolving + assert not jnp.allclose( + out_sv.values[0,:], + out_sv.values[5,:], + ) + # Check against presaved results + assert jnp.allclose( + out_sv.values[0,:], + jnp.array([4.5784335 , 10.70937771, 3.97859892, 0.25609285, -1.89681598, + -1.34747704]) + ) + assert jnp.allclose( + out_sv.values[-1,:], + jnp.array([-2.32350141, 2.66564733, 9.1592932 , 0.26887161, -2.72295144, + 1.24513147]) + ) + +def test_var4d_backprop_l96(lorenz96, obs_vec_l96, var4d_backprop_cycler): + """Test 4DVar-Backprop cycler""" + init_noise = jrand.normal(key, shape=(6,)) + init_state = dab.vector.StateVector( + values=lorenz96.values[0] + init_noise, + store_as_jax=True) + start_time = lorenz96.times[0] + + out_sv = var4d_backprop_cycler.cycle( + input_state = init_state, + start_time = start_time, + obs_vector = obs_vec_l96, + obs_error_sd=obs_vec_l96.error_sd*1.5, + n_cycles=10, + analysis_window=0.1, + return_forecast=True) + + assert out_sv.values.shape == (100, 6) + + # Check that timeseries is evolving + assert not jnp.allclose( + out_sv.values[0,:], + out_sv.values[5,:], + ) + # Check against presaved results + assert jnp.allclose( + out_sv.values[0,:], + jnp.array([4.53548062, 9.00637144, 3.07940726, 3.24252952, -2.77042587, + -2.01121753]) + ) + assert jnp.allclose( + out_sv.values[-1,:], + jnp.array([3.91514756, 6.5823489 , -1.60393758, -2.85701674, -0.5386405, + 0.11637277]) + ) diff --git a/tests/data_barotropic_test.py b/tests/data_barotropic_test.py index 01da6ec..82f611d 100644 --- a/tests/data_barotropic_test.py +++ b/tests/data_barotropic_test.py @@ -1,6 +1,5 @@ """Tests for Barotropic class (dabench.data.barotropic)""" -from dabench.data import Barotropic import numpy as np import pytest @@ -11,7 +10,6 @@ from dabench.data import Barotropic - @pytest.fixture(scope='module') def barotropic(): """Defines class Barotropic object for rest of tests.""" diff --git a/tests/data_pyqg_test.py b/tests/data_pyqg_test.py index 57ef354..4cd97ab 100644 --- a/tests/data_pyqg_test.py +++ b/tests/data_pyqg_test.py @@ -10,7 +10,6 @@ from dabench.data import PyQG - @pytest.fixture(scope='module') def pyqg(): """Defines class PYQ object for rest of tests."""