You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
To make this less scary, we should add a guide for how to fix these errors, and detect them automatically.
This one was solved by using jax.ensure_compile_time_eval(), but sometimes the answer is to not use a static field.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/Users/ivan/dev/levanter/local_dev_roundtrip.py", line 145, in <module>
main()
File "/Users/ivan/dev/levanter/local_dev_roundtrip.py", line 41, in main
model = converter.load_pretrained(LlamaLMHeadModel, RepoRef(model_id))
File "/Users/ivan/dev/levanter/src/levanter/compat/hf_checkpoints.py", line 469, in load_pretrained
lev_model = eqx.filter_eval_shape(lm_model_cls.init, Vocab, config, key=PRNGKey(0))
File "/Users/ivan/dev/miniconda3/envs/levanter/lib/python3.10/site-packages/equinox/_eval_shape.py", line 36, in filter_eval_shape
dynamic_out, static_out = jax.eval_shape(ft.partial(_fn, static), dynamic)
File "/Users/ivan/dev/miniconda3/envs/levanter/lib/python3.10/contextlib.py", line 142, in __exit__
next(self.gen)
Exception: Leaked trace MainTrace(1,DynamicJaxprTrace). Leaked tracer(s):
Traced<ShapedArray(float32[128,4])>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line /Users/ivan/dev/levanter/src/levanter/models/llama.py:220 (_get_cos_sin_cache)
<DynamicJaxprTracer 6163549296> is referred to by <NamedArray 6163407600>.array
<NamedArray 6163407600> is referred to by <tuple 6163920192>[2]
<tuple 6163920192> is referred to by <_FlattenedData 6163738976>.static_field_values
Traced<ShapedArray(float32[128,4])>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line /Users/ivan/dev/levanter/src/levanter/models/llama.py:219 (_get_cos_sin_cache)
<DynamicJaxprTracer 6163548816> is referred to by <NamedArray 6163404000>.array
<NamedArray 6163404000> is referred to by <tuple 6163920192>[1]
<tuple 6163920192> is referred to by <_FlattenedData 6163738976>.static_field_values
The text was updated successfully, but these errors were encountered:
To make this less scary, we should add a guide for how to fix these errors, and detect them automatically.
This one was solved by using jax.ensure_compile_time_eval(), but sometimes the answer is to not use a static field.
The text was updated successfully, but these errors were encountered: