-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for partitioning/sharded data with Pallas kernels? #72
Comments
Hey, I'm not aware of anybody trying yet. Do you have an example of it working in e.g. flax? My guess is that we need to use shardmap at they say, but I don't have much experience iwth that yet either. |
OK, the best example I see if from MaxText https://github.com/google/maxtext/blob/10a7c473e9feb1107894e7588b283b1bcfcbd679/MaxText/layers/attentions.py#L213 I think te basic idea to get a PSpec for each input array (and similarly with the the expected output shape) and then call shard_map(kernel), and then just double check that there's no sharding of axes that the kernel assumes to be single-device? So, I think what you'll want to do is to call |
Thanks, that helps a lot. I'm able to call the kernel without errors now. However, I'm still trying to figure out how to manipulate the kernel output (a plain Jax array) back into a Haliax named array with the correct sharding. Here's my current code:
When I try to use the output of this function in the model, it results in this error: |
the issue is that the output array is the "local" array inside the shard map, so Haliax infers that batch is 64, but outside the shard map the raw jax array is concatenated/global, but JAX doesn't know about Haliax's arrays so the axis sizes don't change (I should change the way Haliax works to make this easier...) The easiest thing to do is return a plain jax array from attn_sharded and then wrap the array before returning from linear attention. |
(I'm glad this turned out to be relatively straightforward!) |
Great, it's all working now! Here's the final code I ended up with.
|
Sweet! I'll leave this open just as a "make it easy for people to do this"/make a tutorial issue. |
also, could you let me know what kind of speedup you get? We can try to prioritize getting it into Levanter if it's nontrivial |
That's nice! I actually just meant Pallas flash attention vs pure JAX attention on TPU |
I'm trying to train a model with a custom linear attention kernel I wrote in Pallas, but the following issue is occurring (only happens when the input data is sharded across multiple TPU devices).
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: Mosaic kernels cannot be automatically partitioned. Please wrap the call in a shard_map or xmap.
Here's the code that I'm trying to run: https://github.com/G-Levine/levanter/blob/9e78ab17e416d5e471f27255d13888d5fb98e632/src/levanter/models/linear_attention.py
Is there a recommended way to achieve this with Haliax? I tried to find examples of people using the Pallas FlashAttention kernel with Haliax/Levanter, but it appears nobody has tried this yet. It seems like an important use case to support, for anyone who wants to efficiently train transformer models on multiple TPUs.
The text was updated successfully, but these errors were encountered: