diff --git a/pyproject.toml b/pyproject.toml index aae0d75..211e086 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "imt-ring" -version = "1.6.9" +version = "1.6.10" authors = [ { name="Simon Bachhuber", email="simon.bachhuber@fau.de" }, ] diff --git a/src/ring/algorithms/generator/base.py b/src/ring/algorithms/generator/base.py index 65dc099..fa4b2e7 100644 --- a/src/ring/algorithms/generator/base.py +++ b/src/ring/algorithms/generator/base.py @@ -6,6 +6,7 @@ import jax.numpy as jnp import numpy as np import tree_utils +from tree_utils import PyTree from ring import base from ring import utils @@ -143,9 +144,7 @@ def _number_of_executions_required(size: int) -> int: return n_calls - def to_list( - self, sizes: int | list[int] = 1, seed: int = 1 - ) -> list[tree_utils.PyTree[np.ndarray]]: + def _generators_ncalls(self, sizes: int | list[int] = 1): "Returns list of unbatched sequences as numpy arrays." repeats = self._compute_repeats(sizes) sizes = list(jnp.array(repeats) * jnp.array(self._size_of_generators)) @@ -165,7 +164,45 @@ def to_list( batch.generators_lazy([self.gens[i]], [reduced_repeats[i]], jits[i]) ) - return batch.generators_eager_to_list(gens, n_calls, seed, self._disable_tqdm) + return gens, n_calls + + def to_list( + self, sizes: int | list[int] = 1, seed: int = 1 + ) -> list[tree_utils.PyTree[np.ndarray]]: + "Returns list of unbatched sequences as numpy arrays." + gens, n_calls = self._generators_ncalls(sizes) + + data = [] + batch.generators_eager( + gens, n_calls, lambda d: data.extend(d), seed, self._disable_tqdm + ) + return data + + def to_folder( + self, + path: str, + sizes: int | list[int] = 1, + seed: int = 1, + overwrite: bool = True, + file_prefix: str = "seq", + save_fn: Callable[[PyTree[np.ndarray], str], None] = utils.pickle_save, + verbose: bool = True, + ): + + i = 0 + + def callback(data: list[PyTree[np.ndarray]]) -> None: + nonlocal i + data = utils.replace_elements_w_nans(data, verbose=verbose) + for d in data: + file = utils.parse_path( + path, file_prefix + str(i), file_exists_ok=overwrite + ) + save_fn(d, file) + i += 1 + + gens, n_calls = self._generators_ncalls(sizes) + batch.generators_eager(gens, n_calls, callback, seed, self._disable_tqdm) def to_pickle( self, diff --git a/src/ring/algorithms/generator/batch.py b/src/ring/algorithms/generator/batch.py index a496110..2c0185a 100644 --- a/src/ring/algorithms/generator/batch.py +++ b/src/ring/algorithms/generator/batch.py @@ -1,8 +1,10 @@ +from typing import Callable + import jax import jax.numpy as jnp import numpy as np from tqdm import tqdm -import tree_utils +from tree_utils import PyTree from ring import utils from ring.algorithms.generator import types @@ -50,15 +52,15 @@ def generator(key): return generator -def generators_eager_to_list( +def generators_eager( generators: list[types.BatchedGenerator], n_calls: list[int], + callback: Callable[[list[PyTree[np.ndarray]]], None], seed: int = 1, disable_tqdm: bool = False, -) -> list[tree_utils.PyTree]: +) -> None: key = jax.random.PRNGKey(seed) - data = [] for gen, n_call in tqdm( zip(generators, n_calls), desc="executing generators", @@ -81,6 +83,4 @@ def generators_eager_to_list( sample_flat, _ = jax.tree_util.tree_flatten(sample) size = 1 if len(sample_flat) == 0 else sample_flat[0].shape[0] - data.extend([jax.tree_map(lambda a: a[i], sample) for i in range(size)]) - - return data + callback([jax.tree_map(lambda a: a[i], sample) for i in range(size)])