-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor to prepare for parallel sampling #100
Conversation
# entries for the older tokens. | ||
if ( | ||
len(self.current_batch) == 0 | ||
and num_tokens > self.max_num_batched_tokens |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Previously this line was
num_new_batched_tokens > self.max_num_batched_tokens
but I believe the lhs should be num_tokens
. Can you confirm? @elvin-n
output_text: str | ||
is_finished: bool = False | ||
|
||
|
||
@dataclass | ||
class RequestState: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now this class basically acts like SequenceGroup
in vllm.
e0770c3
to
bdb0be3
Compare
prefix_idx = generation_sequence.next_start_position | ||
|
||
# TODO(masahi): Figure out a way to remove this concat | ||
token_ids = prompt_tokens + generation_sequence.generated_token_ids |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This list concat is probably causing perf regression. Improving this is left for future work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it should be very few tokens passed into decode_last_output
, copying of them into new container and concatting the new token unlikely can lead to perf degradation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if not gen_seq.is_finished: | ||
# TODO(masahi): No need to add prompt_token_ids here if we send | ||
# the prompt len instead | ||
token_ids = state.prompt_token_ids + gen_seq.generated_token_ids |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This list concat is probably causing perf regression. Improving this is left for future work.
We need to update paged_cache_model.py
if we remove prompt tokens here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @masahi for the refactoring!
Would you try running serve/tests/unittest/test_engine_with_samplers.py
? I'm getting errors with Prometheus metric and finish reason for stopping condition.
"Preempt request to free %s tokens", | ||
len(request_to_remove.token_ids), | ||
) | ||
self.evict_request() | ||
|
||
if self.cache_manager.get_max_new_tokens() <= self.max_decode_steps: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this why we cannot move _adjust_batch
to engine_common.py
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, in the sync engine, there are calls to self._discard_cancelled_requests_from_queue()
that are not in the staging engine.
https://github.com/octoml/mlc-llm/blob/batch-serving/serve/mlc_serve/engine/sync_engine.py#L263
https://github.com/octoml/mlc-llm/blob/batch-serving/serve/mlc_serve/engine/sync_engine.py#L316
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting. Is there anything sync engine-specific thing about this method? Seems like cancellation logics are bit different. If the answer is no, maybe follow-up PR can further unify this.
return True | ||
|
||
|
||
class EngineBase: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make sense to move this base class to engine/base.py
?
Also, I'm wondering if further unification with InferenceEngine
and ScopedInferenceEngine
is possible.
https://github.com/octoml/mlc-llm/blob/batch-serving/serve/mlc_serve/engine/base.py#L180
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
engine/base.py
is mostly declarations of core data structures while engine_common.py
is specifically for sharing common code between the two engines. So EngineBase
belongs to the latter.
InferenceEngine
and ScopedInferenceEngine
are interface classes. I don't understand what you mean by "unifying" them with EngineBase
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, sorry. What I meant is maybe we can unify InferenceEngine
and ScopedInferenceEngine
. I guess it is possible by letting sync engine do nothing for start
and stop
. I guess my tiny complaint is we have many similar-looking classes with small differences that might confuse the new beginners.
f73a430
to
b16f787
Compare
They are fixed now. |
|
@elvin-n Fixed. |
prefix_idx = generation_sequence.next_start_position | ||
|
||
# TODO(masahi): Figure out a way to remove this concat | ||
token_ids = prompt_tokens + generation_sequence.generated_token_ids |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it should be very few tokens passed into decode_last_output
, copying of them into new container and concatting the new token unlikely can lead to perf degradation
159ca14
to
242e3de
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main changes are
engine_common.py
There is no functional change, but since the change is big, I request a careful review.
@sunggg @elvin-n