Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate Pallas TPU Flash Attention #514

Closed
dlwh opened this issue Mar 7, 2024 · 0 comments
Closed

Integrate Pallas TPU Flash Attention #514

dlwh opened this issue Mar 7, 2024 · 0 comments
Labels
good first issue Good for newcomers help wanted Extra attention is needed p2 perf tpu

Comments

@dlwh
Copy link
Member

dlwh commented Mar 7, 2024

Probably once #468 is fully landed... This is maybe not the easiest "good first issue" but I think it can be done.

The kernel is here https://github.com/google/jax/blob/main/jax/experimental/pallas/ops/tpu/flash_attention.py . I think wiring it up isn't too bad. stanford-crfm/haliax#72 shows the way for how to integrate it, though I'd like to maybe the interface a bit nicer eventually.

(I'd be happy to help spot a Stanford student TPU access to do this...)

@dlwh dlwh added good first issue Good for newcomers help wanted Extra attention is needed p2 perf tpu labels Mar 7, 2024
@dlwh dlwh closed this as completed Nov 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue Good for newcomers help wanted Extra attention is needed p2 perf tpu
Projects
None yet
Development

No branches or pull requests

1 participant