Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve block weighting with uniform and hat functions #147

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

lsorber
Copy link

@lsorber lsorber commented Jan 2, 2025

This PR makes the current uniform weighting scheme explicit, and adds an improved hat weighting scheme.

The rationale behind hat weighting is that predictions for tokens near the beginning or end of the block will be less accurate than predictions for tokens near the middle of the block, where the model has maximal context.

For instance, let's say we use stride=128 and block_size=256 and compare the predictions for the token with index 128:

  1. With uniform weighting, its prediction will be 0.5 * first_block[128] + 0.5 * second_block[0].
  2. With hat weighting, its prediction will (approximately) be 1 * first_block[128] + 1/256 * second_block[0].

In this example, hat weighting is preferable because the first token of the second block is likely to be much less accurate than the middle token of first block.

Anecdotally, I've also observed that hat weighting improves output quality on test data.

@markus583
Copy link
Collaborator

Hi! Thanks a lot for implementing this. Interesting idea, cool stuff! It intuitively makes sense, but I'm unsure if it makes a practical difference. It would be interesting to test it on some benchmarks. For the time being, I'd be happy to add it as a feature and leave the default to uniform. Would you agree @bminixhofer?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants