Skip to content

Commit

Permalink
v1.6.10; adds RCMG.to_folder
Browse files Browse the repository at this point in the history
  • Loading branch information
simon-bachhuber committed Aug 26, 2024
1 parent 7b41019 commit 8ca470e
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 12 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "imt-ring"
version = "1.6.9"
version = "1.6.10"
authors = [
{ name="Simon Bachhuber", email="[email protected]" },
]
Expand Down
45 changes: 41 additions & 4 deletions src/ring/algorithms/generator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import jax.numpy as jnp
import numpy as np
import tree_utils
from tree_utils import PyTree

from ring import base
from ring import utils
Expand Down Expand Up @@ -143,9 +144,7 @@ def _number_of_executions_required(size: int) -> int:

return n_calls

def to_list(
self, sizes: int | list[int] = 1, seed: int = 1
) -> list[tree_utils.PyTree[np.ndarray]]:
def _generators_ncalls(self, sizes: int | list[int] = 1):
"Returns list of unbatched sequences as numpy arrays."
repeats = self._compute_repeats(sizes)
sizes = list(jnp.array(repeats) * jnp.array(self._size_of_generators))
Expand All @@ -165,7 +164,45 @@ def to_list(
batch.generators_lazy([self.gens[i]], [reduced_repeats[i]], jits[i])
)

return batch.generators_eager_to_list(gens, n_calls, seed, self._disable_tqdm)
return gens, n_calls

def to_list(
self, sizes: int | list[int] = 1, seed: int = 1
) -> list[tree_utils.PyTree[np.ndarray]]:
"Returns list of unbatched sequences as numpy arrays."
gens, n_calls = self._generators_ncalls(sizes)

data = []
batch.generators_eager(
gens, n_calls, lambda d: data.extend(d), seed, self._disable_tqdm
)
return data

def to_folder(
self,
path: str,
sizes: int | list[int] = 1,
seed: int = 1,
overwrite: bool = True,
file_prefix: str = "seq",
save_fn: Callable[[PyTree[np.ndarray], str], None] = utils.pickle_save,
verbose: bool = True,
):

i = 0

def callback(data: list[PyTree[np.ndarray]]) -> None:
nonlocal i
data = utils.replace_elements_w_nans(data, verbose=verbose)
for d in data:
file = utils.parse_path(
path, file_prefix + str(i), file_exists_ok=overwrite
)
save_fn(d, file)
i += 1

gens, n_calls = self._generators_ncalls(sizes)
batch.generators_eager(gens, n_calls, callback, seed, self._disable_tqdm)

def to_pickle(
self,
Expand Down
14 changes: 7 additions & 7 deletions src/ring/algorithms/generator/batch.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Callable

import jax
import jax.numpy as jnp
import numpy as np
from tqdm import tqdm
import tree_utils
from tree_utils import PyTree

from ring import utils
from ring.algorithms.generator import types
Expand Down Expand Up @@ -50,15 +52,15 @@ def generator(key):
return generator


def generators_eager_to_list(
def generators_eager(
generators: list[types.BatchedGenerator],
n_calls: list[int],
callback: Callable[[list[PyTree[np.ndarray]]], None],
seed: int = 1,
disable_tqdm: bool = False,
) -> list[tree_utils.PyTree]:
) -> None:

key = jax.random.PRNGKey(seed)
data = []
for gen, n_call in tqdm(
zip(generators, n_calls),
desc="executing generators",
Expand All @@ -81,6 +83,4 @@ def generators_eager_to_list(

sample_flat, _ = jax.tree_util.tree_flatten(sample)
size = 1 if len(sample_flat) == 0 else sample_flat[0].shape[0]
data.extend([jax.tree_map(lambda a: a[i], sample) for i in range(size)])

return data
callback([jax.tree_map(lambda a: a[i], sample) for i in range(size)])

0 comments on commit 8ca470e

Please sign in to comment.