Skip to content

Commit

Permalink
Add labeled arrays (#291)
Browse files Browse the repository at this point in the history
* add labeled arrays

* fix typo

* fix tests

* update array interpolate linear
  • Loading branch information
ValentinaHutter authored Nov 13, 2024
1 parent 9a79761 commit c128316
Show file tree
Hide file tree
Showing 3 changed files with 235 additions and 31 deletions.
179 changes: 151 additions & 28 deletions openeo_processes_dask/process_implementations/arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
ArrayElementNotAvailable,
ArrayElementParameterConflict,
ArrayElementParameterMissing,
ArrayLabelConflict,
ArrayLengthMismatch,
ArrayNotLabeled,
LabelExists,
TooManyDimensions,
)

Expand All @@ -26,11 +30,14 @@
__all__ = [
"array_element",
"array_create",
"array_create_labeled",
"array_modify",
"array_concat",
"array_append",
"array_contains",
"array_find",
"array_find_label",
"array_filter",
"array_labels",
"array_apply",
"array_interpolate_linear",
Expand All @@ -43,6 +50,18 @@
]


def get_labels(data, dimension="labels"):
if isinstance(data, xr.DataArray):
dimension = data.dims[0] if len(data.dims) == 1 else dimension
labels = data[dimension].values
data = data.values
else:
labels = []
if isinstance(data, list):
data = np.asarray(data)
return labels, data


def array_element(
data: ArrayLike,
index: Optional[int] = None,
Expand All @@ -62,7 +81,14 @@ def array_element(
"The process `array_element` only allows that either the `index` or the `labels` parameter is set."
)

if isinstance(data, xr.DataArray):
dim_labels, data = get_labels(data)

if label is not None:
if len(dim_labels) == 0:
raise ArrayNotLabeled(
"The array is not a labeled array, but the `label` parameter is set. Use the `index` instead."
)
if isinstance(label, DateTime):
label = label.to_numpy()
(index,) = np.where(dim_labels == label)
Expand Down Expand Up @@ -105,28 +131,58 @@ def array_create(
return np.tile(data, reps=repeat)


def array_create_labeled(data: ArrayLike, labels: ArrayLike) -> ArrayLike:
if isinstance(data, list):
data = np.array(data)
if len(data) == len(labels):
data = xr.DataArray(data, dims=["labels"], coords={"labels": labels})
return data
raise ArrayLengthMismatch(
"The number of values in the parameters `data` and `labels` don't match."
)


def array_modify(
data: ArrayLike,
values: ArrayLike,
index: int,
length: Optional[int] = 1,
) -> ArrayLike:
labels, data = get_labels(data)
values_labels, values = get_labels(values)

if index > len(data):
raise ArrayElementNotAvailable(
"The array can't be modified as the given index is larger than the number of elements in the array."
)
if len(np.intersect1d(labels, values_labels)) > 0:
raise ArrayLabelConflict(
"At least one label exists in both arrays and the conflict must be resolved before."
)

first = data[:index]
modified = np.append(first, values)
if index + length < len(data):
modified = np.append(modified, data[index + length :])

if len(labels) > 0:
first = labels[:index]
modified_labels = np.append(first, values_labels)
if index + length < len(labels):
modified_labels = np.append(modified_labels, labels[index + length :])
modified = array_create_labeled(data=modified, labels=modified_labels)

return modified


def array_concat(array1: ArrayLike, array2: ArrayLike) -> ArrayLike:
if isinstance(array1, list):
array1 = np.asarray(array1)
if isinstance(array2, list):
array2 = np.asarray(array2)
labels1, array1 = get_labels(array1)
labels2, array2 = get_labels(array2)

if len(np.intersect1d(labels1, labels2)) > 0:
raise ArrayLabelConflict(
"At least one label exists in both arrays and the conflict must be resolved before."
)

concat = np.concatenate([array1, array2])

Expand All @@ -138,12 +194,28 @@ def array_concat(array1: ArrayLike, array2: ArrayLike) -> ArrayLike:
f"array_concat: different datatypes for array1 ({array1.dtype}) and array2 ({array2.dtype}), cast to {concat.dtype}"
)

if len(labels1) > 0 and len(labels2) > 0:
labels = np.concatenate([labels1, labels2])
concat = array_create_labeled(data=concat, labels=labels)
return concat


def array_append(data: ArrayLike, value: Any, label: Optional[Any] = None) -> ArrayLike:
def array_append(
data: ArrayLike,
value: Any,
label: Optional[Any] = None,
dim_labels=None,
) -> ArrayLike:
if dim_labels:
data = array_create_labeled(data=data, labels=dim_labels)
if label is not None:
raise NotImplementedError("labelled arrays are currently not implemented.")
labels, _ = get_labels(data)
if label in labels:
raise LabelExists(
"An array element with the specified label already exists."
)
value = array_create_labeled(data=[value], labels=[label])
return array_concat(data, value)

if (
not isinstance(value, list)
Expand All @@ -158,7 +230,7 @@ def array_append(data: ArrayLike, value: Any, label: Optional[Any] = None) -> Ar
def array_contains(data: ArrayLike, value: Any, axis=None) -> bool:
# TODO: Contrary to the process spec, our implementation does interpret temporal strings before checking them here
# This is somewhat implicit in how we currently parse parameters, so cannot be easily changed.

labels, data = get_labels(data)
value_is_valid = False
valid_dtypes = [np.number, np.bool_, np.str_]
for dtype in valid_dtypes:
Expand All @@ -178,8 +250,7 @@ def array_find(
reverse: Optional[bool] = False,
axis: Optional[int] = None,
) -> np.number:
if isinstance(data, list):
data = np.asarray(data)
labels, data = get_labels(data)

if reverse:
data = np.flip(data, axis=axis)
Expand Down Expand Up @@ -208,20 +279,60 @@ def array_find(
return masked_idxs


def array_labels(data: ArrayLike) -> ArrayLike:
logger.warning(
"Labelled arrays are currently not supported, array_labels will only return indices."
)
if isinstance(data, list):
data = np.asarray(data)
if len(data.shape) > 1:
def array_find_label(data: ArrayLike, label: Union[str, int, float], dim_labels=None):
if dim_labels:
labels = dim_labels
else:
labels, data = get_labels(data)
if len(labels) > 0:
return array_find(labels, label)
return None


def array_filter(
data: ArrayLike, condition: Callable, context: Optional[Any] = None
) -> ArrayLike:
labels, data = get_labels(data)
if not context:
context = {}
positional_parameters = {"x": 0}
named_parameters = {"x": data, "context": context}
if callable(condition):
process_to_apply = np.vectorize(condition)
filtered_data = process_to_apply(
data,
positional_parameters=positional_parameters,
named_parameters=named_parameters,
)
data = data[filtered_data.astype(bool)]
if len(labels) > 0:
labels = labels[filtered_data]
data = array_create_labeled(data, labels)
return data
raise Exception(f"Array could not be filtered as condition is not callable. ")


def array_labels(data: ArrayLike, axis=None, dim_labels=None) -> ArrayLike:
if dim_labels:
return dim_labels
if isinstance(data, xr.DataArray) and axis:
dim = data.dims[axis]
labels, data = get_labels(data, dim)
else:
labels, data = get_labels(data)
if len(labels) > 0:
return labels
if len(np.shape(data)) > 1:
if axis:
return np.arange(data.shape[axis])
raise TooManyDimensions("array_labels is only implemented for 1D arrays.")
return np.arange(len(data))


def array_apply(
data: ArrayLike, process: Callable, context: Optional[Any] = None
) -> ArrayLike:
labels, data = get_labels(data)
if not context:
context = {}
positional_parameters = {"x": 0}
Expand All @@ -233,12 +344,25 @@ def array_apply(
positional_parameters=positional_parameters,
named_parameters=named_parameters,
)


def array_interpolate_linear(data: ArrayLike):
if isinstance(data, list):
data = np.array(data)
x = np.arange(len(data))
raise Exception(f"Could not apply process as it is not callable. ")


def array_interpolate_linear(data: ArrayLike, dim_labels=None):
x, data = get_labels(data)
if len(x) > 0:
dim_labels = x
if dim_labels:
x = np.array(dim_labels)
if np.array(x).dtype.type is np.str_:
try:
x = np.array(x, dtype="datetime64").astype(float)
except Exception:
try:
x = np.array(x, dtype=float)
except Exception:
x = np.arange(len(data))
if len(x) == 0:
x = np.arange(len(data))
valid = np.isfinite(data)
if len(x[valid]) < 2:
return data
Expand Down Expand Up @@ -291,8 +415,7 @@ def order(
nodata: Optional[bool] = None,
axis: Optional[int] = None,
):
if isinstance(data, list):
data = np.asarray(data)
labels, data = get_labels(data)
if len(data) == 0:
return data

Expand Down Expand Up @@ -334,8 +457,7 @@ def rearrange(
):
if len(data) == 0:
return data
if isinstance(data, list):
data = np.asarray(data)
labels, data = get_labels(data)
if len(data.shape) == 1 and axis is None:
axis = 0
if isinstance(order, list):
Expand All @@ -353,8 +475,7 @@ def sort(
nodata: Optional[bool] = None,
axis: Optional[int] = None,
):
if isinstance(data, list):
data = np.asarray(data)
labels, data = get_labels(data)
if len(data) == 0:
return data
if asc:
Expand Down Expand Up @@ -384,6 +505,7 @@ def count(
axis=None,
keepdims=False,
):
labels, data = get_labels(data)
if condition is None:
valid = is_valid(data)
return np.nansum(valid, axis=axis, keepdims=keepdims)
Expand All @@ -395,3 +517,4 @@ def count(
context.pop("x", None)
count = condition(x=data, **context)
return np.nansum(count, axis=axis, keepdims=keepdims)
raise Exception(f"Could not count values as condition is not callable. ")
12 changes: 12 additions & 0 deletions openeo_processes_dask/process_implementations/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,18 @@ class ArrayElementNotAvailable(OpenEOException):
pass


class ArrayLabelConflict(OpenEOException):
pass


class ArrayLengthMismatch(OpenEOException):
pass


class LabelExists(OpenEOException):
pass


class TooManyDimensions(OpenEOException):
pass

Expand Down
Loading

0 comments on commit c128316

Please sign in to comment.