Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quantization appears to be broken, at least for AWQ and BnB #722

Open
2 of 4 tasks
codybum opened this issue Dec 21, 2024 · 5 comments
Open
2 of 4 tasks

Quantization appears to be broken, at least for AWQ and BnB #722

codybum opened this issue Dec 21, 2024 · 5 comments

Comments

@codybum
Copy link

codybum commented Dec 21, 2024

System Info

I have tried the following Lorax versions:
(official version) ghcr.io/predibase/lorax:0.12
(locally compiled) lorax:69bb989

CUDA:
12.4
12.6

Information

  • Docker
  • The CLI directly

Tasks

  • An officially supported command
  • My own modifications

Reproduction

  1. Obtain a AWQ or BnB quantized model, such as https://huggingface.co/unsloth/Meta-Llama-3.1-8B-bnb-4bit or https://huggingface.co/hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4
  2. Using otherwise default settings, configure Lorax to use a quantized model

Expected behavior

I would expect Lorax to start normally, like when using a non-quantized model. However, I have tried several different versions of AWQ and BnB 4bit models, different versions of Lorax, and CUDA/Drivers, all of which fail.

Example errors (also seen in other issue requests):
AWQ: "ValueError: too many values to unpack (expected 3)"
BnB: "AssertionError: [12582912, 1] != [6144, 4096]"

**Full AWQ Error:

2024-12-21T13:26:16.679407Z INFO shard-manager: lorax_launcher: Starting shard rank=0
2024-12-21T13:26:25.250809Z ERROR lorax_launcher: server.py:317 Error when initializing model
Traceback (most recent call last):
File "/opt/conda/bin/lorax-server", line 8, in
sys.exit(app())
File "/opt/conda/lib/python3.10/site-packages/typer/main.py", line 311, in call
return get_command(self)(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1157, in call
return self.main(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/typer/core.py", line 778, in main
return _main(
File "/opt/conda/lib/python3.10/site-packages/typer/core.py", line 216, in _main
rv = self.invoke(ctx)
File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1688, in invoke
return _process_result(sub_ctx.command.invoke(sub_ctx))
File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1434, in invoke
return ctx.invoke(self.callback, **ctx.params)
File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 783, in invoke
return __callback(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/typer/main.py", line 683, in wrapper
return callback(**use_params) # type: ignore
File "/opt/conda/lib/python3.10/site-packages/lorax_server/cli.py", line 92, in serve
server.serve(
File "/opt/conda/lib/python3.10/site-packages/lorax_server/server.py", line 449, in serve
asyncio.run(
File "/opt/conda/lib/python3.10/asyncio/runners.py", line 44, in run
return loop.run_until_complete(main)
File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 636, in run_until_complete
self.run_forever()
File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
self._run_once()
File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
handle._run()
File "/opt/conda/lib/python3.10/asyncio/events.py", line 80, in _run
self._context.run(self._callback, *self._args)

File "/opt/conda/lib/python3.10/site-packages/lorax_server/server.py", line 302, in serve_inner
model = get_model(
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/init.py", line 186, in get_model
return FlashLlama(
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_llama.py", line 40, in init
super().init(
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_causal_lm.py", line 1119, in init
model = model_cls(prefix, config, weights)
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 570, in init
self.model = FlashLlamaModel(prefix, config, weights, create_layer_fn)
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 499, in init
[
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 500, in
create_layer_fn(
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 431, in init
self.self_attn = FlashLlamaAttention(
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 265, in init
self.query_key_value = load_attention(config, prefix, weights, layer_id)
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 162, in load_attention
base_layer = load_attention_multi(config, prefix, weights)
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 179, in load_attention_multi
return _load_gqa(config, prefix, weights)
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 202, in _load_gqa
weight, input_scale, weight_scale = weight
ValueError: too many values to unpack (expected 3)

2024-12-21T13:26:26.400592Z ERROR shard-manager: lorax_launcher: Shard complete standard error output:

2024-12-21 13:26:20.746 | INFO | lorax_server.utils.state::22 - Backend = fa2
2024-12-21 13:26:20.746 | INFO | lorax_server.utils.state::24 - Prefix caching = False
2024-12-21 13:26:20.746 | INFO | lorax_server.utils.state::25 - Chunked prefill = False
/opt/conda/lib/python3.10/site-packages/torch/distributed/c10d_logger.py:79: FutureWarning: You are using a Backend <class 'lorax_server.utils.dist.FakeGroup'> as a ProcessGroup. This usage is deprecated since PyTorch 2.0. Please use a public API of PyTorch Distributed instead.
return func(*args, **kwargs)
Traceback (most recent call last):

File "/opt/conda/bin/lorax-server", line 8, in
sys.exit(app())

File "/opt/conda/lib/python3.10/site-packages/lorax_server/cli.py", line 92, in serve
server.serve(

File "/opt/conda/lib/python3.10/site-packages/lorax_server/server.py", line 449, in serve
asyncio.run(

File "/opt/conda/lib/python3.10/asyncio/runners.py", line 44, in run
return loop.run_until_complete(main)

File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete
return future.result()

File "/opt/conda/lib/python3.10/site-packages/lorax_server/server.py", line 302, in serve_inner
model = get_model(

File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/init.py", line 186, in get_model
return FlashLlama(

File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_llama.py", line 40, in init
super().init(

File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_causal_lm.py", line 1119, in init
model = model_cls(prefix, config, weights)

File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 570, in init
self.model = FlashLlamaModel(prefix, config, weights, create_layer_fn)

File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 499, in init
[

File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 500, in
create_layer_fn(

File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 431, in init
self.self_attn = FlashLlamaAttention(

File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 265, in init
self.query_key_value = load_attention(config, prefix, weights, layer_id)

File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 162, in load_attention
base_layer = load_attention_multi(config, prefix, weights)

File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 179, in load_attention_multi
return _load_gqa(config, prefix, weights)

File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 202, in _load_gqa
weight, input_scale, weight_scale = weight

ValueError: too many values to unpack (expected 3)
rank=0
2024-12-21T13:26:26.485424Z ERROR lorax_launcher: Shard 0 failed to start
2024-12-21T13:26:26.485464Z INFO lorax_launcher: Shutting down shards
Error: ShardCannotStart

**Full BnB Error
2024-12-21T13:33:35.146822Z INFO shard-manager: lorax_launcher: Starting shard rank=0
2024-12-21T13:33:43.387125Z ERROR lorax_launcher: server.py:317 Error when initializing model
Traceback (most recent call last):
File "/opt/conda/bin/lorax-server", line 8, in
sys.exit(app())
File "/opt/conda/lib/python3.10/site-packages/typer/main.py", line 311, in call
return get_command(self)(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1157, in call
return self.main(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/typer/core.py", line 778, in main
return _main(
File "/opt/conda/lib/python3.10/site-packages/typer/core.py", line 216, in _main
rv = self.invoke(ctx)
File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1688, in invoke
return _process_result(sub_ctx.command.invoke(sub_ctx))
File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1434, in invoke
return ctx.invoke(self.callback, **ctx.params)
File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 783, in invoke
return __callback(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/typer/main.py", line 683, in wrapper
return callback(**use_params) # type: ignore
File "/opt/conda/lib/python3.10/site-packages/lorax_server/cli.py", line 92, in serve
server.serve(
File "/opt/conda/lib/python3.10/site-packages/lorax_server/server.py", line 449, in serve
asyncio.run(
File "/opt/conda/lib/python3.10/asyncio/runners.py", line 44, in run
return loop.run_until_complete(main)
File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 636, in run_until_complete
self.run_forever()
File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
self._run_once()
File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
handle._run()
File "/opt/conda/lib/python3.10/asyncio/events.py", line 80, in _run
self._context.run(self._callback, *self._args)

File "/opt/conda/lib/python3.10/site-packages/lorax_server/server.py", line 302, in serve_inner
model = get_model(
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/init.py", line 186, in get_model
return FlashLlama(
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_llama.py", line 40, in init
super().init(
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_causal_lm.py", line 1119, in init
model = model_cls(prefix, config, weights)
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 570, in init
self.model = FlashLlamaModel(prefix, config, weights, create_layer_fn)
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 499, in init
[
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 500, in
create_layer_fn(
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 431, in init
self.self_attn = FlashLlamaAttention(
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 265, in init
self.query_key_value = load_attention(config, prefix, weights, layer_id)
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 162, in load_attention
base_layer = load_attention_multi(config, prefix, weights)
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 179, in load_attention_multi
return _load_gqa(config, prefix, weights)
File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 210, in _load_gqa
assert list(weight.shape) == [
AssertionError: [12582912, 1] != [6144, 4096]

2024-12-21T13:33:44.667486Z ERROR shard-manager: lorax_launcher: Shard complete standard error output:

2024-12-21 13:33:39.687 | INFO | lorax_server.utils.state::22 - Backend = fa2
2024-12-21 13:33:39.687 | INFO | lorax_server.utils.state::24 - Prefix caching = False
2024-12-21 13:33:39.687 | INFO | lorax_server.utils.state::25 - Chunked prefill = False
/opt/conda/lib/python3.10/site-packages/torch/distributed/c10d_logger.py:79: FutureWarning: You are using a Backend <class 'lorax_server.utils.dist.FakeGroup'> as a ProcessGroup. This usage is deprecated since PyTorch 2.0. Please use a public API of PyTorch Distributed instead.
return func(*args, **kwargs)
Traceback (most recent call last):

File "/opt/conda/bin/lorax-server", line 8, in
sys.exit(app())

File "/opt/conda/lib/python3.10/site-packages/lorax_server/cli.py", line 92, in serve
server.serve(

File "/opt/conda/lib/python3.10/site-packages/lorax_server/server.py", line 449, in serve
asyncio.run(

File "/opt/conda/lib/python3.10/asyncio/runners.py", line 44, in run
return loop.run_until_complete(main)

File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete
return future.result()

File "/opt/conda/lib/python3.10/site-packages/lorax_server/server.py", line 302, in serve_inner
model = get_model(

File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/init.py", line 186, in get_model
return FlashLlama(

File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_llama.py", line 40, in init
super().init(

File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/flash_causal_lm.py", line 1119, in init
model = model_cls(prefix, config, weights)

File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 570, in init
self.model = FlashLlamaModel(prefix, config, weights, create_layer_fn)

File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 499, in init
[

File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 500, in
create_layer_fn(

File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 431, in init
self.self_attn = FlashLlamaAttention(

File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 265, in init
self.query_key_value = load_attention(config, prefix, weights, layer_id)

File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 162, in load_attention
base_layer = load_attention_multi(config, prefix, weights)

File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 179, in load_attention_multi
return _load_gqa(config, prefix, weights)

File "/opt/conda/lib/python3.10/site-packages/lorax_server/models/custom_modeling/flash_llama_modeling.py", line 210, in _load_gqa
assert list(weight.shape) == [

AssertionError: [12582912, 1] != [6144, 4096]
rank=0
2024-12-21T13:33:44.753438Z ERROR lorax_launcher: Shard 0 failed to start
2024-12-21T13:33:44.753471Z INFO lorax_launcher: Shutting down shards
Error: ShardCannotStart

@codybum
Copy link
Author

codybum commented Dec 21, 2024

The issue appears related to #595 and #607, both closed, with no clear resolution.

@codybum
Copy link
Author

codybum commented Dec 21, 2024

Here is a another related case, #611 referencing container ghcr.io/predibase/lorax:07addea, which I could get to work with https://huggingface.co/hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4, but is very old.

@codybum
Copy link
Author

codybum commented Dec 21, 2024

The problem appears to be here:

if quantize in ["gptq", "awq"]:
...
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)"
...
else:
...
return weight, input_scale, weight_scale
weight = torch.cat(weight_list, dim=dim)
...

weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)

"
weight = torch.cat(weight_list, dim=dim)

Weight can have a different number of parameters, which is not accounted for here:

weight, input_scale, weight_scale = weight

@arnavgarg1
Copy link
Contributor

Hey @codybum ! Thanks for the investigation. Would you be willing to contribute to LoRAX by pushing up a fix with your suggested changes?

@codybum
Copy link
Author

codybum commented Jan 3, 2025

I see where it is breaking, not necessarily how to fix, but I will look into this further and see if I can get something working.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants