Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Irregular obs dacycle #47

Merged
merged 65 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
fe8148f
Support for irregular observation timing in ETKF
kysolvik Jul 12, 2024
1a85e5f
Support missing time observations in var4d_backprop
kysolvik Jul 12, 2024
d7bd91a
Moving time filtering functionality to separte module, untested
kysolvik Jul 15, 2024
a88ea5b
Base test for dacycler
kysolvik Jul 15, 2024
6b8cc3f
3DVar L96 tests
kysolvik Jul 15, 2024
f353695
ETKF tests
kysolvik Jul 15, 2024
40e268b
Updated rcond (deprecated) to rcol
kysolvik Jul 15, 2024
f63e560
4D-Var and 4D-Var Backprop tests
kysolvik Jul 15, 2024
d5fb965
Corrected default R dimensions for var4d and var4dBackprop
kysolvik Jul 15, 2024
ccb2521
Resizing needs to be done with numpy, not jax.numpy
kysolvik Jul 15, 2024
48bcaac
Use utils for obs filtering in var4d_backprop
kysolvik Jul 15, 2024
7c1a913
Add AND NOT for start/end_inclusive false in time filtering
kysolvik Jul 15, 2024
5a557d8
Use utils for etkf time filtering
kysolvik Jul 15, 2024
93145a5
Add variable observation time support to var4d
kysolvik Jul 15, 2024
eb33368
Remove padding method from pbackprop (using _utils module instead
kysolvik Jul 16, 2024
7cb3c2a
Use jax.lax.cond for obs masking
kysolvik Jul 16, 2024
ccab1fd
Use jax.lax.cond for obs masking in backprop
kysolvik Jul 16, 2024
544eeff
Use .at[].get() instead of dynamic slice to fix bug when padded indic…
kysolvik Jul 16, 2024
adc2bf4
Add ability to compute obs_window_indices on the fly
kysolvik Jul 16, 2024
7de6c58
Small corrections to backprop obs_window_indices
kysolvik Jul 16, 2024
c97f4f2
Automatically calculate obs)window_indices for 4dvar, WIP getting ver…
kysolvik Jul 16, 2024
1bc2888
var4d, bp tests with no provided obs_window_indices
kysolvik Jul 17, 2024
8d50749
Don't print loss values but save as attribute, and fix obs_mask out o…
kysolvik Jul 17, 2024
5b7ce6d
Fix out of bounds indexing for os_mask
kysolvik Jul 17, 2024
9aa95c6
Remove debugging printing for obs_window_indices
kysolvik Jul 17, 2024
8c7deaa
Remove extra indexing leftover from debugging
kysolvik Jul 17, 2024
148698f
Restore testing to using obs_window_indices
kysolvik Jul 17, 2024
38195a3
Working obs_loc_mask for etkf, but crude
kysolvik Aug 5, 2024
e349291
Add stationary observer attribute to obs vector
kysolvik Aug 6, 2024
8a45cd3
Update dacycler for irregular observation locations
kysolvik Aug 6, 2024
2d66570
Add error/warning message for 4dvar variants and nonstationary observers
kysolvik Aug 6, 2024
c9fe631
Backprop with irregular obs, working but failing test
kysolvik Aug 6, 2024
051c6b6
Use initial H for hessian approximation, passing tests
kysolvik Aug 7, 2024
411a34c
Remove unnecessary casting to int
kysolvik Aug 7, 2024
1d8e63f
4DVar with obs_loc_mask, working
kysolvik Aug 12, 2024
e27d259
Clean up 4Dvar BP irregular obs, make custom obs ops a warning instea…
kysolvik Aug 12, 2024
86e874e
Streamline 4DVar
kysolvik Aug 12, 2024
6f440df
Fix obs_helper_mask comment
kysolvik Aug 12, 2024
65ddc4d
Add support for multi-dimensional locatiosn in irregular obs
kysolvik Aug 12, 2024
5ed85e5
Set up H as int array
kysolvik Aug 12, 2024
fca6a7b
Removed commented stationary observers check
kysolvik Aug 12, 2024
31703d8
Added warning for nonstationary observers and store_as_jax=True
kysolvik Aug 12, 2024
e1c05ae
Updated var3d to run over analysis windows (not model timesteps) and …
kysolvik Aug 15, 2024
7645076
Updated ETKF to properly round over analysis windows, not model times…
kysolvik Aug 15, 2024
3712469
Updated 4dvar and backprop to use n_cycles and properly specific obs_…
kysolvik Aug 15, 2024
80b9ec3
ETKF and Var3d updated to match inputs, model prediction lenght of va…
kysolvik Aug 15, 2024
391623b
Updated dacycler tests
kysolvik Aug 15, 2024
eca2c47
Updated docstrings
kysolvik Aug 15, 2024
7d499e9
Fix calculation of model timesteps for 4dvar and backprop
kysolvik Aug 15, 2024
f998fa6
4Dvar and 4dvarbp require H to be jax array
kysolvik Aug 15, 2024
0d79b4f
Don't alter custom Hs in 4DVar and 4DVar Backprop
kysolvik Aug 16, 2024
51c6615
Fix calling Ht before assignment
kysolvik Aug 16, 2024
2a7234a
Improvements to var4d backprop speed
kysolvik Aug 20, 2024
e174889
Streamline iter looping
kysolvik Aug 20, 2024
f0de50f
Sped up backprop with irregular obs
kysolvik Aug 21, 2024
4fff86c
Workaround for obs_loc_mask, use conditional
kysolvik Aug 23, 2024
167cbb1
Rework obs window indices calc to use conditional
kysolvik Aug 23, 2024
b04a1b3
Faster regular obs case for var4d
kysolvik Aug 23, 2024
be37444
Fix linting errors for var4d_backprop
kysolvik Aug 23, 2024
9a40680
Irregular obs option for exact hessian
kysolvik Aug 23, 2024
f21200d
Updated var4d_backprop LR and decay defaults to 0.5, 0.5
kysolvik Aug 23, 2024
3bbcc1d
Removing pyqg tests
kysolvik Aug 23, 2024
fdbe443
Import warnings for observer
kysolvik Aug 23, 2024
c640cfe
Merge branch 'main' into irregular-obs-dacycle
kysolvik Aug 23, 2024
814cc93
Remove general skip from pyqg test
kysolvik Aug 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 43 additions & 17 deletions dabench/dacycler/_dacycler.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,52 +53,78 @@ def cycle(self,
input_state,
start_time,
obs_vector,
timesteps,
n_cycles,
analysis_window,
analysis_time_in_window=None):
analysis_time_in_window=None,
return_forecast=False):
"""Perform DA cycle repeatedly, including analysis and forecast

Args:
input_state (vector.StateVector): Input state.
start_time (float or datetime-like): Starting time.
obs_vector (vector.ObsVector): Observations vector.
timesteps (int): Number of timesteps, in model time.
n_cycles (int): Number of analysis cycles to run, each of length
analysis_window.
analysis_window (float): Time window from which to gather
observations for DA Cycle.
analysis_time_in_window (float): Where within analysis_window
to perform analysis. For example, 0.0 is the start of the
window. Default is None, which selects the middle of the
window.
return_forecast (bool): If True, returns forecast at each model
timestep. If False, returns only analyses, one per analysis
cycle. Default is False.

Returns:
vector.StateVector of analyses and times.
"""

# If don't specify analysis_time_in_window, is assumed to be middle
if analysis_time_in_window is None:
analysis_time_in_window = analysis_window/2

# Time offset from middle of time window, for gathering observations
_time_offset = (analysis_window/2) - analysis_time_in_window

# Number of model steps to run per window
steps_per_window = round(analysis_window/self.delta_t) + 1

# For storing outputs
all_analyses = []
all_output_states = []
all_times = []
cur_time = start_time + analysis_time_in_window
cur_time = start_time
cur_state = input_state

for i in range(timesteps):
# 1. Filter observations to plus/minus 0.1 from that time
for i in range(n_cycles):
# 1. Filter observations to inside analysis window
window_middle = cur_time + _time_offset
window_start = window_middle - analysis_window/2
window_end = window_middle + analysis_window/2
obs_vec_timefilt = obs_vector.filter_times(
cur_time - analysis_window/2, cur_time + analysis_window/2)
window_start, window_end
)

if obs_vec_timefilt.values.shape[0] > 0:
# 2. Calculate analysis
analysis, kh = self._step_cycle(cur_state, obs_vec_timefilt)
# 3. Forecast next timestep
cur_state = self._step_forecast(analysis)
# 3. Forecast through analysis window
forecast_states = self._step_forecast(analysis,
n_steps=steps_per_window)
# 4. Save outputs
all_analyses.append(analysis.values)
all_times.append(cur_time)

cur_time += self.delta_t

return vector.StateVector(values=np.stack(all_analyses),
times=np.array(all_times))
if return_forecast:
# Append forecast to current state, excluding last step
all_output_states.append(forecast_states.values[:-1])
all_times.append(
np.arange(steps_per_window-1)*self.delta_t + cur_time
)
else:
all_output_states.append(analysis.values[np.newaxis])
all_times.append([cur_time])

# Starting point for next cycle is last step of forecast
cur_state = forecast_states[-1]
cur_time += analysis_window

return vector.StateVector(values=np.concatenate(all_output_states),
times=np.concatenate(all_times))

146 changes: 103 additions & 43 deletions dabench/dacycler/_etkf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from jax.scipy import linalg

from dabench import dacycler, vector
import dabench.dacycler._utils as dac_utils


class ETKF(dacycler.DACycler):
Expand Down Expand Up @@ -60,18 +61,20 @@ def __init__(self,
ensemble=True,
B=B, R=R, H=H, h=h)

def _step_cycle(self, xb, yo, H=None, h=None, R=None, B=None):
def _step_cycle(self, xb, yo, obs_time_mask, obs_loc_mask,
H=None, h=None, R=None, B=None):
if H is not None or h is None:
vals, kh = self._cycle_obsop(
xb.values, yo.values, yo.location_indices, yo.error_sd,
H, R, B)
xb.values, yo.values, yo.location_indices, yo.error_sd, obs_time_mask,
obs_loc_mask, H, R, B)
return vector.StateVector(values=vals, store_as_jax=True), kh
else:
return self._cycle_general_obsop(xb, yo, h, R, B)

def _calc_default_H(self, obs_values, obs_loc_indices):
H = jnp.zeros((obs_values.flatten().shape[0], self.system_dim))
H = H.at[jnp.arange(H.shape[0]), obs_loc_indices.flatten()
H = H.at[jnp.arange(H.shape[0]),
obs_loc_indices.flatten()
].set(1)
return H

Expand All @@ -82,6 +85,7 @@ def _calc_default_B(self):
return jnp.identity(self.system_dim)

def _cycle_obsop(self, Xbt, obs_values, obs_loc_indices, obs_error_sd,
obs_time_mask, obs_loc_mask,
H=None, h=None, R=None, B=None):
if H is None and h is None:
if self.H is None:
Expand All @@ -107,6 +111,10 @@ def _cycle_obsop(self, Xbt, obs_values, obs_loc_indices, obs_error_sd,
'cycle:: model_forecast must have dimension {}x{}').format(
self.ensemble_dim, self.system_dim)

# Apply obs masks to H
H = jnp.where(obs_time_mask, H.T, 0).T
H = jnp.where(obs_loc_mask.flatten(), H.T, 0).T

# Analysis cycles over all obs in data_obs
Xa = self._compute_analysis(Xb=Xbt.T,
y=obs_values,
Expand All @@ -117,15 +125,17 @@ def _cycle_obsop(self, Xbt, obs_values, obs_loc_indices, obs_error_sd,

return Xa.T, 0

def _step_forecast(self, xa):
def _step_forecast(self, xa, n_steps):
data_forecast = []
for i in range(self.ensemble_dim):
new_vals = self.model_obj.forecast(
vector.StateVector(values=xa.values[i], store_as_jax=True)
vector.StateVector(values=xa.values[i], store_as_jax=True),
n_steps=n_steps
).values
data_forecast.append(new_vals)

return vector.StateVector(values=jnp.stack(data_forecast),
out_vals = jnp.moveaxis(jnp.stack(data_forecast), [0,1,2],[1,0,2])
return vector.StateVector(values=out_vals,
store_as_jax=True)

def _apply_obsop(self, Xb, H, h):
Expand Down Expand Up @@ -182,11 +192,11 @@ def _compute_analysis(self, Xb, y, H, h, R, rho=1.0, Yb=None):

# Compute the analysis
if len(R) > 0:
Rinv = jnp.linalg.pinv(R, rcond=1e-15)
Rinv = jnp.linalg.pinv(R, rtol=1e-15)

Pa_ens = jnp.linalg.pinv((ensemble_dim-1)/rho*I
+ Yb_pert.T @ Rinv @ Yb_pert,
rcond=1e-15)
rtol=1e-15)
Wa = linalg.sqrtm((ensemble_dim-1) * Pa_ens)
Wa = Wa.real
else:
Expand All @@ -206,49 +216,63 @@ def _compute_analysis(self, Xb, y, H, h, R, rho=1.0, Yb=None):
return Xa

def _cycle_and_forecast(self, state_obs_tuple, filtered_idx):
# 1. Get data
cur_state_vals = state_obs_tuple[0]
obs_vals = state_obs_tuple[1]
obs_times = state_obs_tuple[2]
obs_loc_indices = state_obs_tuple[3]
obs_error_sd = state_obs_tuple[4]
obs_loc_masks = state_obs_tuple[4]
obs_error_sd = state_obs_tuple[5]

# 1-b. Calculate obs_time_mask and restore filtered_idx to original values
obs_time_mask = jnp.repeat(filtered_idx > 0, obs_loc_indices.shape[1])
filtered_idx = filtered_idx - 1

# 2. Calculate analysis
new_obs_vals = jax.lax.dynamic_slice_in_dim(obs_vals, filtered_idx[0],
len(filtered_idx))
new_obs_loc_indices = jax.lax.dynamic_slice_in_dim(
obs_loc_indices, filtered_idx[0], len(filtered_idx))
new_obs_vals = obs_vals[filtered_idx]
new_obs_loc_indices = obs_loc_indices[filtered_idx]
new_obs_loc_mask = obs_loc_masks[filtered_idx]
analysis, kh = self._step_cycle(
vector.StateVector(values=cur_state_vals, store_as_jax=True),
vector.ObsVector(values=new_obs_vals,
location_indices=new_obs_loc_indices,
error_sd=obs_error_sd, store_as_jax=True)
error_sd=obs_error_sd, store_as_jax=True),
obs_loc_mask=new_obs_loc_mask,
obs_time_mask=obs_time_mask
)
# 3. Forecast next timestep
cur_state = self._step_forecast(analysis)
forecast_states = self._step_forecast(analysis, n_steps=self.steps_per_window)
next_state = forecast_states.values[-1]

return (cur_state.values, obs_vals, obs_times, obs_loc_indices,
obs_error_sd), cur_state.values
return (next_state, obs_vals, obs_times, obs_loc_indices,
obs_loc_masks, obs_error_sd), forecast_states.values[:-1]

def cycle(self,
input_state,
start_time,
obs_vector,
timesteps,
n_cycles,
obs_error_sd=None,
analysis_window=0.2):
analysis_window=0.2,
analysis_time_in_window=None,
return_forecast=False):
"""Perform DA cycle repeatedly, including analysis and forecast

Args:
input_state (vector.StateVector): Input state.
start_time (float or datetime-like): Starting time.
obs_vector (vector.ObsVector): Observations vector.
timesteps (int): Number of timesteps, in model time.
n_cycles (int): Number of analysis cycles to run, each of length
analysis_window.
analysis_window (float): Time window from which to gather
observations for DA Cycle.
analysis_time_in_window (float): Where within analysis_window
to perform analysis. For example, 0.0 is the start of the
window. Default is None, which selects the middle of the
window.
return_forecast (bool): If True, returns forecast at each model
timestep. If False, returns only analyses, one per analysis
cycle. Default is False.

Returns:
vector.StateVector of analyses and times.
Expand All @@ -257,26 +281,62 @@ def cycle(self,
if obs_error_sd is None:
obs_error_sd = obs_vector.error_sd
self.analysis_window = analysis_window
all_times = (jnp.repeat(start_time, timesteps)
+ (jnp.arange(0, timesteps)*self.delta_t))

# If don't specify analysis_time_in_window, is assumed to be middle
if analysis_time_in_window is None:
analysis_time_in_window = analysis_window/2

# Steps per window + 1 to include start
self.steps_per_window = round(analysis_window/self.delta_t) + 1

# Time offset from middle of time window, for gathering observations
_time_offset = (analysis_window/2) - analysis_time_in_window

# Set up for jax.lax.scan, which is very fast
all_times = dac_utils._get_all_times(
start_time,
analysis_window,
n_cycles)


# Get the obs vectors for each analysis window
all_filtered_idx = jnp.stack([jnp.where(
# Greater than start of window
(obs_vector.times > cur_time - analysis_window/2)
# AND Less than end of window
* (obs_vector.times < cur_time + analysis_window/2)
# AND not equal to end of window
* (1-jnp.isclose(obs_vector.times, cur_time + analysis_window/2,
rtol=0))
# OR Equal to start of window
+ jnp.isclose(obs_vector.times, cur_time - analysis_window/2,
rtol=0)
)[0] for cur_time in all_times])
cur_state, all_values = jax.lax.scan(
self._cycle_and_forecast,
(input_state.values, obs_vector.values, obs_vector.times,
obs_vector.location_indices, obs_error_sd),
all_filtered_idx)

return vector.StateVector(values=jnp.stack(all_values),
times=all_times)
all_filtered_idx = dac_utils._get_obs_indices(
obs_times=obs_vector.times,
analysis_times=all_times+_time_offset,
start_inclusive=True,
end_inclusive=False,
analysis_window=analysis_window
)

all_filtered_padded = dac_utils._pad_time_indices(all_filtered_idx, add_one=True)

# Padding observations
if obs_vector.stationary_observers:
obs_loc_masks = jnp.ones(obs_vector.values.shape, dtype=bool)
cur_state, all_values = jax.lax.scan(
self._cycle_and_forecast,
(input_state.values, obs_vector.values, obs_vector.times,
obs_vector.location_indices, obs_loc_masks, obs_error_sd),
all_filtered_padded)
else:
obs_vals, obs_locs, obs_loc_masks = dac_utils._pad_obs_locs(obs_vector)
cur_state, all_values = jax.lax.scan(
self._cycle_and_forecast,
(input_state.values, obs_vals, obs_vector.times,
obs_locs, obs_loc_masks, obs_error_sd),
all_filtered_padded)


if return_forecast:
all_times_forecast = jnp.arange(
0,
n_cycles*analysis_window,
self.delta_t
) + start_time
return vector.StateVector(values=jnp.concatenate(all_values),
times=all_times_forecast)
else:
return vector.StateVector(values=jnp.vstack([
forecast[0][jnp.newaxis] for forecast in all_values]
),
times=all_times)
Loading
Loading