Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attention without for loops #10

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

AliYoussef97
Copy link

Reimplemented compute_middle_attn_vector without for loops. Function computes the attention instantly, even for very long sequences. Additionally, added the ability to compute the entire attention vector or up to a prefered cls position. Unit test script provided as well.

@AliYoussef97
Copy link
Author

AliYoussef97 commented May 5, 2024

I apologise I forgot to calculate time in the unit test script.

However, for a sequence length of 1000, the following is the time to compute the full attention vectors.

Time taken for full vector with loops for sequence length 1000:  13.185152053833008
Time taken for full vector without loops for sequence length 1000:  0.009944677352905273
..
----------------------------------------------------------------------
Ran 2 tests in 13.430s

OK

and for a sequence length of 5000, the full attention vector time is:

Time taken for full vector with loops for sequence length 5000:  255.85072016716003
Time taken for full vector without loops for sequence length 5000:  0.00858449935913086
..
----------------------------------------------------------------------
Ran 2 tests in 256.023s

OK

Tests executed on CPU

@Itamarzimm
Copy link
Collaborator

Hi @AliYoussef97 , thanks for this valuable contribution, and I apologize for the delay in response!

The idea of computing hidden attention without loops, via torch.cumprod, is great and we have investigated it for quite some time (in fact, we have been using similar variants internally since the beginning of this project). While your code is completely accurate and stable, we observe that there are numerical issues when using torch.cumprod to compute the entire hidden attention matrix, since it requires division by relatively small numbers.

There are some similarities between our code and that of Tiny Mamba. I have some ideas regarding fast and stable computation, and I'll share them here once I finish. Feel free to share your thoughts on this topic!

@AliYoussef97
Copy link
Author

@Itamarzimm Hello, that completely fine!

I am not entirely sure where the division term comes from, however, computing the entire attention matrix without loops proved to be rather difficult compared to the attention vector. I am however, still working on solution to compute the matrix in one go, one idea I am working on right now is to expand the dimension of all parameters and compute the cumulative product on the expanded dimension first.

Nonetheless, please do share if you reach a solution to compute the attention matrix without loops. If i manage to reach a swift and stable method, I will update this pull request, or create a new one to avoid confusion.

Thank You 😆!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants