-
Notifications
You must be signed in to change notification settings - Fork 149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Quantization appears to be broken, at least for AWQ and BnB #722
Comments
Here is a another related case, #611 referencing container ghcr.io/predibase/lorax:07addea, which I could get to work with https://huggingface.co/hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4, but is very old. |
The problem appears to be here: if quantize in ["gptq", "awq"]: lorax/server/lorax_server/utils/weights.py Line 120 in 69bb989
" lorax/server/lorax_server/utils/weights.py Line 141 in 69bb989
Weight can have a different number of parameters, which is not accounted for here:
|
Hey @codybum ! Thanks for the investigation. Would you be willing to contribute to LoRAX by pushing up a fix with your suggested changes? |
I see where it is breaking, not necessarily how to fix, but I will look into this further and see if I can get something working. |
System Info
I have tried the following Lorax versions:
(official version) ghcr.io/predibase/lorax:0.12
(locally compiled) lorax:69bb989
CUDA:
12.4
12.6
Information
Tasks
Reproduction
Expected behavior
I would expect Lorax to start normally, like when using a non-quantized model. However, I have tried several different versions of AWQ and BnB 4bit models, different versions of Lorax, and CUDA/Drivers, all of which fail.
Example errors (also seen in other issue requests):
AWQ: "ValueError: too many values to unpack (expected 3)"
BnB: "AssertionError: [12582912, 1] != [6144, 4096]"
**Full AWQ Error:
2024-12-21T13:26:16.679407Z INFO shard-manager: lorax_launcher: Starting shard rank=0
2024-12-21T13:26:25.250809Z ERROR lorax_launcher: server.py:317 Error when initializing model
Traceback (most recent call last):
File "/opt/conda/bin/lorax-server", line 8, in
sys.exit(app())
File "/opt/conda/lib/python3.10/site-packages/typer/main.py", line 311, in call
return get_command(self)(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1157, in call
return self.main(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/typer/core.py", line 778, in main
return _main(
File "/opt/conda/lib/python3.10/site-packages/typer/core.py", line 216, in _main
rv = self.invoke(ctx)
File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1688, in invoke
return _process_result(sub_ctx.command.invoke(sub_ctx))
File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1434, in invoke
return ctx.invoke(self.callback, **ctx.params)
File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 783, in invoke
return __callback(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/typer/main.py", line 683, in wrapper
return callback(**use_params) # type: ignore
File "/opt/conda/lib/python3.10/site-packages/lorax_server/cli.py", line 92, in serve
server.serve(
File "/opt/conda/lib/python3.10/site-packages/lorax_server/server.py", line 449, in serve
asyncio.run(
File "/opt/conda/lib/python3.10/asyncio/runners.py", line 44, in run
return loop.run_until_complete(main)
File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 636, in run_until_complete
self.run_forever()
File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
self._run_once()
File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
handle._run()
File "/opt/conda/lib/python3.10/asyncio/events.py", line 80, in _run
self._context.run(self._callback, *self._args)
2024-12-21T13:26:26.400592Z ERROR shard-manager: lorax_launcher: Shard complete standard error output:
2024-12-21 13:26:20.746 | INFO | lorax_server.utils.state::22 - Backend = fa2
2024-12-21 13:26:20.746 | INFO | lorax_server.utils.state::24 - Prefix caching = False
2024-12-21 13:26:20.746 | INFO | lorax_server.utils.state::25 - Chunked prefill = False
/opt/conda/lib/python3.10/site-packages/torch/distributed/c10d_logger.py:79: FutureWarning: You are using a Backend <class 'lorax_server.utils.dist.FakeGroup'> as a ProcessGroup. This usage is deprecated since PyTorch 2.0. Please use a public API of PyTorch Distributed instead.
return func(*args, **kwargs)
Traceback (most recent call last):
File "/opt/conda/bin/lorax-server", line 8, in
sys.exit(app())
File "/opt/conda/lib/python3.10/site-packages/lorax_server/cli.py", line 92, in serve
server.serve(
File "/opt/conda/lib/python3.10/site-packages/lorax_server/server.py", line 449, in serve
asyncio.run(
File "/opt/conda/lib/python3.10/asyncio/runners.py", line 44, in run
return loop.run_until_complete(main)
File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete
return future.result()
File "/opt/conda/lib/python3.10/site-packages/lorax_server/server.py", line 302, in serve_inner
model = get_model(
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/init.py", line 186, in get_model
return FlashLlama(
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_llama.py", line 40, in init
super().init(
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_causal_lm.py", line 1119, in init
model = model_cls(prefix, config, weights)
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 570, in init
self.model = FlashLlamaModel(prefix, config, weights, create_layer_fn)
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 499, in init
[
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 500, in
create_layer_fn(
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 431, in init
self.self_attn = FlashLlamaAttention(
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 265, in init
self.query_key_value = load_attention(config, prefix, weights, layer_id)
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 162, in load_attention
base_layer = load_attention_multi(config, prefix, weights)
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 179, in load_attention_multi
return _load_gqa(config, prefix, weights)
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 202, in _load_gqa
weight, input_scale, weight_scale = weight
ValueError: too many values to unpack (expected 3)
rank=0
2024-12-21T13:26:26.485424Z ERROR lorax_launcher: Shard 0 failed to start
2024-12-21T13:26:26.485464Z INFO lorax_launcher: Shutting down shards
Error: ShardCannotStart
**Full BnB Error
2024-12-21T13:33:35.146822Z INFO shard-manager: lorax_launcher: Starting shard rank=0
2024-12-21T13:33:43.387125Z ERROR lorax_launcher: server.py:317 Error when initializing model
Traceback (most recent call last):
File "/opt/conda/bin/lorax-server", line 8, in
sys.exit(app())
File "/opt/conda/lib/python3.10/site-packages/typer/main.py", line 311, in call
return get_command(self)(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1157, in call
return self.main(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/typer/core.py", line 778, in main
return _main(
File "/opt/conda/lib/python3.10/site-packages/typer/core.py", line 216, in _main
rv = self.invoke(ctx)
File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1688, in invoke
return _process_result(sub_ctx.command.invoke(sub_ctx))
File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1434, in invoke
return ctx.invoke(self.callback, **ctx.params)
File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 783, in invoke
return __callback(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/typer/main.py", line 683, in wrapper
return callback(**use_params) # type: ignore
File "/opt/conda/lib/python3.10/site-packages/lorax_server/cli.py", line 92, in serve
server.serve(
File "/opt/conda/lib/python3.10/site-packages/lorax_server/server.py", line 449, in serve
asyncio.run(
File "/opt/conda/lib/python3.10/asyncio/runners.py", line 44, in run
return loop.run_until_complete(main)
File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 636, in run_until_complete
self.run_forever()
File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
self._run_once()
File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
handle._run()
File "/opt/conda/lib/python3.10/asyncio/events.py", line 80, in _run
self._context.run(self._callback, *self._args)
2024-12-21T13:33:44.667486Z ERROR shard-manager: lorax_launcher: Shard complete standard error output:
2024-12-21 13:33:39.687 | INFO | lorax_server.utils.state::22 - Backend = fa2
2024-12-21 13:33:39.687 | INFO | lorax_server.utils.state::24 - Prefix caching = False
2024-12-21 13:33:39.687 | INFO | lorax_server.utils.state::25 - Chunked prefill = False
/opt/conda/lib/python3.10/site-packages/torch/distributed/c10d_logger.py:79: FutureWarning: You are using a Backend <class 'lorax_server.utils.dist.FakeGroup'> as a ProcessGroup. This usage is deprecated since PyTorch 2.0. Please use a public API of PyTorch Distributed instead.
return func(*args, **kwargs)
Traceback (most recent call last):
File "/opt/conda/bin/lorax-server", line 8, in
sys.exit(app())
File "/opt/conda/lib/python3.10/site-packages/lorax_server/cli.py", line 92, in serve
server.serve(
File "/opt/conda/lib/python3.10/site-packages/lorax_server/server.py", line 449, in serve
asyncio.run(
File "/opt/conda/lib/python3.10/asyncio/runners.py", line 44, in run
return loop.run_until_complete(main)
File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete
return future.result()
File "/opt/conda/lib/python3.10/site-packages/lorax_server/server.py", line 302, in serve_inner
model = get_model(
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/init.py", line 186, in get_model
return FlashLlama(
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_llama.py", line 40, in init
super().init(
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_causal_lm.py", line 1119, in init
model = model_cls(prefix, config, weights)
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 570, in init
self.model = FlashLlamaModel(prefix, config, weights, create_layer_fn)
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 499, in init
[
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 500, in
create_layer_fn(
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 431, in init
self.self_attn = FlashLlamaAttention(
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 265, in init
self.query_key_value = load_attention(config, prefix, weights, layer_id)
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 162, in load_attention
base_layer = load_attention_multi(config, prefix, weights)
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 179, in load_attention_multi
return _load_gqa(config, prefix, weights)
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 210, in _load_gqa
assert list(weight.shape) == [
AssertionError: [12582912, 1] != [6144, 4096]
rank=0
2024-12-21T13:33:44.753438Z ERROR lorax_launcher: Shard 0 failed to start
2024-12-21T13:33:44.753471Z INFO lorax_launcher: Shutting down shards
Error: ShardCannotStart
The text was updated successfully, but these errors were encountered: