Skip to content

Commit

Permalink
Refactor to prepare for parallel sampling (#100)
Browse files Browse the repository at this point in the history
* wip

* wip

* wip

* fix

* fix

* fix

* refactor

* more refactor

* wip

* wip

* more refactor

* more refactor

* fixed

* fixed mypy

* minor

* msg clean

* fix missing finish_reason

* remove unnecessary type annot on defaultdict

* Return requests state from get_requests_to_process

* simplify typing

* reduced list concat

* remove dict add and lookup

* wrong comment

* Revert "remove dict add and lookup"

This reverts commit 5382004.

* fix sampler test

* make it possible to disable prometheus metrics

* collect metrics only in staging engine

* return False in stop_by_length if request is already finished

* move check_stopping_sequences to engine_common.py

* add missing free_request method to Dummy cache manager

* update Dummy cache manager to operate on sequence

* fix request finish condition
  • Loading branch information
masahi authored Dec 11, 2023
1 parent 5aaf55d commit 745ce71
Show file tree
Hide file tree
Showing 12 changed files with 745 additions and 629 deletions.
4 changes: 3 additions & 1 deletion serve/mlc_serve/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
RequestOutput,
StoppingCriteria,
MLCServeEngineConfig,
get_engine_config
get_engine_config,
SequenceId,
RequestState,
)
from .sampling_params import SamplingParams, SamplingType
85 changes: 59 additions & 26 deletions serve/mlc_serve/engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

RequestId = str


# TODO(@sunggg): consider transition to something like Pydantic
@dataclass
class MLCServeEngineConfig:
Expand Down Expand Up @@ -39,6 +40,7 @@ def _from_json(config_cls, json_obj: Dict[Any, Any]):
}
)


def get_engine_config(dict_config):
engine_config = MLCServeEngineConfig._from_json(dict_config)
# Checks to make sure engine configs are set correctly
Expand All @@ -51,26 +53,32 @@ def get_engine_config(dict_config):
assert isinstance(engine_config.min_decode_steps, int)

# TODO(@sunggg): engine allows -1 for these params. figure out the behavior and enable checks properly
assert engine_config.max_num_batched_tokens == -1, \
"`max_num_batched_tokens` is not supposed to be configured directly. \
assert (
engine_config.max_num_batched_tokens == -1
), "`max_num_batched_tokens` is not supposed to be configured directly. \
Use `max_num_sequences` and `max_input_len` instead."
assert engine_config.max_input_len > 0
assert engine_config.max_num_sequences > 0
engine_config.max_num_batched_tokens = engine_config.max_num_sequences * engine_config.max_input_len
engine_config.max_num_batched_tokens = (
engine_config.max_num_sequences * engine_config.max_input_len
)

assert (engine_config.min_decode_steps > 0) and (engine_config.max_decode_steps > 0)
assert engine_config.max_decode_steps > engine_config.min_decode_steps

return engine_config


@dataclass
class StoppingCriteria:
"""
Parameters about when to stop text generation.
"""

max_tokens: Optional[int] = None
stop_sequences: Optional[list[str]] = None


@dataclass
class ChatMessage:
role: str
Expand All @@ -89,16 +97,20 @@ class FinishReason(Enum):
Length = "length"
Cancelled = "cancelled"


# A single token.
Token = int


@dataclass
class ValidationError:
msg: str


# The type signature of the token validation callback.
ValidateTokensCallback = Callable[["Request", List[Token]], ValidationError]


@dataclass
class Request:
request_id: RequestId
Expand All @@ -110,7 +122,9 @@ class Request:
# Options for sampling.
sampling_params: SamplingParams = field(default_factory=SamplingParams)
# Options for stopping.
stopping_criteria: StoppingCriteria = field(default_factory=lambda: StoppingCriteria())
stopping_criteria: StoppingCriteria = field(
default_factory=lambda: StoppingCriteria()
)
# Options for debugging.
debug_options: DebugOptions = field(default_factory=DebugOptions)
# Perform request validation post-tokenization, used by the HTTP layer to control validation.
Expand Down Expand Up @@ -238,38 +252,57 @@ def stop(self) -> None:
...


@dataclass(frozen=True)
class SequenceId:
"""
SequenceId identified a unique sequence to be generated.
Each request will have `n` unique SequenceIds, where `n` is
the `n` from SamplingParams.
"""

request_id: RequestId
sequence_index: int


@dataclass
class GenerationSequence:
seq_id: SequenceId
generated_token_ids: list[int]
next_start_position: int
output_text: str
is_finished: bool = False


@dataclass
class RequestState:
"""
The internal state of request in the InferenceEngine.
"""

request_id: RequestId
token_ids: list[int]
output_text: str
prompt_len: int
next_start_position: int
prompt_token_ids: list[int]
sampling_params: SamplingParams
generation_sequences: list[GenerationSequence]
stopping_criteria: StoppingCriteria
debug_options: DebugOptions
arrival_timestamp: float
is_ended: bool = False
validation_err: Optional[ValidationError] = None

def check_stopping_sequences(stopping_criteria, output_text, delta, is_ended):
if stopping_criteria.stop_sequences:
for t in stopping_criteria.stop_sequences:
if t in output_text:
# since search pattern can include only part of the new generated token,
# we need to trim generated string
# for example, we have "I " in the stopping criteria, previously existed
# output_text had "I" and new coming token "am" would add space before the word
# thus final output_text would have "I am" before verification on stop sequence
# While eventually we need to return "I "
if not output_text.endswith(t):
sub_index = output_text.find(t)
delta = delta[:-(len(output_text) - sub_index - len(t))]
output_text = output_text[:output_text.find(t) + len(t)]
is_ended = True
break
return output_text, delta, is_ended
@property
def is_finished(self) -> bool:
return all(seq.is_finished for seq in self.generation_sequences)

@property
def prompt_len(self) -> int:
return len(self.prompt_token_ids)

@property
def num_sequences(self) -> int:
return len(self.generation_sequences)

@property
def num_total_tokens(self) -> int:
return self.prompt_len + sum(
len(gen_seq.generated_token_ids) for gen_seq in self.generation_sequences
)
Loading

0 comments on commit 745ce71

Please sign in to comment.