Skip to content

Commit

Permalink
fixed mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Dec 7, 2023
1 parent 1c983df commit 389d999
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 18 deletions.
4 changes: 2 additions & 2 deletions serve/mlc_serve/engine/engine_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ def update_sequence(

def get_requests_to_process(
current_states: list[RequestState], cache_manager: KVCacheManager
) -> Tuple[Union[list[PrefillRequest], list[DecodeRequest]], bool]:
requests = []
) -> Tuple[list[Union[PrefillRequest, DecodeRequest]], bool]:
requests : list[Union[PrefillRequest, DecodeRequest]] = []
# TODO: consider having hybrid batch if the underlying attention kernel supports
# mixing prefill and decode.
is_prompt_batch = any(
Expand Down
28 changes: 13 additions & 15 deletions serve/mlc_serve/engine/sync_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import logging
from collections import defaultdict
from typing import Set
from typing import Set, DefaultDict

from .base import (
FinishReason,
Expand Down Expand Up @@ -108,19 +108,17 @@ def step(self) -> InferenceStepResult:
finish_reason = FinishReason.Length

if finish_reason is not None:
seq_outputs = [
SequenceOutput(
i,
finish_reason=finish_reason,
num_generated_tokens=len(gen_seq.generated_token_ids),
)
for i, gen_seq in enumerate(state.generation_sequences)
]

outputs.append(
RequestOutput(
state.request_id,
seq_outputs,
[
SequenceOutput(
i,
finish_reason=finish_reason,
num_generated_tokens=len(gen_seq.generated_token_ids),
)
for i, gen_seq in enumerate(state.generation_sequences)
],
num_prompt_tokens=state.prompt_len,
)
)
Expand Down Expand Up @@ -158,7 +156,7 @@ def step(self) -> InferenceStepResult:
return InferenceStepResult(outputs)

requests, _ = get_requests_to_process(
self.current_batch.values(), self.cache_manager
list(self.current_batch.values()), self.cache_manager
)
results = self.text_generator.generate(requests, self.cache_manager.get_cache())
logger.debug("Finished text generation.")
Expand All @@ -182,7 +180,7 @@ def step(self) -> InferenceStepResult:
else:
valid_results.append(res)

seq_outputs = defaultdict(list)
seq_outputs: DefaultDict[RequestId, list[SequenceOutput]] = defaultdict(list)

for res in valid_results:
request_id = res.sequence_id.request_id
Expand Down Expand Up @@ -216,12 +214,12 @@ def step(self) -> InferenceStepResult:
)
)

for request_id, seq_outputs in seq_outputs.items():
for request_id, out_seqs in seq_outputs.items():
state = self.current_batch[request_id]
outputs.append(
RequestOutput(
request_id,
sequences=seq_outputs,
sequences=out_seqs,
num_prompt_tokens=state.prompt_len,
)
)
Expand Down
2 changes: 1 addition & 1 deletion serve/mlc_serve/model/paged_cache_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,7 @@ def generate(
outputs.append(
TextGenerationResult(
sequence_id=sequence_id,
generated_tokens=[maybe_new_token[0]],
generated_tokens=[maybe_new_token[0]], # type: ignore
error=None,
)
)
Expand Down

0 comments on commit 389d999

Please sign in to comment.