Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make a better error message for abstract tracer errors #28

Open
dlwh opened this issue Aug 30, 2023 · 0 comments
Open

Make a better error message for abstract tracer errors #28

dlwh opened this issue Aug 30, 2023 · 0 comments

Comments

@dlwh
Copy link
Member

dlwh commented Aug 30, 2023

To make this less scary, we should add a guide for how to fix these errors, and detect them automatically.

This one was solved by using jax.ensure_compile_time_eval(), but sometimes the answer is to not use a static field.

​
The above exception was the direct cause of the following exception:
​
Traceback (most recent call last):
  File "/Users/ivan/dev/levanter/local_dev_roundtrip.py", line 145, in <module>
    main()
  File "/Users/ivan/dev/levanter/local_dev_roundtrip.py", line 41, in main
    model = converter.load_pretrained(LlamaLMHeadModel, RepoRef(model_id))
  File "/Users/ivan/dev/levanter/src/levanter/compat/hf_checkpoints.py", line 469, in load_pretrained
    lev_model = eqx.filter_eval_shape(lm_model_cls.init, Vocab, config, key=PRNGKey(0))
  File "/Users/ivan/dev/miniconda3/envs/levanter/lib/python3.10/site-packages/equinox/_eval_shape.py", line 36, in filter_eval_shape
    dynamic_out, static_out = jax.eval_shape(ft.partial(_fn, static), dynamic)
  File "/Users/ivan/dev/miniconda3/envs/levanter/lib/python3.10/contextlib.py", line 142, in __exit__
    next(self.gen)
Exception: Leaked trace MainTrace(1,DynamicJaxprTrace). Leaked tracer(s):
​
Traced<ShapedArray(float32[128,4])>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line /Users/ivan/dev/levanter/src/levanter/models/llama.py:220 (_get_cos_sin_cache)
<DynamicJaxprTracer 6163549296> is referred to by <NamedArray 6163407600>.array
<NamedArray 6163407600> is referred to by <tuple 6163920192>[2]
<tuple 6163920192> is referred to by <_FlattenedData 6163738976>.static_field_values
​
Traced<ShapedArray(float32[128,4])>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line /Users/ivan/dev/levanter/src/levanter/models/llama.py:219 (_get_cos_sin_cache)
<DynamicJaxprTracer 6163548816> is referred to by <NamedArray 6163404000>.array
<NamedArray 6163404000> is referred to by <tuple 6163920192>[1]
<tuple 6163920192> is referred to by <_FlattenedData 6163738976>.static_field_values
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant