We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Probably once #468 is fully landed... This is maybe not the easiest "good first issue" but I think it can be done.
The kernel is here https://github.com/google/jax/blob/main/jax/experimental/pallas/ops/tpu/flash_attention.py . I think wiring it up isn't too bad. stanford-crfm/haliax#72 shows the way for how to integrate it, though I'd like to maybe the interface a bit nicer eventually.
(I'd be happy to help spot a Stanford student TPU access to do this...)
The text was updated successfully, but these errors were encountered:
No branches or pull requests
Probably once #468 is fully landed... This is maybe not the easiest "good first issue" but I think it can be done.
The kernel is here https://github.com/google/jax/blob/main/jax/experimental/pallas/ops/tpu/flash_attention.py . I think wiring it up isn't too bad. stanford-crfm/haliax#72 shows the way for how to integrate it, though I'd like to maybe the interface a bit nicer eventually.
(I'd be happy to help spot a Stanford student TPU access to do this...)
The text was updated successfully, but these errors were encountered: