Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor to prepare for parallel sampling #100

Merged
merged 32 commits into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
712ccec
wip
masahi Dec 6, 2023
2d05640
wip
masahi Dec 7, 2023
f3742b9
wip
masahi Dec 7, 2023
a38c955
fix
masahi Dec 7, 2023
b960d4d
fix
masahi Dec 7, 2023
a5e6e37
fix
masahi Dec 7, 2023
3b2df21
refactor
masahi Dec 7, 2023
77e0e5f
more refactor
masahi Dec 7, 2023
5808958
wip
masahi Dec 7, 2023
e4f21b4
wip
masahi Dec 7, 2023
4080146
more refactor
masahi Dec 7, 2023
9d42deb
more refactor
masahi Dec 7, 2023
9eb92f8
fixed
masahi Dec 7, 2023
18b8e41
fixed mypy
masahi Dec 7, 2023
bdb0be3
minor
masahi Dec 7, 2023
27da1c2
msg clean
masahi Dec 8, 2023
f9747ac
fix missing finish_reason
masahi Dec 8, 2023
522edd7
remove unnecessary type annot on defaultdict
masahi Dec 8, 2023
5585197
Return requests state from get_requests_to_process
masahi Dec 8, 2023
421d2ea
simplify typing
masahi Dec 8, 2023
a83c494
reduced list concat
masahi Dec 8, 2023
5382004
remove dict add and lookup
masahi Dec 8, 2023
55045d0
wrong comment
masahi Dec 8, 2023
e7b6a3c
Revert "remove dict add and lookup"
masahi Dec 10, 2023
d962435
fix sampler test
masahi Dec 10, 2023
78ab330
make it possible to disable prometheus metrics
masahi Dec 10, 2023
70369fc
collect metrics only in staging engine
masahi Dec 10, 2023
b16f787
return False in stop_by_length if request is already finished
masahi Dec 10, 2023
fd39416
move check_stopping_sequences to engine_common.py
masahi Dec 10, 2023
c8b7f55
add missing free_request method to Dummy cache manager
masahi Dec 11, 2023
1853a54
update Dummy cache manager to operate on sequence
masahi Dec 11, 2023
242e3de
fix request finish condition
masahi Dec 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion serve/mlc_serve/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
RequestOutput,
StoppingCriteria,
MLCServeEngineConfig,
get_engine_config
get_engine_config,
SequenceId,
RequestState,
)
from .sampling_params import SamplingParams, SamplingType
74 changes: 63 additions & 11 deletions serve/mlc_serve/engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

RequestId = str


# TODO(@sunggg): consider transition to something like Pydantic
@dataclass
class MLCServeEngineConfig:
Expand Down Expand Up @@ -39,6 +40,7 @@ def _from_json(config_cls, json_obj: Dict[Any, Any]):
}
)


def get_engine_config(dict_config):
engine_config = MLCServeEngineConfig._from_json(dict_config)
# Checks to make sure engine configs are set correctly
Expand All @@ -51,26 +53,32 @@ def get_engine_config(dict_config):
assert isinstance(engine_config.min_decode_steps, int)

# TODO(@sunggg): engine allows -1 for these params. figure out the behavior and enable checks properly
assert engine_config.max_num_batched_tokens == -1, \
"`max_num_batched_tokens` is not supposed to be configured directly. \
assert (
engine_config.max_num_batched_tokens == -1
), "`max_num_batched_tokens` is not supposed to be configured directly. \
Use `max_num_sequences` and `max_input_len` instead."
assert engine_config.max_input_len > 0
assert engine_config.max_num_sequences > 0
engine_config.max_num_batched_tokens = engine_config.max_num_sequences * engine_config.max_input_len
engine_config.max_num_batched_tokens = (
engine_config.max_num_sequences * engine_config.max_input_len
)

assert (engine_config.min_decode_steps > 0) and (engine_config.max_decode_steps > 0)
assert engine_config.max_decode_steps > engine_config.min_decode_steps

return engine_config


@dataclass
class StoppingCriteria:
"""
Parameters about when to stop text generation.
"""

max_tokens: Optional[int] = None
stop_sequences: Optional[list[str]] = None


@dataclass
class ChatMessage:
role: str
Expand All @@ -89,16 +97,20 @@ class FinishReason(Enum):
Length = "length"
Cancelled = "cancelled"


# A single token.
Token = int


@dataclass
class ValidationError:
msg: str


# The type signature of the token validation callback.
ValidateTokensCallback = Callable[["Request", List[Token]], ValidationError]


@dataclass
class Request:
request_id: RequestId
Expand All @@ -110,7 +122,9 @@ class Request:
# Options for sampling.
sampling_params: SamplingParams = field(default_factory=SamplingParams)
# Options for stopping.
stopping_criteria: StoppingCriteria = field(default_factory=lambda: StoppingCriteria())
stopping_criteria: StoppingCriteria = field(
default_factory=lambda: StoppingCriteria()
)
# Options for debugging.
debug_options: DebugOptions = field(default_factory=DebugOptions)
# Perform request validation post-tokenization, used by the HTTP layer to control validation.
Expand Down Expand Up @@ -238,24 +252,62 @@ def stop(self) -> None:
...


@dataclass(frozen=True)
class SequenceId:
"""
SequenceId identified a unique sequence to be generated.

Each request will have `n` unique SequenceIds, where `n` is
the `n` from SamplingParams.
"""

request_id: RequestId
sequence_index: int


@dataclass
class GenerationSequence:
seq_id: SequenceId
generated_token_ids: list[int]
next_start_position: int
output_text: str
is_finished: bool = False


@dataclass
class RequestState:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now this class basically acts like SequenceGroup in vllm.

"""
The internal state of request in the InferenceEngine.
"""

request_id: RequestId
token_ids: list[int]
output_text: str
prompt_len: int
next_start_position: int
prompt_token_ids: list[int]
sampling_params: SamplingParams
generation_sequences: list[GenerationSequence]
stopping_criteria: StoppingCriteria
debug_options: DebugOptions
arrival_timestamp: float
is_ended: bool = False
validation_err: Optional[ValidationError] = None

@property
def is_finished(self) -> bool:
return all(seq.is_finished for seq in self.generation_sequences)

@property
def prompt_len(self) -> int:
return len(self.prompt_token_ids)

@property
def num_sequences(self) -> int:
return len(self.generation_sequences)

@property
def num_total_tokens(self) -> int:
return self.prompt_len + sum(
len(gen_seq.generated_token_ids) for gen_seq in self.generation_sequences
)


def check_stopping_sequences(stopping_criteria, output_text, delta, is_ended):
if stopping_criteria.stop_sequences:
for t in stopping_criteria.stop_sequences:
Expand All @@ -268,8 +320,8 @@ def check_stopping_sequences(stopping_criteria, output_text, delta, is_ended):
# While eventually we need to return "I "
if not output_text.endswith(t):
sub_index = output_text.find(t)
delta = delta[:-(len(output_text) - sub_index - len(t))]
output_text = output_text[:output_text.find(t) + len(t)]
delta = delta[: -(len(output_text) - sub_index - len(t))]
output_text = output_text[: output_text.find(t) + len(t)]
is_ended = True
break
return output_text, delta, is_ended
Loading