Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tokenization error when using msiglip #126

Open
simran-khanuja opened this issue Aug 12, 2024 · 1 comment
Open

tokenization error when using msiglip #126

simran-khanuja opened this issue Aug 12, 2024 · 1 comment

Comments

@simran-khanuja
Copy link

simran-khanuja commented Aug 12, 2024

Hi, I get this error when preprocessing text using the mSigLIP model. Any idea what may be wrong? I didn't change anything in the demo colab

Traceback (most recent call last):
  File "/home/${USER}/babelnet/labels/msiglip.py", line 131, in <module>
    _, ztxt, out = model.apply({'params': params}, None, txts)
  File "/home/${USER}/babelnet/big_vision/big_vision/models/proj/image_text/two_towers.py", line 55, in __call__
    ztxt, out_txt = text_model(text, **kw)
  File "/home/${USER}/babelnet/big_vision/big_vision/models/proj/image_text/text_transformer.py", line 64, in __call__
    x = out["embedded"] = embedding(text)
  File "/home/${USER}/miniconda3/envs/msiglip/lib/python3.10/site-packages/flax/linen/linear.py", line 1106, in setup
    self.embedding = self.param(
flax.errors.ScopeParamShapeError: Initializer expected to generate shape (256000, 1152) but got shape (250000, 1152) instead for parameter "embedding" in "/txt/Embed_0". (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopeParamShapeError)
@merveenoyan
Copy link

merveenoyan commented Aug 23, 2024

@simran-khanuja not very sure but I came across above issue with the latest released SigLIP (so400m patch16) and for me the tokenizer was different + vocab dim should've been 256k (I fixed it during initialization).

Maybe try a different tokenizer (I confirmed with google folks there seems to be a mistake with config for my case, it might be different for you as well), MSigLIP tokenizer spiece model exists here so swapping tokenizer should work.
Also if you're ok with using PyTorch mSigLIP is implemented at transformers, you can use that for time being here.

edit: the new notebook seems to be fixed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants