Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor to prepare for parallel sampling #100

Merged
merged 32 commits into from
Dec 11, 2023

Conversation

masahi
Copy link
Member

@masahi masahi commented Dec 7, 2023

The main changes are

  • Change various data structures to associate a request with multiple generated sequences
  • Extract common bits in sync and staging engines into engine_common.py

There is no functional change, but since the change is big, I request a careful review.

@sunggg @elvin-n

# entries for the older tokens.
if (
len(self.current_batch) == 0
and num_tokens > self.max_num_batched_tokens
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously this line was

num_new_batched_tokens > self.max_num_batched_tokens

but I believe the lhs should be num_tokens. Can you confirm? @elvin-n

output_text: str
is_finished: bool = False


@dataclass
class RequestState:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now this class basically acts like SequenceGroup in vllm.

@masahi masahi force-pushed the parallel-sampling branch from e0770c3 to bdb0be3 Compare December 8, 2023 00:07
@masahi masahi marked this pull request as draft December 8, 2023 00:14
@masahi masahi marked this pull request as ready for review December 8, 2023 05:44
prefix_idx = generation_sequence.next_start_position

# TODO(masahi): Figure out a way to remove this concat
token_ids = prompt_tokens + generation_sequence.generated_token_ids
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This list concat is probably causing perf regression. Improving this is left for future work.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should be very few tokens passed into decode_last_output, copying of them into new container and concatting the new token unlikely can lead to perf degradation

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR #102 changes this part. Although it adds more concats and I also agree with @elvin-n, it would be good to collect data and confirm.

if not gen_seq.is_finished:
# TODO(masahi): No need to add prompt_token_ids here if we send
# the prompt len instead
token_ids = state.prompt_token_ids + gen_seq.generated_token_ids
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This list concat is probably causing perf regression. Improving this is left for future work.
We need to update paged_cache_model.py if we remove prompt tokens here.

Copy link
Member

@sunggg sunggg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @masahi for the refactoring!
Would you try running serve/tests/unittest/test_engine_with_samplers.py? I'm getting errors with Prometheus metric and finish reason for stopping condition.

"Preempt request to free %s tokens",
len(request_to_remove.token_ids),
)
self.evict_request()

if self.cache_manager.get_max_new_tokens() <= self.max_decode_steps:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this why we cannot move _adjust_batch to engine_common.py?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, in the sync engine, there are calls to self._discard_cancelled_requests_from_queue() that are not in the staging engine.

https://github.com/octoml/mlc-llm/blob/batch-serving/serve/mlc_serve/engine/sync_engine.py#L263
https://github.com/octoml/mlc-llm/blob/batch-serving/serve/mlc_serve/engine/sync_engine.py#L316

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. Is there anything sync engine-specific thing about this method? Seems like cancellation logics are bit different. If the answer is no, maybe follow-up PR can further unify this.

serve/mlc_serve/engine/engine_common.py Show resolved Hide resolved
return True


class EngineBase:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to move this base class to engine/base.py?
Also, I'm wondering if further unification with InferenceEngine and ScopedInferenceEngine is possible.

https://github.com/octoml/mlc-llm/blob/batch-serving/serve/mlc_serve/engine/base.py#L180

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

engine/base.py is mostly declarations of core data structures while engine_common.py is specifically for sharing common code between the two engines. So EngineBase belongs to the latter.

InferenceEngine and ScopedInferenceEngine are interface classes. I don't understand what you mean by "unifying" them with EngineBase.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, sorry. What I meant is maybe we can unify InferenceEngine and ScopedInferenceEngine. I guess it is possible by letting sync engine do nothing for start and stop. I guess my tiny complaint is we have many similar-looking classes with small differences that might confuse the new beginners.

@masahi
Copy link
Member Author

masahi commented Dec 11, 2023

Would you try running serve/tests/unittest/test_engine_with_samplers.py? I'm getting errors with Prometheus metric and finish reason for stopping condition.

They are fixed now.

@elvin-n
Copy link

elvin-n commented Dec 11, 2023

python -m pytest serve/tests

================================================================================== short test summary info ===================================================================================
FAILED serve/tests/unittest/test_staging_engine.py::test_single_request - RuntimeError: Error when calling GenerationLoopWorker: 'DummyCacheManager' object has no attribute 'free_request'
FAILED serve/tests/unittest/test_staging_engine.py::test_single_request_step_to_finish - RuntimeError: Error when calling GenerationLoopWorker: 'DummyCacheManager' object has no attribute 'free_request'
FAILED serve/tests/unittest/test_staging_engine.py::test_multiple_requests_wait_queue - RuntimeError: Error when calling GenerationLoopWorker: 'DummyCacheManager' object has no attribute 'free_request'
FAILED serve/tests/unittest/test_staging_engine.py::test_multiple_requests_preempt - RuntimeError: Error when calling GenerationLoopWorker: 'DummyCacheManager' object has no attribute 'free_request'
FAILED serve/tests/unittest/test_staging_engine.py::test_cache_evict_hang_staging - RuntimeError: Error when calling GenerationLoopWorker: 'DummyCacheManager' object has no attribute 'free_request'
FAILED serve/tests/unittest/test_staging_engine.py::test_big_prompt_fit_to_cache_staging - RuntimeError: Error when calling GenerationLoopWorker: 'DummyCacheManager' object has no attribute 'free_request'
FAILED serve/tests/unittest/test_sync_engine.py::test_single_request_step_to_finish - AttributeError: 'DummyCacheManager' object has no attribute 'free_request'
FAILED serve/tests/unittest/test_sync_engine.py::test_multiple_requests_wait_queue - AttributeError: 'DummyCacheManager' object has no attribute 'free_request'
FAILED serve/tests/unittest/test_sync_engine.py::test_multiple_requests_preempt - AttributeError: 'DummyCacheManager' object has no attribute 'free_request'
FAILED serve/tests/unittest/test_sync_engine.py::test_cache_evict_hang - AttributeError: 'DummyCacheManager' object has no attribute 'free_request'
FAILED serve/tests/unittest/test_sync_engine.py::test_big_prompt_fit_to_cache - AttributeError: 'DummyCacheManager' object has no attribute 'free_request'
========================================================================= 11 failed, 6 passed, 3 warnings in 10.30s ==========================================================================

@masahi
Copy link
Member Author

masahi commented Dec 11, 2023

@elvin-n Fixed.

serve/mlc_serve/model/paged_cache_model.py Show resolved Hide resolved
prefix_idx = generation_sequence.next_start_position

# TODO(masahi): Figure out a way to remove this concat
token_ids = prompt_tokens + generation_sequence.generated_token_ids
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should be very few tokens passed into decode_last_output, copying of them into new container and concatting the new token unlikely can lead to perf degradation

serve/mlc_serve/engine/staging_engine.py Show resolved Hide resolved
Copy link
Member

@sunggg sunggg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @masahi for refactoring and @elvin-n for review!

@sunggg sunggg merged commit 745ce71 into octoml:batch-serving Dec 11, 2023
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants