Skip to content

Commit

Permalink
bugfix relating to RNNO and celltype=lstm
Browse files Browse the repository at this point in the history
  • Loading branch information
simon-bachhuber committed Sep 28, 2024
1 parent 360f584 commit cf9e390
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions src/ring/ml/rnno_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,11 @@ def forward_fn(X):
assert X.shape[-2] == 1

for i, n_units in enumerate(rnn_layers):
n_units = _factor * n_units
state = hk.get_state(f"rnn_{i}", shape=[1, n_units], init=jnp.zeros)
X, state = hk.dynamic_unroll(_cell(n_units), X, state)
hk.set_state(f"rnn_{i}", state)
state = hk.get_state(
f"rnn_{i}", shape=[1, n_units * _factor], init=jnp.zeros
)
X, state = hk.dynamic_unroll(_cell(n_units), X[..., 0, :], state[0])
hk.set_state(f"rnn_{i}", state[None])

if layernorm:
X = hk.LayerNorm(axis=-1, create_scale=False, create_offset=False)(X)
Expand Down

0 comments on commit cf9e390

Please sign in to comment.