Skip to content

Commit

Permalink
Can specify per-variable SD and Bias in observer
Browse files Browse the repository at this point in the history
  • Loading branch information
kysolvik committed Dec 21, 2023
1 parent f652ed8 commit f82e8ca
Showing 1 changed file with 93 additions and 17 deletions.
110 changes: 93 additions & 17 deletions dabench/observer/_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

import numpy as np
import jax.numpy as jnp

from dabench.vector import ObsVector

Expand All @@ -27,7 +28,7 @@ class Observer():
random_location_density.
random_time_count (int): Number of times to randomly select for
observing. Default is None. User should specify one of:
random_time_count, random_time_density, or time_indices.
random_time_count, random_time_density, or time_indices.
If random_time_count is specified, it takes precedent over
random_time_density.
location_indices (ndarray): Manually specified indices for observing.
Expand All @@ -40,21 +41,25 @@ class Observer():
time_indices (ndarray): Indices of times to gather observations from.
If not specified, randomly generate according to
random_time_density OR random_time_count. Default is None.
stationary_observers (bool): If True, samples from same indices at
each time step. If False, randomly generates/expects new
stationary_observers (bool): If True, samples from same indices at
each time step. If False, randomly generates/expects new
observation indices at each timestep. Default is True.
If False:
If using random_location_count, the same number of indices
will be randomly generated.
will be randomly generated.
If using random_location_density, indices are randomly
generated, with the possibility of a different number
of locations at each times step..
If using location_indices, expects indices to either be 2D
(time_dim, system_dim) or >2D (time_dim, original_dim).
error_bias (float): Mean of normal distribution of observation errors.
error_bias (float or array): Mean of normal distribution of
observation errors. If provided as an array, it is taken to be
variable-specific and the length must be equal to
data_obj.system_dim. Default is 0.
error_sd (float or array): Standard deviation of observation errors.
observation errors. If provided as an array, it is taken to be
variable-specific and the length be equal to data_obj.system_dim.
Default is 0.
error_sd (float): Standard deviation of observation errors. Default is
0.
error_positive_only (bool): Clip errors to be positive only. Default is
False.
random_seed (int): Random seed for sampling times and locations.
Expand Down Expand Up @@ -95,11 +100,50 @@ def __init__(self,
self.random_location_count = random_location_count
self.stationary_observers = stationary_observers

self.random_seed = random_seed
self.store_as_jax = store_as_jax

self.error_bias = error_bias
self.error_sd = error_sd
if isinstance(self.error_bias, (list, np.ndarray, jnp.ndarray)):
if len(self.error_bias) == 1:
self._error_bias_is_list = False
elif not len(self.error_bias) == self.data_obj.system_dim:
raise ValueError(
"List of error biases has length {}."
"Must match either system_dim ({}) or "
"number of location indices ({})".format(
len(self.error_bias), self.data_obj.system_dim,
self.location_indices.shape[0]))
elif isinstance(self.error_bias, list):
if self.store_as_jax:
self.error_bias = jnp.array(self.error_bias)
else:
self.error_bias = np.array(self.error_bias)
self._error_bias_is_list = True
else:
self._error_bias_is_list = False

if isinstance(self.error_sd, (list, np.ndarray, jnp.ndarray)):
if len(self.error_sd) == 1:
self._error_sd_is_list = False
elif not len(self.error_sd) == self.data_obj.system_dim:
raise ValueError(
"List of error sds has length {}."
"Must match either system_dim ({}) or "
"number of location indices ({})".format(
len(self.error_sd), self.data_obj.system_dim,
self.location_indices.shape[0]))
elif isinstance(self.error_sd, list):
if self.store_as_jax:
self.error_sd = jnp.array(self.error_sd)
else:
self.error_sd = np.array(self.error_sd)
self._error_sd_is_list = True
else:
self._error_sd_is_list = False

self.error_positive_only = error_positive_only
self.random_seed = random_seed
self.store_as_jax = store_as_jax

def _generate_time_indices(self, rng):
if self.random_time_count is not None:
Expand Down Expand Up @@ -262,8 +306,16 @@ def observe(self):
self.location_dim = np.repeat(self.location_indices.shape[0],
self.time_dim)
errors_vec_size = (self.time_dim,) + (self.location_dim[0],)
errors_vector = rng.normal(loc=self.error_bias,
scale=self.error_sd,
if self._error_bias_is_list:
error_bias = self.error_bias[self.location_indices]
else:
error_bias = self.error_bias
if self._error_sd_is_list:
error_sd = self.error_sd[self.location_indices]
else:
error_sd = self.error_sd
errors_vector = rng.normal(loc=error_bias,
scale=error_sd,
size=errors_vec_size)

# Clip errors to positive only
Expand Down Expand Up @@ -307,12 +359,36 @@ def observe(self):
self.location_indices])

# Generate errors
errors_vector = np.array([
rng.normal(
loc=self.error_bias,
scale=self.error_sd,
size=ld)
for ld in self.location_dim], dtype=object)
if self._error_bias_is_list:
if self._error_sd_is_list:
errors_vector = np.array([
rng.normal(
loc=self.error_bias[ld],
scale=self.error_sd[ld],
size=ld)
for ld in self.location_dim], dtype=object)
else:
errors_vector = np.array([
rng.normal(
loc=self.error_bias[ld],
scale=self.error_sd,
size=ld)
for ld in self.location_dim], dtype=object)
else:
if self._error_sd_is_list:
errors_vector = np.array([
rng.normal(
loc=self.error_bias,
scale=self.error_sd[ld],
size=ld)
for ld in self.location_dim], dtype=object)
else:
errors_vector = np.array([
rng.normal(
loc=self.error_bias,
scale=self.error_sd,
size=ld)
for ld in self.location_dim], dtype=object)

if self.error_positive_only:
errors_vector = np.array([
Expand Down

0 comments on commit f82e8ca

Please sign in to comment.