From f82e8ca501c281e8be4f5111f21589ae775483cc Mon Sep 17 00:00:00 2001 From: Kylen Solvik Date: Thu, 21 Dec 2023 16:10:06 -0700 Subject: [PATCH] Can specify per-variable SD and Bias in observer --- dabench/observer/_observer.py | 110 ++++++++++++++++++++++++++++------ 1 file changed, 93 insertions(+), 17 deletions(-) diff --git a/dabench/observer/_observer.py b/dabench/observer/_observer.py index 19a347c..b1fedb7 100644 --- a/dabench/observer/_observer.py +++ b/dabench/observer/_observer.py @@ -4,6 +4,7 @@ """ import numpy as np +import jax.numpy as jnp from dabench.vector import ObsVector @@ -27,7 +28,7 @@ class Observer(): random_location_density. random_time_count (int): Number of times to randomly select for observing. Default is None. User should specify one of: - random_time_count, random_time_density, or time_indices. + random_time_count, random_time_density, or time_indices. If random_time_count is specified, it takes precedent over random_time_density. location_indices (ndarray): Manually specified indices for observing. @@ -40,21 +41,25 @@ class Observer(): time_indices (ndarray): Indices of times to gather observations from. If not specified, randomly generate according to random_time_density OR random_time_count. Default is None. - stationary_observers (bool): If True, samples from same indices at - each time step. If False, randomly generates/expects new + stationary_observers (bool): If True, samples from same indices at + each time step. If False, randomly generates/expects new observation indices at each timestep. Default is True. If False: If using random_location_count, the same number of indices - will be randomly generated. + will be randomly generated. If using random_location_density, indices are randomly generated, with the possibility of a different number of locations at each times step.. If using location_indices, expects indices to either be 2D (time_dim, system_dim) or >2D (time_dim, original_dim). - error_bias (float): Mean of normal distribution of observation errors. + error_bias (float or array): Mean of normal distribution of + observation errors. If provided as an array, it is taken to be + variable-specific and the length must be equal to + data_obj.system_dim. Default is 0. + error_sd (float or array): Standard deviation of observation errors. + observation errors. If provided as an array, it is taken to be + variable-specific and the length be equal to data_obj.system_dim. Default is 0. - error_sd (float): Standard deviation of observation errors. Default is - 0. error_positive_only (bool): Clip errors to be positive only. Default is False. random_seed (int): Random seed for sampling times and locations. @@ -95,11 +100,50 @@ def __init__(self, self.random_location_count = random_location_count self.stationary_observers = stationary_observers + self.random_seed = random_seed + self.store_as_jax = store_as_jax + self.error_bias = error_bias self.error_sd = error_sd + if isinstance(self.error_bias, (list, np.ndarray, jnp.ndarray)): + if len(self.error_bias) == 1: + self._error_bias_is_list = False + elif not len(self.error_bias) == self.data_obj.system_dim: + raise ValueError( + "List of error biases has length {}." + "Must match either system_dim ({}) or " + "number of location indices ({})".format( + len(self.error_bias), self.data_obj.system_dim, + self.location_indices.shape[0])) + elif isinstance(self.error_bias, list): + if self.store_as_jax: + self.error_bias = jnp.array(self.error_bias) + else: + self.error_bias = np.array(self.error_bias) + self._error_bias_is_list = True + else: + self._error_bias_is_list = False + + if isinstance(self.error_sd, (list, np.ndarray, jnp.ndarray)): + if len(self.error_sd) == 1: + self._error_sd_is_list = False + elif not len(self.error_sd) == self.data_obj.system_dim: + raise ValueError( + "List of error sds has length {}." + "Must match either system_dim ({}) or " + "number of location indices ({})".format( + len(self.error_sd), self.data_obj.system_dim, + self.location_indices.shape[0])) + elif isinstance(self.error_sd, list): + if self.store_as_jax: + self.error_sd = jnp.array(self.error_sd) + else: + self.error_sd = np.array(self.error_sd) + self._error_sd_is_list = True + else: + self._error_sd_is_list = False + self.error_positive_only = error_positive_only - self.random_seed = random_seed - self.store_as_jax = store_as_jax def _generate_time_indices(self, rng): if self.random_time_count is not None: @@ -262,8 +306,16 @@ def observe(self): self.location_dim = np.repeat(self.location_indices.shape[0], self.time_dim) errors_vec_size = (self.time_dim,) + (self.location_dim[0],) - errors_vector = rng.normal(loc=self.error_bias, - scale=self.error_sd, + if self._error_bias_is_list: + error_bias = self.error_bias[self.location_indices] + else: + error_bias = self.error_bias + if self._error_sd_is_list: + error_sd = self.error_sd[self.location_indices] + else: + error_sd = self.error_sd + errors_vector = rng.normal(loc=error_bias, + scale=error_sd, size=errors_vec_size) # Clip errors to positive only @@ -307,12 +359,36 @@ def observe(self): self.location_indices]) # Generate errors - errors_vector = np.array([ - rng.normal( - loc=self.error_bias, - scale=self.error_sd, - size=ld) - for ld in self.location_dim], dtype=object) + if self._error_bias_is_list: + if self._error_sd_is_list: + errors_vector = np.array([ + rng.normal( + loc=self.error_bias[ld], + scale=self.error_sd[ld], + size=ld) + for ld in self.location_dim], dtype=object) + else: + errors_vector = np.array([ + rng.normal( + loc=self.error_bias[ld], + scale=self.error_sd, + size=ld) + for ld in self.location_dim], dtype=object) + else: + if self._error_sd_is_list: + errors_vector = np.array([ + rng.normal( + loc=self.error_bias, + scale=self.error_sd[ld], + size=ld) + for ld in self.location_dim], dtype=object) + else: + errors_vector = np.array([ + rng.normal( + loc=self.error_bias, + scale=self.error_sd, + size=ld) + for ld in self.location_dim], dtype=object) if self.error_positive_only: errors_vector = np.array([