Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 575252495
  • Loading branch information
Jake VanderPlas authored and learned_optimization authors committed Oct 20, 2023
1 parent 1f9823d commit d13285c
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions learned_optimization/optimizers/optax_opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,11 +384,13 @@ def __init__(self,

# SM3 doesn't support scalars, so we have to reshape the params and grads.

def init(self,
params: Any,
model_state: Optional[Any] = None,
num_steps: Optional[int] = None,
key: chex.PRNGKey = None) -> SM3OptState:
def init(
self,
params: Any,
model_state: Optional[Any] = None,
num_steps: Optional[int] = None,
key: Optional[chex.PRNGKey] = None,
) -> SM3OptState:
should_reshape = jax.tree_util.tree_map(lambda x: len(x.shape) == 0, params) # pylint: disable=g-explicit-length-test
params = jax.tree_util.tree_map(_expand_scalar, params, should_reshape)
out = super().init(params, model_state, num_steps, key)
Expand Down

0 comments on commit d13285c

Please sign in to comment.