From 0632d19a6f010e5a44ff2917c3adc4556317fb65 Mon Sep 17 00:00:00 2001 From: Amey Agrawal Date: Sun, 21 Jul 2024 00:50:10 -0700 Subject: [PATCH 01/24] minor --- .vscode/settings.json | 2 + data/device_configs/a100.yml | 4 - data/device_configs/a40.yml | 3 - data/device_configs/h100.yml | 3 - data/model_configs/Qwen/Qwen-72B.yml | 15 - .../codellama/CodeLlama-34b-Instruct-hf.yml | 16 - data/model_configs/internlm/internlm-20b.yml | 15 - data/model_configs/internlm/internlm2-20b.yml | 15 - .../meta-llama/Llama-2-70b-hf.yml | 16 - .../meta-llama/Llama-2-7b-hf.yml | 16 - data/model_configs/microsoft/phi-2.yml | 17 - data/model_configs/openai/gpt3.yml | 7 - data/model_configs/tiiuae/falcon-180B.yml | 8 - .../all_reduce.csv | 0 .../send_recv.csv | 0 .../attention.csv | 0 .../send_recv.csv | 0 .../all_reduce.csv | 0 .../send_recv.csv | 0 vidur/config/__init__.py | 4 +- vidur/config/base_poly_config.py | 16 + vidur/config/config.py | 446 ++++++++++++++---- vidur/config/default.yml | 162 ------- vidur/config/device_sku_config.py | 45 ++ vidur/config/flat_dataclass.py | 204 ++++++++ vidur/config/model_config.py | 151 ++++++ vidur/config/node_sku_config.py | 62 +++ vidur/config/utils.py | 60 +++ .../dummy_execution_time_predictor.py | 33 -- .../execution_time_predictor_registry.py | 7 - .../orca_replica_scheduler.py | 8 - .../sarathi_replica_scheduler.py | 26 +- vidur/types/__init__.py | 11 + vidur/types/activation_type.py | 6 + vidur/types/device_sku_type.py | 7 + vidur/types/model_type.py | 10 + vidur/types/node_sku_type.py | 9 + vidur/types/norm_type.py | 7 + 38 files changed, 939 insertions(+), 472 deletions(-) delete mode 100644 data/device_configs/a100.yml delete mode 100644 data/device_configs/a40.yml delete mode 100644 data/device_configs/h100.yml delete mode 100644 data/model_configs/Qwen/Qwen-72B.yml delete mode 100644 data/model_configs/codellama/CodeLlama-34b-Instruct-hf.yml delete mode 100644 data/model_configs/internlm/internlm-20b.yml delete mode 100644 data/model_configs/internlm/internlm2-20b.yml delete mode 100644 data/model_configs/meta-llama/Llama-2-70b-hf.yml delete mode 100644 data/model_configs/meta-llama/Llama-2-7b-hf.yml delete mode 100644 data/model_configs/microsoft/phi-2.yml delete mode 100644 data/model_configs/openai/gpt3.yml delete mode 100644 data/model_configs/tiiuae/falcon-180B.yml rename data/profiling/network/{a100_pair_nvlink => a100_pairwise_nvlink}/all_reduce.csv (100%) rename data/profiling/network/{a100_pair_nvlink => a100_pairwise_nvlink}/send_recv.csv (100%) rename data/profiling/network/{a40_pair_nvlink => a40_pairwise_nvlink}/attention.csv (100%) rename data/profiling/network/{a40_pair_nvlink => a40_pairwise_nvlink}/send_recv.csv (100%) rename data/profiling/network/{h100_pair_nvlink => h100_pairwise_nvlink}/all_reduce.csv (100%) rename data/profiling/network/{h100_pair_nvlink => h100_pairwise_nvlink}/send_recv.csv (100%) create mode 100644 vidur/config/base_poly_config.py delete mode 100644 vidur/config/default.yml create mode 100644 vidur/config/device_sku_config.py create mode 100644 vidur/config/flat_dataclass.py create mode 100644 vidur/config/model_config.py create mode 100644 vidur/config/node_sku_config.py create mode 100644 vidur/config/utils.py delete mode 100644 vidur/execution_time_predictor/dummy_execution_time_predictor.py create mode 100644 vidur/types/activation_type.py create mode 100644 vidur/types/device_sku_type.py create mode 100644 vidur/types/model_type.py create mode 100644 vidur/types/node_sku_type.py create mode 100644 vidur/types/norm_type.py diff --git a/.vscode/settings.json b/.vscode/settings.json index be38d2c..fda9355 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,5 +1,7 @@ { "cSpell.words": [ + "INTERNLM", + "QWEN", "vidur" ] } \ No newline at end of file diff --git a/data/device_configs/a100.yml b/data/device_configs/a100.yml deleted file mode 100644 index b2ba170..0000000 --- a/data/device_configs/a100.yml +++ /dev/null @@ -1,4 +0,0 @@ -fp16_tflops: 312 -total_memory_gb: 80 -num_devices_per_node: 4 - diff --git a/data/device_configs/a40.yml b/data/device_configs/a40.yml deleted file mode 100644 index 17e0446..0000000 --- a/data/device_configs/a40.yml +++ /dev/null @@ -1,3 +0,0 @@ -fp16_tflops: 150 -total_memory_gb: 45 -num_devices_per_node: 8 diff --git a/data/device_configs/h100.yml b/data/device_configs/h100.yml deleted file mode 100644 index cd17475..0000000 --- a/data/device_configs/h100.yml +++ /dev/null @@ -1,3 +0,0 @@ -fp16_tflops: 1000 -total_memory_gb: 80 -num_devices_per_node: 4 diff --git a/data/model_configs/Qwen/Qwen-72B.yml b/data/model_configs/Qwen/Qwen-72B.yml deleted file mode 100644 index 0acd50d..0000000 --- a/data/model_configs/Qwen/Qwen-72B.yml +++ /dev/null @@ -1,15 +0,0 @@ -num_layers: 80 -num_q_heads: 64 -num_kv_heads: 64 -embedding_dim: 8192 -mlp_hidden_dim: 24576 -max_position_embeddings: 32768 -use_gated_mlp: true -use_bias: false -use_qkv_bias: true -activation: silu -norm: rms_norm -post_attn_norm: true -rope_theta: 1000000 -vocab_size: 152064 -is_neox_style: true \ No newline at end of file diff --git a/data/model_configs/codellama/CodeLlama-34b-Instruct-hf.yml b/data/model_configs/codellama/CodeLlama-34b-Instruct-hf.yml deleted file mode 100644 index e07dde4..0000000 --- a/data/model_configs/codellama/CodeLlama-34b-Instruct-hf.yml +++ /dev/null @@ -1,16 +0,0 @@ -num_layers: 48 -num_q_heads: 64 -num_kv_heads: 8 -embedding_dim: 8192 -mlp_hidden_dim: 22016 -max_position_embeddings: 16384 -use_gated_mlp: true -use_bias: false -use_qkv_bias: false -activation: silu -norm: rms_norm -post_attn_norm: true -rope_scaling: null -rope_theta: 1000000 -vocab_size: 32768 -is_neox_style: true \ No newline at end of file diff --git a/data/model_configs/internlm/internlm-20b.yml b/data/model_configs/internlm/internlm-20b.yml deleted file mode 100644 index d93b467..0000000 --- a/data/model_configs/internlm/internlm-20b.yml +++ /dev/null @@ -1,15 +0,0 @@ -num_layers: 60 -num_q_heads: 40 -num_kv_heads: 40 -embedding_dim: 5120 -mlp_hidden_dim: 13824 -max_position_embeddings: 4096 -use_gated_mlp: true -use_bias: false -use_qkv_bias: false -activation: silu -norm: rms_norm -rope_scaling: null -rope_theta: 10000 -post_attn_norm: true -vocab_size: 103168 diff --git a/data/model_configs/internlm/internlm2-20b.yml b/data/model_configs/internlm/internlm2-20b.yml deleted file mode 100644 index f8a0337..0000000 --- a/data/model_configs/internlm/internlm2-20b.yml +++ /dev/null @@ -1,15 +0,0 @@ -num_layers: 48 -num_q_heads: 48 -num_kv_heads: 8 -embedding_dim: 6144 -mlp_hidden_dim: 16384 -max_position_embeddings: 32768 -use_gated_mlp: true -use_bias: false -use_qkv_bias: false -act: silu -norm: rms_norm -post_attn_norm: true -rope_scaling: null -rope_theta: 1000000 -vocab_size: 92544 diff --git a/data/model_configs/meta-llama/Llama-2-70b-hf.yml b/data/model_configs/meta-llama/Llama-2-70b-hf.yml deleted file mode 100644 index f74d844..0000000 --- a/data/model_configs/meta-llama/Llama-2-70b-hf.yml +++ /dev/null @@ -1,16 +0,0 @@ -num_layers: 80 -num_q_heads: 64 -num_kv_heads: 8 -embedding_dim: 8192 -mlp_hidden_dim: 28672 -max_position_embeddings: 4096 -use_gated_mlp: true -use_bias: false -use_qkv_bias: false -activation: silu -norm: rms_norm -post_attn_norm: true -rope_theta: 10000.0 -rope_scaling: null -vocab_size: 32768 -is_neox_style: true \ No newline at end of file diff --git a/data/model_configs/meta-llama/Llama-2-7b-hf.yml b/data/model_configs/meta-llama/Llama-2-7b-hf.yml deleted file mode 100644 index a49b2bb..0000000 --- a/data/model_configs/meta-llama/Llama-2-7b-hf.yml +++ /dev/null @@ -1,16 +0,0 @@ -num_layers: 32 -num_q_heads: 32 -num_kv_heads: 32 -embedding_dim: 4096 -mlp_hidden_dim: 11008 -max_position_embeddings: 4096 -use_gated_mlp: true -use_bias: false -use_qkv_bias: false -activation: silu -norm: rms_norm -post_attn_norm: true -rope_theta: 10000.0 -rope_scaling: null -vocab_size: 32768 -is_neox_style: true diff --git a/data/model_configs/microsoft/phi-2.yml b/data/model_configs/microsoft/phi-2.yml deleted file mode 100644 index 77c7776..0000000 --- a/data/model_configs/microsoft/phi-2.yml +++ /dev/null @@ -1,17 +0,0 @@ -num_layers: 32 -num_q_heads: 32 -num_kv_heads: 32 -embedding_dim: 2560 -mlp_hidden_dim: 10240 -max_position_embeddings: 2048 -use_gated_mlp: false -use_bias: true -use_qkv_bias: true -activation: gelu -norm: layer_norm -post_attn_norm: false -vocab_size: 51200 -rope_scaling: null -rope_theta: 10000.0 -partial_rotary_factor: 0.4 -no_tensor_parallel: true diff --git a/data/model_configs/openai/gpt3.yml b/data/model_configs/openai/gpt3.yml deleted file mode 100644 index a34b2b3..0000000 --- a/data/model_configs/openai/gpt3.yml +++ /dev/null @@ -1,7 +0,0 @@ -num_layers: 96 -num_q_heads: 96 -num_kv_heads: 96 -embedding_dim: 12288 -mlp_hidden_dim: 49152 -use_gated_mlp: false -vocab_size: 50257 diff --git a/data/model_configs/tiiuae/falcon-180B.yml b/data/model_configs/tiiuae/falcon-180B.yml deleted file mode 100644 index 4b8cdb4..0000000 --- a/data/model_configs/tiiuae/falcon-180B.yml +++ /dev/null @@ -1,8 +0,0 @@ -num_layers: 80 -num_q_heads: 232 -num_kv_heads: 8 -embedding_dim: 14848 -mlp_hidden_dim: 59392 -use_gated_mlp: false -vocab_size: 65024 -is_neox_style: true \ No newline at end of file diff --git a/data/profiling/network/a100_pair_nvlink/all_reduce.csv b/data/profiling/network/a100_pairwise_nvlink/all_reduce.csv similarity index 100% rename from data/profiling/network/a100_pair_nvlink/all_reduce.csv rename to data/profiling/network/a100_pairwise_nvlink/all_reduce.csv diff --git a/data/profiling/network/a100_pair_nvlink/send_recv.csv b/data/profiling/network/a100_pairwise_nvlink/send_recv.csv similarity index 100% rename from data/profiling/network/a100_pair_nvlink/send_recv.csv rename to data/profiling/network/a100_pairwise_nvlink/send_recv.csv diff --git a/data/profiling/network/a40_pair_nvlink/attention.csv b/data/profiling/network/a40_pairwise_nvlink/attention.csv similarity index 100% rename from data/profiling/network/a40_pair_nvlink/attention.csv rename to data/profiling/network/a40_pairwise_nvlink/attention.csv diff --git a/data/profiling/network/a40_pair_nvlink/send_recv.csv b/data/profiling/network/a40_pairwise_nvlink/send_recv.csv similarity index 100% rename from data/profiling/network/a40_pair_nvlink/send_recv.csv rename to data/profiling/network/a40_pairwise_nvlink/send_recv.csv diff --git a/data/profiling/network/h100_pair_nvlink/all_reduce.csv b/data/profiling/network/h100_pairwise_nvlink/all_reduce.csv similarity index 100% rename from data/profiling/network/h100_pair_nvlink/all_reduce.csv rename to data/profiling/network/h100_pairwise_nvlink/all_reduce.csv diff --git a/data/profiling/network/h100_pair_nvlink/send_recv.csv b/data/profiling/network/h100_pairwise_nvlink/send_recv.csv similarity index 100% rename from data/profiling/network/h100_pair_nvlink/send_recv.csv rename to data/profiling/network/h100_pairwise_nvlink/send_recv.csv diff --git a/vidur/config/__init__.py b/vidur/config/__init__.py index 83c25e2..27c9ec6 100644 --- a/vidur/config/__init__.py +++ b/vidur/config/__init__.py @@ -1,3 +1 @@ -from vidur.config.config import Config - -__all__ = [Config] +from .config import * diff --git a/vidur/config/base_poly_config.py b/vidur/config/base_poly_config.py new file mode 100644 index 0000000..fbdd0e1 --- /dev/null +++ b/vidur/config/base_poly_config.py @@ -0,0 +1,16 @@ +from abc import ABC +from dataclasses import dataclass +from typing import Any + +from vidur.config.utils import get_all_subclasses + + +@dataclass +class BasePolyConfig(ABC): + + @classmethod + def create_from_type(cls, type_: Any) -> Any: + for subclass in get_all_subclasses(cls): + if subclass.get_type() == type_: + return subclass() + raise ValueError(f"Invalid type: {type_}") diff --git a/vidur/config/config.py b/vidur/config/config.py index 5a16203..898cd88 100644 --- a/vidur/config/config.py +++ b/vidur/config/config.py @@ -1,109 +1,353 @@ -import argparse -import datetime -import os +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from datetime import datetime +from typing import Optional, List -import yaml - -from vidur.constants import DEFAULT_CONFIG_FILE, DEVICE_CONFIG_DIR, MODEL_CONFIG_DIR +from vidur.config.base_poly_config import BasePolyConfig +from vidur.config.flat_dataclass import create_flat_dataclass from vidur.logger import init_logger +from vidur.types import ReplicaSchedulerType, GlobalSchedulerType, ExecutionTimePredictorType, RequestGeneratorType, RequestIntervalGeneratorType, RequestLengthGeneratorType logger = init_logger(__name__) -class Config: - def __init__(self, config_file=DEFAULT_CONFIG_FILE): - self._parser = argparse.ArgumentParser() - self._args = None - self._load_yaml(config_file) - self._parse_args() - self._add_derived_args() - self._write_yaml_to_file() - logger.info(f"Config: {self.get_yaml()}") - - def _load_yaml(self, filename): - with open(filename, "r") as file: - yaml_config = yaml.safe_load(file) - self._update_namespace(yaml_config) - - def _parse_args(self): - self._args = self._parser.parse_args() - - def _add_derived_args(self): - self._args.output_dir = f"{self._args.output_dir}/{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S-%f')}" - os.makedirs(self._args.output_dir, exist_ok=True) - self._load_model_config() - self._load_device_config() - self._substitute_variables_in_args() - - def _update_namespace(self, config_dict, parent_key=""): - for key, value in config_dict.items(): - if isinstance(value, dict): - new_key = f"{parent_key}{key}_" if parent_key else f"{key}_" - self._update_namespace(value, new_key) - else: - arg_name = f"{parent_key}{key}" - - if type(value) == bool: - self._parser.add_argument( - f"--{arg_name}", - default=value, - action=argparse.BooleanOptionalAction, - ) - elif arg_name in [ - "simulator_time_limit", - "metrics_store_subsamples", - "replica_scheduler_num_blocks", - ]: - self._parser.add_argument(f"--{arg_name}", default=value, type=int) - else: - self._parser.add_argument( - f"--{arg_name}", default=value, type=type(value) - ) - - def __getattr__(self, name): - return getattr(self._args, name, None) - - def get_yaml(self): - return yaml.dump(self._args.__dict__, default_flow_style=False) - - def _write_yaml_to_file(self): - with open(f"{self._args.output_dir}/config.yml", "w") as file: - file.write(self.get_yaml()) +@dataclass +class BaseRequestIntervalGeneratorConfig(BasePolyConfig): + seed: int = 42 + + +@dataclass +class BaseRequestLengthGeneratorConfig(BasePolyConfig): + seed: int = 42 + + +@dataclass +class TraceRequestIntervalGeneratorConfig(BaseRequestIntervalGeneratorConfig): + trace_file: str = ( + "data/processed_traces/AzureFunctionsInvocationTraceForTwoWeeksJan2021Processed.csv" + ) + start_time: str = "1970-01-04 12:00:00" + end_time: str = "1970-01-04 15:00:00" + time_scale_factor: float = 0.3 + + @staticmethod + def get_type(): + return RequestIntervalGeneratorType.TRACE + + +@dataclass +class PoissonRequestIntervalGeneratorConfig(BaseRequestIntervalGeneratorConfig): + qps: float = 1.0 + + @staticmethod + def get_type(): + return RequestIntervalGeneratorType.POISSON + + +@dataclass +class GammaRequestIntervalGeneratorConfig(BaseRequestIntervalGeneratorConfig): + qps: float = 1.0 + cv: float = 0.5 + + @staticmethod + def get_type(): + return RequestIntervalGeneratorType.GAMMA + + +@dataclass +class StaticRequestIntervalGeneratorConfig(BaseRequestIntervalGeneratorConfig): + @staticmethod + def get_type(): + return RequestIntervalGeneratorType.STATIC + + +@dataclass +class TraceRequestLengthGeneratorConfig(BaseRequestLengthGeneratorConfig): + trace_file: str = ( + "data/processed_traces/sharegpt_8k_filtered_stats_llama2_tokenizer.csv" + ) + prefill_scale_factor: float = 1 + decode_scale_factor: float = 1 + max_tokens: int = 4096 + + @staticmethod + def get_type(): + return RequestLengthGeneratorType.TRACE + + +@dataclass +class ZipfRequestLengthGeneratorConfig(BaseRequestLengthGeneratorConfig): + theta: float = 0.6 + scramble: bool = False + min_tokens: int = 1024 + max_tokens: int = 4096 + prefill_to_decode_ratio: float = 20.0 + + @staticmethod + def get_type(): + return RequestLengthGeneratorType.ZIPF + + +@dataclass +class UniformRequestLengthGeneratorConfig(BaseRequestLengthGeneratorConfig): + min_tokens: int = 1024 + max_tokens: int = 4096 + prefill_to_decode_ratio: float = 20.0 + + @staticmethod + def get_type(): + return RequestLengthGeneratorType.UNIFORM + + +@dataclass +class FixedRequestLengthGeneratorConfig(BaseRequestLengthGeneratorConfig): + prefill_tokens: int = 4096 + decode_tokens: int = 512 + + @staticmethod + def get_type(): + return RequestLengthGeneratorType.FIXED + + +@dataclass +class BaseRequestGeneratorConfig(BasePolyConfig): + seed: int = 42 + + +@dataclass +class SyntheticRequestGeneratorConfig(BaseRequestGeneratorConfig): + length_generator_config: BaseRequestLengthGeneratorConfig = field( + default_factory=FixedRequestLengthGeneratorConfig + ) + interval_generator_config: BaseRequestIntervalGeneratorConfig = field( + default_factory=PoissonRequestIntervalGeneratorConfig + ) + num_requests: int = 64 + duration: float = None + + @staticmethod + def get_type(): + return RequestGeneratorType.SYNTHETIC + + +@dataclass +class TraceRequestGeneratorConfig(BaseRequestGeneratorConfig): + trace_file: str = "data/processed_traces/sydney_enterprise.csv" + date: str = "2023-08-21" + prefill_scale_factor: float = 0.3 + decode_scale_factor: float = 1 + time_scale_factor: float = 0.04 + max_tokens: int = 4096 + + @staticmethod + def get_type(): + return RequestGeneratorType.TRACE + + +@dataclass +class BaseReplicaSchedulerConfig(BasePolyConfig): + max_num_seqs: int = 128 + num_pipeline_stages: int = 1 + watermark_blocks_fraction: float = 0.01 + block_size: int = 16 + num_blocks: Optional[int] = None + + @abstractmethod + def get_max_num_batched_tokens(self): + pass + + +@dataclass +class VllmSchedulerConfig(BaseReplicaSchedulerConfig): + max_batched_tokens: int = None + + @staticmethod + def get_type(): + return ReplicaSchedulerType.VLLM + + +@dataclass +class LightLLMSchedulerConfig(BaseReplicaSchedulerConfig): + max_batched_tokens: int = None + max_tokens_in_batch: int = None + + @staticmethod + def get_type(): + return ReplicaSchedulerType.SIMPLE_CHUNKING + + +@dataclass +class OrcaSchedulerConfig(BaseReplicaSchedulerConfig): + + @staticmethod + def get_type(): + return ReplicaSchedulerType.ORCA + + +@dataclass +class FasterTransformerSchedulerConfig(BaseReplicaSchedulerConfig): + + @staticmethod + def get_type(): + return ReplicaSchedulerType.FASTER_TRANSFORMER + + +@dataclass +class SarathiSchedulerConfig(BaseReplicaSchedulerConfig): + chunk_size: int = 512 + + @staticmethod + def get_type(): + return ReplicaSchedulerType.SARATHI + + +@dataclass +class MetricsConfig: + """Metric configuration.""" + + write_metrics: bool = True + write_json_trace: bool = False + wandb_project: Optional[str] = None + wandb_group: Optional[str] = None + wandb_run_name: Optional[str] = None + wandb_sweep_id: Optional[str] = None + wandb_run_id: Optional[str] = None + enable_chrome_trace: bool = True + save_table_to_wandb: bool = False + store_plots: bool = True + store_operation_metrics: bool = False + store_token_completion_metrics: bool = False + store_request_metrics: bool = True + store_batch_metrics: bool = True + store_utilization_metrics: bool = True + keep_individual_batch_metrics: bool = False + + +@dataclass +class ReplicaConfig: + model_name: str = "meta-llama/Llama-2-7b-hf" + gpu_memory_utilization: float = 0.8 + pipeline_parallel_size: int = 1 + tensor_parallel_size: int = 1 + device: str = "a100" + network_device: str = "a100_pair_nvlink" + + def __post_init__(self): + self.world_size = self.pipeline_parallel_size * self.tensor_parallel_size + + +@dataclass +class BaseGlobalSchedulerConfig(BasePolyConfig): + pass + + +@dataclass +class RandomGlobalSchedulerConfig(BaseGlobalSchedulerConfig): + @staticmethod + def get_type(): + return GlobalSchedulerType.RANDOM + + +@dataclass +class RoundRobinGlobalSchedulerConfig(BaseGlobalSchedulerConfig): + @staticmethod + def get_type(): + return GlobalSchedulerType.ROUND_ROBIN + + +@dataclass +class LORGlobalSchedulerConfig(BaseGlobalSchedulerConfig): + @staticmethod + def get_type(): + return GlobalSchedulerType.LOR + + +@dataclass +class BaseExecutionTimePredictorConfig(BasePolyConfig): + compute_input_file: str = "./data/profiling/compute/{DEVICE}/{MODEL}/mlp.csv" + attention_input_file: str = "./data/profiling/compute/{DEVICE}/{MODEL}/attention.csv" + all_reduce_input_file: str = "./data/profiling/network/{NETWORK_DEVICE}/all_reduce.csv" + send_recv_input_file: str = "./data/profiling/network/{NETWORK_DEVICE}/send_recv.csv" + cpu_overhead_input_file: str = "./data/profiling/cpu_overhead/{NETWORK_DEVICE}/{MODEL}/cpu_overheads.csv" + k_fold_cv_splits: int = 10 + no_cache: bool = False + kv_cache_prediction_granularity: int = 64 + prediction_max_prefill_chunk_size: int = 4096 + prediction_max_batch_size: int = 128 + prediction_max_tokens_per_request: int = 4096 + attention_decode_batching_overhead_fraction: float = 0.1 + attention_prefill_batching_overhead_fraction: float = 0.1 + nccl_cpu_launch_overhead_ms: float = 0.02 + nccl_cpu_skew_overhead_per_device_ms: float = 0.0 + num_training_job_threads: int = -1 + skip_cpu_overhead_modeling: bool = True + + +@dataclass +class LinearRegressionExecutionTimePredictorConfig(BaseExecutionTimePredictorConfig): + polynomial_degree: List[int] = field(default_factory=lambda: list(range(1, 6))) + polynomial_include_bias: List[bool] = field(default_factory=lambda: [True, False]) + polynomial_interaction_only: List[bool] = field(default_factory=lambda: [True, False]) + fit_intercept: List[bool] = field(default_factory=lambda: [True, False]) + + @staticmethod + def get_type(): + return ExecutionTimePredictorType.LINEAR_REGRESSION + + +@dataclass +class RandomForrestExecutionTimePredictorConfig(BaseExecutionTimePredictorConfig): + num_estimators: List[int] = field(default_factory=lambda: [250, 500, 750]) + max_depth: List[int] = field(default_factory=lambda: [8, 16, 32]) + min_samples_split: List[int] = field(default_factory=lambda: [2, 5, 10]) + + @staticmethod + def get_type(): + return ExecutionTimePredictorType.RANDOM_FORREST + + +@dataclass +class ClusterConfig: + num_replicas: int = 1 + replica_config: ReplicaConfig = field(default_factory=ReplicaConfig) + execution_time_predictor_config: BaseExecutionTimePredictorConfig = field( + default_factory=RandomForrestExecutionTimePredictorConfig + ) + global_scheduler_config: BaseGlobalSchedulerConfig = field( + default_factory=RoundRobinGlobalSchedulerConfig + ) + replica_scheduler_config: BaseReplicaSchedulerConfig = field( + default_factory=SarathiSchedulerConfig + ) + metrics_config: MetricsConfig = field(default_factory=MetricsConfig) + + +@dataclass +class SimulationConfig(ABC): + log_level: str = "info" + output_dir: str = "simulator_output" + cache_dir: str = "cache" + time_limit: int = 0 # in seconds, 0 is no limit + cluster_config: ClusterConfig = field(default_factory=ClusterConfig) + request_generator_config: BaseRequestGeneratorConfig = field( + default_factory=SyntheticRequestGeneratorConfig + ) + + def __post_init__(self): + self.output_dir = ( + f"{self.output_dir}/{datetime.now().strftime('%Y-%m-%d_%H-%M-%S-%f')}" + ) + + @classmethod + def create_from_cli_args(cls): + flat_config = create_flat_dataclass(cls).create_from_cli_args() + instance = flat_config.reconstruct_original_dataclass() + instance.__flat_config__ = flat_config + return instance def to_dict(self): - return self._args.__dict__ - - def _add_to_args(self, new_args_dict, parent_key=""): - for key, value in new_args_dict.items(): - arg_name = f"{parent_key}{key}" - setattr(self._args, arg_name, value) - - def _load_model_config(self): - assert self.replica_model_name is not None - - config_file = f"{MODEL_CONFIG_DIR}/{self.replica_model_name}.yml" - with open(config_file, "r") as file: - yaml_config = yaml.safe_load(file) - self._add_to_args(yaml_config, "replica_") - - def _load_device_config(self): - assert self.replica_device is not None - - config_file = f"{DEVICE_CONFIG_DIR}/{self.replica_device}.yml" - with open(config_file, "r") as file: - yaml_config = yaml.safe_load(file) - self._add_to_args(yaml_config, "replica_") - - def _substitute_variables_in_args(self): - assert self.replica_model_name is not None - assert self.replica_device is not None - assert self.replica_network_device is not None - - # update names of sklearn config files - for key, value in self._args.__dict__.items(): - if isinstance(value, str): - self._args.__dict__[key] = ( - value.replace("{MODEL}", self.replica_model_name) - .replace("{DEVICE}", self.replica_device) - .replace("{NETWORK_DEVICE}", self.replica_network_device) - ) + if not hasattr(self, "__flat_config__"): + logger.warning("Flat config not found. Returning the original config.") + return self.__dict__ + + return self.__flat_config__.__dict__ diff --git a/vidur/config/default.yml b/vidur/config/default.yml deleted file mode 100644 index 1913499..0000000 --- a/vidur/config/default.yml +++ /dev/null @@ -1,162 +0,0 @@ -seed: 42 -log_level: info -output_dir: ./simulator_output/ -cache_dir: ./cache -write_json_trace: false -write_chrome_trace: true -write_metrics: true - -cluster: - num_replicas: 1 - -replica: - block_size: 16 - memory_margin_fraction: 0.1 - num_pipeline_stages: 4 - num_tensor_parallel_workers: 1 - model_name: meta-llama/Llama-2-7b-hf - device: a100 - network_device: a100_pair_nvlink - -request_generator: - provider: synthetic - max_tokens: 4096 - -synthetic_request_generator: - length_provider: trace - interval_provider: static - min_tokens: 1024 - prefill_to_decode_ratio: 10 - num_requests: 128 - -trace_request_generator: - trace_file: ./data/processed_traces/sydney_enterprise.csv - date: '2023-08-21' - prefill_scale_factor: 0.3 - decode_scale_factor: 1 - time_scale_factor: 0.04 - -# Config for synthetic trace generator -trace_request_length_generator: - trace_file: ./data/processed_traces/arxiv_summarization_stats_llama2_tokenizer_filtered_v2.csv - prefill_scale_factor: 1 - decode_scale_factor: 1 - -trace_request_interval_generator: - trace_file: ./data/processed_traces/AzureFunctionsInvocationTraceForTwoWeeksJan2021Processed.csv - start_time: "1970-01-04 12:00:00" - end_time: "1970-01-04 15:00:00" - time_scale_factor: 0.3 - -poisson_request_interval_generator: - qps: 0.5 - -gamma_request_interval_generator: - cv: 0.5 - qps: 0.2 - -zipf_request_length_generator: - theta: 0.4 - scramble: false - -fixed_request_generator: - prefill_tokens: 2048 - decode_tokens: 512 - -execution_time_predictor: - provider: random_forrest - # provider: linear_regression - -sklearn_execution_time_predictor: - compute_input_file: ./data/profiling/compute/{DEVICE}/{MODEL}/mlp.csv - attention_input_file: ./data/profiling/compute/{DEVICE}/{MODEL}/attention.csv - all_reduce_input_file: ./data/profiling/network/{NETWORK_DEVICE}/all_reduce.csv - send_recv_input_file: ./data/profiling/network/{NETWORK_DEVICE}/send_recv.csv - cpu_overhead_input_file: ./data/profiling/cpu_overhead/{NETWORK_DEVICE}/{MODEL}/cpu_overheads.csv - k_fold_cv_splits: 10 - no_cache: false - kv_cache_prediction_granularity: 64 - prediction_max_prefill_chunk_size: 4096 - prediction_max_batch_size: 128 - prediction_max_tokens_per_request: 4096 - attention_decode_batching_overhead_fraction: 0.1 - attention_prefill_batching_overhead_fraction: 0.1 - nccl_cpu_launch_overhead_ms: 0.020 - nccl_cpu_skew_overhead_per_device_ms: 0.0 - num_training_job_threads: -1 - skip_cpu_overhead_modeling: true - -random_forrest_execution_time_predictor: - num_estimators: - # - 250 - - 500 - - 750 - max_depth: - # - 8 - # - 16 - - 32 - min_samples_split: - - 2 - - 5 - - 10 - -linear_regression_execution_time_predictor: - polynomial_degree: - - 1 - - 2 - - 3 - - 4 - - 5 - polynomial_include_bias: - - true - - false - polynomial_interaction_only: - - true - - false - fit_intercept: - - true - - false - -simulator: - time_limit: 0 # in seconds, 0 is no limit - -global_scheduler: - provider: round_robin - -replica_scheduler: - provider: sarathi - batch_size_cap: 128 - num_blocks: null - -orca_scheduler: - use_single_prefill_per_batch: false - -vllm_scheduler: - watermark_blocks_fraction: 0.01 - max_tokens_in_batch: 4096 - -sarathi_scheduler: - chunk_size: 512 - enable_rolling_prefills: true - prefill_fitting_tolerance: 0.0 - watermark_blocks_fraction: 0.01 - -lightllm_scheduler: - max_tokens_in_batch: 4096 - max_waiting_iters: 10 - -metrics_store: - wandb_project: "llm-simulator-v2" - wandb_group: "" - wandb_run_name: "" - subsamples: null - save_table_to_wandb: false - store_plots: true - store_operation_metrics: false - store_token_completion_metrics: false - store_request_metrics: true - store_batch_metrics: true - store_utilization_metrics: true - keep_individual_batch_metrics: false - # min_batch_idx: 2000 - # max_batch_idx: 5000 diff --git a/vidur/config/device_sku_config.py b/vidur/config/device_sku_config.py new file mode 100644 index 0000000..be4411f --- /dev/null +++ b/vidur/config/device_sku_config.py @@ -0,0 +1,45 @@ +from dataclasses import dataclass + +from vidur.config.base_poly_config import BasePolyConfig +from vidur.logger import init_logger +from vidur.types import DeviceSKUType + +logger = init_logger(__name__) + + +@dataclass +class BaseDeviceSKUConfig(BasePolyConfig): + fp16_tflops: int + total_memory_gb: int + num_devices_per_node: int + + +@dataclass +class A100DeviceSKUConfig(BaseDeviceSKUConfig): + fp16_tflops: int = 312 + total_memory_gb: int = 80 + + @staticmethod + def get_type(): + return DeviceSKUType.A100 + + +@dataclass +class A40DeviceSKUConfig(BaseDeviceSKUConfig): + fp16_tflops: int = 150 + total_memory_gb: int = 45 + + @staticmethod + def get_type(): + return DeviceSKUType.A100 + + +@dataclass +class H100DeviceSKUConfig(BaseDeviceSKUConfig): + fp16_tflops: int = 1000 + total_memory_gb: int = 80 + + @staticmethod + def get_type(): + return DeviceSKUType.A100 + diff --git a/vidur/config/flat_dataclass.py b/vidur/config/flat_dataclass.py new file mode 100644 index 0000000..82cb9fa --- /dev/null +++ b/vidur/config/flat_dataclass.py @@ -0,0 +1,204 @@ +import json +from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser +from collections import defaultdict, deque +from dataclasses import MISSING, fields, make_dataclass +from typing import Any, get_args + +from vidur.config.base_poly_config import BasePolyConfig +from vidur.config.utils import ( + get_all_subclasses, + get_inner_type, + is_composed_of_primitives, + is_dict, + is_list, + is_optional, + is_primitive_type, + is_subclass, + to_snake_case, +) + + +def topological_sort(dataclass_dependencies: dict) -> list: + in_degree = defaultdict(int) + for cls, dependencies in dataclass_dependencies.items(): + for dep in dependencies: + in_degree[dep] += 1 + + zero_in_degree_classes = deque( + [cls for cls in dataclass_dependencies if in_degree[cls] == 0] + ) + sorted_classes = [] + + while zero_in_degree_classes: + cls = zero_in_degree_classes.popleft() + sorted_classes.append(cls) + for dep in dataclass_dependencies[cls]: + in_degree[dep] -= 1 + if in_degree[dep] == 0: + zero_in_degree_classes.append(dep) + + return sorted_classes + + +def reconstruct_original_dataclass(self) -> Any: + """ + This function is dynamically mapped to FlatClass as an instance method. + """ + sorted_classes = topological_sort(self.dataclass_dependencies) + instances = {} + + for _cls in reversed(sorted_classes): + args = {} + + for prefixed_filed_name, original_field_name, field_type in self.dataclass_args[ + _cls + ]: + if is_subclass(field_type, BasePolyConfig): + config_type = getattr(self, f"{original_field_name}_type") + # find all subclasses of field_type and check which one matches the config_type + for subclass in get_all_subclasses(field_type): + if subclass.get_type() == config_type: + args[original_field_name] = instances[subclass] + break + elif hasattr(field_type, "__dataclass_fields__"): + args[original_field_name] = instances[field_type] + else: + args[original_field_name] = getattr(self, prefixed_filed_name) + + instances[_cls] = _cls(**args) + + return instances[sorted_classes[0]] + + +@classmethod +def create_from_cli_args(cls) -> Any: + """ + This function is dynamically mapped to FlatClass as a class method. + """ + parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) + + for field in fields(cls): + nargs = None + field_type = field.type + + if is_list(field.type): + assert is_composed_of_primitives(field.type) + field_type = get_args(field.type)[0] + if is_primitive_type(field_type): + nargs = "+" + else: + field_type = json.loads + elif is_dict(field.type): + assert is_composed_of_primitives(field.type) + field_type = json.loads + + # handle cases with default and default factory args + if field.default is not MISSING: + parser.add_argument( + f"--{field.name}", type=field_type, default=field.default, nargs=nargs + ) + elif field.default_factory is not MISSING: + parser.add_argument( + f"--{field.name}", + type=field_type, + default=field.default_factory(), + nargs=nargs, + ) + else: + parser.add_argument( + f"--{field.name}", type=field_type, required=True, nargs=nargs + ) + + args = parser.parse_args() + + return cls(**vars(args)) + + +def create_flat_dataclass(input_dataclass: Any) -> Any: + """ + Creates a new FlatClass type by recursively flattening the input dataclass. + This allows for easy parsing of command line arguments along with storing/loading the configuration to/from a file. + """ + meta_fields_with_defaults = [] + meta_fields_without_defaults = [] + processed_classes = set() + dataclass_args = defaultdict(list) + dataclass_dependencies = defaultdict(set) + + def process_dataclass(_input_dataclass, prefix=""): + if _input_dataclass in processed_classes: + return + + processed_classes.add(_input_dataclass) + + for field in fields(_input_dataclass): + prefixed_name = f"{prefix}{field.name}" + + if is_optional(field.type): + field_type = get_inner_type(field.type) + else: + field_type = field.type + + # # if field is a BasePolyConfig, add a type argument and process it as a dataclass + if is_subclass(field_type, BasePolyConfig): + dataclass_args[_input_dataclass].append( + (field.name, field.name, field_type) + ) + + type_field_name = f"{field.name}_type" + default_value = field.default_factory().get_type() + meta_fields_with_defaults.append( + (type_field_name, type(default_value), default_value) + ) + + assert hasattr(field_type, "__dataclass_fields__") + for subclass in get_all_subclasses(field_type): + dataclass_dependencies[_input_dataclass].add(subclass) + process_dataclass(subclass, f"{to_snake_case(subclass.__name__)}_") + continue + + # if field is a dataclass, recursively process it + if hasattr(field_type, "__dataclass_fields__"): + dataclass_dependencies[_input_dataclass].add(field_type) + dataclass_args[_input_dataclass].append( + (field.name, field.name, field_type) + ) + process_dataclass(field_type, f"{to_snake_case(field_type.__name__)}_") + continue + + field_default = field.default if field.default is not MISSING else MISSING + field_default_factory = ( + field.default_factory + if field.default_factory is not MISSING + else MISSING + ) + + if field_default is not MISSING: + meta_fields_with_defaults.append( + (prefixed_name, field_type, field_default) + ) + elif field_default_factory is not MISSING: + meta_fields_with_defaults.append( + (prefixed_name, field_type, field_default_factory()) + ) + else: + meta_fields_without_defaults.append((prefixed_name, field_type)) + + dataclass_args[_input_dataclass].append( + (prefixed_name, field.name, field_type) + ) + + process_dataclass(input_dataclass) + + meta_fields = meta_fields_without_defaults + meta_fields_with_defaults + FlatClass = make_dataclass("FlatClass", meta_fields) + + # Metadata fields + FlatClass.dataclass_args = dataclass_args + FlatClass.dataclass_dependencies = dataclass_dependencies + + # Helper methods + FlatClass.reconstruct_original_dataclass = reconstruct_original_dataclass + FlatClass.create_from_cli_args = create_from_cli_args + + return FlatClass diff --git a/vidur/config/model_config.py b/vidur/config/model_config.py new file mode 100644 index 0000000..815d746 --- /dev/null +++ b/vidur/config/model_config.py @@ -0,0 +1,151 @@ +from dataclasses import dataclass +from typing import Any, Dict, Optional + +from vidur.config.base_poly_config import BasePolyConfig +from vidur.logger import init_logger +from vidur.types import NormType, ActivationType, ModelType + +logger = init_logger(__name__) + + +@dataclass +class BaseModelConfig(BasePolyConfig): + num_layers: int + num_q_heads: int + num_kv_heads: int + embedding_dim: int + mlp_hidden_dim: int + max_position_embeddings: int + use_gated_mlp: bool + use_bias: bool + use_qkv_bias: bool + activation: ActivationType + norm: NormType + post_attn_norm: bool + vocab_size: int + is_neox_style: Optional[bool] = True + rope_theta: Optional[int] = None + rope_scaling: Optional[Dict[str, Any]] = None + partial_rotary_factor: float = 1.0 + no_tensor_parallel: bool = False + + +@dataclass +class Llama2ModelConfig(BaseModelConfig): + max_position_embeddings: int = 16384 + use_gated_mlp: bool = True + use_bias: bool = False + use_qkv_bias: bool = False + activation: ActivationType = ActivationType.SILU + norm: NormType = NormType.RMS_NORM + post_attn_norm: bool = True + vocab_size: int = 32768 + is_neox_style: Optional[bool] = True + rope_theta: Optional[int] = 10000.0 + rope_scaling: Optional[Dict[str, Any]] = None + partial_rotary_factor: float = 1.0 + no_tensor_parallel: bool = False + + +@dataclass +class CodeLlama34BModelConfig(Llama2ModelConfig): + num_layers: int = 48 + num_q_heads: int = 64 + num_kv_heads: int = 8 + embedding_dim: int = 8192 + mlp_hidden_dim: int = 22016 + + @staticmethod + def get_type(): + return ModelType.CODE_LLAMA_34B + + +@dataclass +class Llama2_7BModelConfig(Llama2ModelConfig): + num_layers: int = 32 + num_q_heads: int = 32 + num_kv_heads: int = 32 + embedding_dim: int = 4096 + mlp_hidden_dim: int = 11008 + + @staticmethod + def get_type(): + return ModelType.LLAMA_2_7B + + +@dataclass +class Llama2_70BModelConfig(Llama2ModelConfig): + num_layers: int = 80 + num_q_heads: int = 64 + num_kv_heads: int = 8 + embedding_dim: int = 8192 + mlp_hidden_dim: int = 28672 + + @staticmethod + def get_type(): + return ModelType.LLAMA_2_70B + + +@dataclass +class InternLM2ModelConfig(Llama2ModelConfig): + max_position_embeddings: int = 32768 + vocab_size: int = 92544 + + +@dataclass +class InternLM2_20BModelConfig(InternLM2ModelConfig): + num_layers: int = 48 + num_q_heads: int = 48 + num_kv_heads: int = 8 + embedding_dim: int = 6144 + mlp_hidden_dim: int = 16384 + + @staticmethod + def get_type(): + return ModelType.INTERNLM_2_20B + + +@dataclass +class Phi2ModelConfig(Llama2ModelConfig): + num_layers: int = 32 + num_q_heads: int = 32 + num_kv_heads: int = 32 + embedding_dim: int = 2560 + mlp_hidden_dim: int = 10240 + max_position_embeddings: int = 2048 + use_gated_mlp: bool = False + use_bias: bool = True + use_qkv_bias: bool = True + activation: ActivationType = ActivationType.GELU + norm: NormType = NormType.LAYER_NORM + post_attn_norm: bool = False + vocab_size: int = 51200 + rope_scaling: Optional[Dict[str, Any]] = None + rope_theta: Optional[int] = 10000.0 + partial_rotary_factor: float = 0.4 + no_tensor_parallel: bool = True + is_neox_style: bool = True + + @staticmethod + def get_type(): + return ModelType.PHI2 + + +@dataclass +class QwenModelConfig(Llama2ModelConfig): + use_qkv_bias: bool = True + max_position_embeddings: int = 32768 + vocab_size: int = 152064 + + +@dataclass +class Qwen72BModelConfig(QwenModelConfig): + num_layers: int = 80 + num_q_heads: int = 64 + num_kv_heads: int = 64 + embedding_dim: int = 8192 + mlp_hidden_dim: int = 24576 + + @staticmethod + def get_type(): + return ModelType.QWEN_72B diff --git a/vidur/config/node_sku_config.py b/vidur/config/node_sku_config.py new file mode 100644 index 0000000..412f532 --- /dev/null +++ b/vidur/config/node_sku_config.py @@ -0,0 +1,62 @@ +from dataclasses import dataclass + +from vidur.config.base_poly_config import BasePolyConfig +from vidur.logger import init_logger +from vidur.types import NodeSKUType, DeviceSKUType + +logger = init_logger(__name__) + + +@dataclass +class BaseNodeSKUConfig(BasePolyConfig): + num_devices_per_node: int + + +@dataclass +class A40PairwiseNvlinkNodeSKUConfig(BaseNodeSKUConfig): + device_sku_type: DeviceSKUType = DeviceSKUType.A40 + num_devices_per_node: int = 8 + + @staticmethod + def get_type(): + return NodeSKUType.A40_PAIRWISE_NVLINK + + +@dataclass +class A100PairwiseNvlinkNodeSKUConfig(BaseNodeSKUConfig): + device_sku_type: DeviceSKUType = DeviceSKUType.A100 + num_devices_per_node: int = 8 + + @staticmethod + def get_type(): + return NodeSKUType.A100_PAIRWISE_NVLINK + + +@dataclass +class H100PairwiseNvlinkNodeSKUConfig(BaseNodeSKUConfig): + device_sku_type: DeviceSKUType = DeviceSKUType.H100 + num_devices_per_node: int = 8 + + @staticmethod + def get_type(): + return NodeSKUType.H100_PAIRWISE_NVLINK + + +@dataclass +class A100DgxNodeSKUConfig(BaseNodeSKUConfig): + device_sku_type: DeviceSKUType = DeviceSKUType.A100 + num_devices_per_node: int = 8 + + @staticmethod + def get_type(): + return NodeSKUType.A100_DGX + + +@dataclass +class H100DgxNodeSKUConfig(BaseNodeSKUConfig): + device_sku_type: DeviceSKUType = DeviceSKUType.H100 + num_devices_per_node: int = 8 + + @staticmethod + def get_type(): + return NodeSKUType.H100_DGX diff --git a/vidur/config/utils.py b/vidur/config/utils.py new file mode 100644 index 0000000..e0ed107 --- /dev/null +++ b/vidur/config/utils.py @@ -0,0 +1,60 @@ +from typing import Union, get_args, get_origin + +primitive_types = {int, str, float, bool, type(None)} + + +def get_all_subclasses(cls): + subclasses = cls.__subclasses__() + return subclasses + [g for s in subclasses for g in get_all_subclasses(s)] + + +def is_primitive_type(field_type: type) -> bool: + # Check if the type is a primitive type + return field_type in primitive_types + + +def is_generic_composed_of_primitives(field_type: type) -> bool: + origin = get_origin(field_type) + if origin in {list, dict, tuple, Union}: + # Check all arguments of the generic type + args = get_args(field_type) + return all(is_composed_of_primitives(arg) for arg in args) + return False + + +def is_composed_of_primitives(field_type: type) -> bool: + # Check if the type is a primitive type + if is_primitive_type(field_type): + return True + + # Check if the type is a generic type composed of primitives + if is_generic_composed_of_primitives(field_type): + return True + + return False + + +def to_snake_case(name: str) -> str: + return "".join(["_" + i.lower() if i.isupper() else i for i in name]).lstrip("_") + + +def is_optional(field_type: type) -> bool: + return get_origin(field_type) is Union and type(None) in get_args(field_type) + + +def is_list(field_type: type) -> bool: + # Check if the field type is a List + return get_origin(field_type) is list + + +def is_dict(field_type: type) -> bool: + # Check if the field type is a Dict + return get_origin(field_type) is dict + + +def get_inner_type(field_type: type) -> type: + return next(t for t in get_args(field_type) if t is not type(None)) + + +def is_subclass(cls, parent: type) -> bool: + return hasattr(cls, "__bases__") and parent in cls.__bases__ diff --git a/vidur/execution_time_predictor/dummy_execution_time_predictor.py b/vidur/execution_time_predictor/dummy_execution_time_predictor.py deleted file mode 100644 index 66d1590..0000000 --- a/vidur/execution_time_predictor/dummy_execution_time_predictor.py +++ /dev/null @@ -1,33 +0,0 @@ -import random -from typing import List - -from vidur.entities import Request -from vidur.execution_time_predictor.base_execution_time_predictor import ( - BaseExecutionTimePredictor, -) - - -class DummyExecutionTimePredictor(BaseExecutionTimePredictor): - def _get_attention_layer_pre_proj_execution_time( - self, batch: List[Request] - ) -> float: - return random.uniform(0.1, 0.2) - - def _get_attention_layer_post_proj_execution_time( - self, batch: List[Request] - ) -> float: - return random.uniform(0.1, 0.2) - - def _get_attention_layer_flash_attention_execution_time( - self, batch: List[Request] - ) -> float: - return random.uniform(0.1, 0.2) - - def _get_mlp_layer_mlp_execution_time(self, batch: List[Request]) -> float: - return random.uniform(0.1, 0.2) - - def _get_tensor_parallel_communication_time(self, batch: List[Request]) -> float: - return random.uniform(0.1, 0.2) - - def _get_pipeline_parallel_communication_time(self, batch: List[Request]) -> float: - return random.uniform(0.1, 0.2) diff --git a/vidur/execution_time_predictor/execution_time_predictor_registry.py b/vidur/execution_time_predictor/execution_time_predictor_registry.py index ed03828..71b3221 100644 --- a/vidur/execution_time_predictor/execution_time_predictor_registry.py +++ b/vidur/execution_time_predictor/execution_time_predictor_registry.py @@ -1,6 +1,3 @@ -from vidur.execution_time_predictor.dummy_execution_time_predictor import ( - DummyExecutionTimePredictor, -) from vidur.execution_time_predictor.linear_regression_execution_time_predictor import ( LinearRegressionExecutionTimePredictor, ) @@ -16,10 +13,6 @@ class ExecutionTimePredictorRegistry(BaseRegistry): def get_key_from_str(cls, key_str: str) -> ExecutionTimePredictorType: return ExecutionTimePredictorType.from_str(key_str) - -ExecutionTimePredictorRegistry.register( - ExecutionTimePredictorType.DUMMY, DummyExecutionTimePredictor -) ExecutionTimePredictorRegistry.register( ExecutionTimePredictorType.RANDOM_FORREST, RandomForrestExecutionTimePredictor ) diff --git a/vidur/scheduler/replica_scheduler/orca_replica_scheduler.py b/vidur/scheduler/replica_scheduler/orca_replica_scheduler.py index d434f81..20352bc 100644 --- a/vidur/scheduler/replica_scheduler/orca_replica_scheduler.py +++ b/vidur/scheduler/replica_scheduler/orca_replica_scheduler.py @@ -10,9 +10,6 @@ def __init__(self, *args, **kwargs): self._preempted_requests = [] self._num_running_batches = 0 - self._use_single_prefill_per_batch = ( - self._config.orca_scheduler_use_single_prefill_per_batch - ) def on_batch_end(self, batch: Batch) -> None: self._num_running_batches -= 1 @@ -26,7 +23,6 @@ def on_batch_end(self, batch: Batch) -> None: def _get_next_batch(self) -> Batch: requests = [] num_tokens = [] - contains_prefill = False # all preempted_requests will have prefill completed while self._preempted_requests: @@ -45,16 +41,12 @@ def _get_next_batch(self) -> Batch: if not self.can_allocate(self._max_blocks_per_sequence): break - if self._use_single_prefill_per_batch and contains_prefill: - break - request = self._request_queue.pop(0) self.allocate(request.id, self._max_blocks_per_sequence) next_num_tokens = self._get_request_next_num_tokens(request) requests.append(request) num_tokens.append(next_num_tokens) - contains_prefill = True if not requests: return diff --git a/vidur/scheduler/replica_scheduler/sarathi_replica_scheduler.py b/vidur/scheduler/replica_scheduler/sarathi_replica_scheduler.py index 637d25e..60288ef 100644 --- a/vidur/scheduler/replica_scheduler/sarathi_replica_scheduler.py +++ b/vidur/scheduler/replica_scheduler/sarathi_replica_scheduler.py @@ -14,16 +14,6 @@ def __init__(self, *args, **kwargs): self._num_running_batches = 0 self._preempted_requests = [] self._chunk_size = self._config.sarathi_scheduler_chunk_size - # club multiple prefills to ensure uniform chunk size - self._enable_rolling_prefills = ( - self._config.sarathi_scheduler_enable_rolling_prefills - ) - # when we are packing multiple prefills in a batch, we need to ensure - # that we don't end up packing a very small prefill chunk just to make batch full - # because that will lead to reduced number of schedulable prefill requests - self._prefill_fitting_tolerance = ( - self._config.sarathi_scheduler_prefill_fitting_tolerance - ) # vLLM config self._watermark_blocks_fraction = ( self._config.sarathi_scheduler_watermark_blocks_fraction @@ -93,19 +83,9 @@ def _get_request_next_num_tokens( self._chunk_size - num_batch_tokens, ) - if not batch_contains_prefill: - return next_num_tokens - - if self._enable_rolling_prefills and num_batch_tokens < self._chunk_size * ( - 1 - self._prefill_fitting_tolerance - ): - # we can have multiple prefills per batch - # but the total number of tokens should not exceed - # the max batch size - return next_num_tokens - else: - # we will only allow one prefill per batch - return 0 + next_num_tokens = max(0, next_num_tokens) + + return next_num_tokens def _get_next_batch(self) -> Batch: requests = [] diff --git a/vidur/types/__init__.py b/vidur/types/__init__.py index 5e64c3e..a432aa9 100644 --- a/vidur/types/__init__.py +++ b/vidur/types/__init__.py @@ -6,6 +6,12 @@ from vidur.types.request_generator_type import RequestGeneratorType from vidur.types.request_interval_generator_type import RequestIntervalGeneratorType from vidur.types.request_length_generator_type import RequestLengthGeneratorType +from vidur.types.device_sku_type import DeviceSKUType +from vidur.types.node_sku_type import NodeSKUType +from vidur.types.norm_type import NormType +from vidur.types.activation_type import ActivationType +from vidur.types.model_type import ModelType + __all__ = [ EventType, @@ -15,5 +21,10 @@ RequestLengthGeneratorType, RequestIntervalGeneratorType, ReplicaSchedulerType, + DeviceSKUType, + NodeSKUType, + NormType, + ActivationType, + ModelType, BaseIntEnum, ] diff --git a/vidur/types/activation_type.py b/vidur/types/activation_type.py new file mode 100644 index 0000000..d622f2f --- /dev/null +++ b/vidur/types/activation_type.py @@ -0,0 +1,6 @@ +from vidur.types.base_int_enum import BaseIntEnum + + +class ActivationType(BaseIntEnum): + GELU = 0 + SILU = 1 diff --git a/vidur/types/device_sku_type.py b/vidur/types/device_sku_type.py new file mode 100644 index 0000000..9077efe --- /dev/null +++ b/vidur/types/device_sku_type.py @@ -0,0 +1,7 @@ +from vidur.types.base_int_enum import BaseIntEnum + + +class DeviceSKUType(BaseIntEnum): + A40 = 1 + A100 = 2 + H100 = 3 diff --git a/vidur/types/model_type.py b/vidur/types/model_type.py new file mode 100644 index 0000000..e5e1bab --- /dev/null +++ b/vidur/types/model_type.py @@ -0,0 +1,10 @@ +from vidur.types.base_int_enum import BaseIntEnum + + +class ModelType(BaseIntEnum): + CODE_LLAMA_34B = 0 + LLAMA_2_7B = 1 + LLAMA_2_70B = 2 + INTERNLM_2_20B = 3 + PHI2 = 4 + QWEN_72B = 5 diff --git a/vidur/types/node_sku_type.py b/vidur/types/node_sku_type.py new file mode 100644 index 0000000..4bcaabf --- /dev/null +++ b/vidur/types/node_sku_type.py @@ -0,0 +1,9 @@ +from vidur.types.base_int_enum import BaseIntEnum + + +class NodeSKUType(BaseIntEnum): + A40_PAIRWISE_NVLINK = 1 + A100_PAIRWISE_NVLINK = 2 + H100_PAIRWISE_NVLINK = 3 + A100_DGX = 4 + H100_DGX = 5 diff --git a/vidur/types/norm_type.py b/vidur/types/norm_type.py new file mode 100644 index 0000000..b3191e4 --- /dev/null +++ b/vidur/types/norm_type.py @@ -0,0 +1,7 @@ +from vidur.types.base_int_enum import BaseIntEnum + + +class NormType(BaseIntEnum): + LAYER_NORM = 0 + RMS_NORM = 1 + From 01f216fe8559c0eac076b157b19125a9c370d100 Mon Sep 17 00:00:00 2001 From: Anmol Agarwal Date: Wed, 24 Jul 2024 03:23:08 -0400 Subject: [PATCH 02/24] config --- vidur/config/config.py | 535 ++++++++++++++---- vidur/config/device_sku_config.py | 44 +- vidur/config/flat_dataclass.py | 10 +- vidur/config/model_config.py | 384 ++++++++++--- vidur/config/node_sku_config.py | 56 +- vidur/entities/cluster.py | 10 +- vidur/main.py | 6 +- vidur/metrics/metrics_store.py | 43 +- .../base_request_generator.py | 16 +- .../base_request_interval_generator.py | 7 +- .../base_request_length_generator.py | 7 +- .../fixed_request_length_generator.py | 5 +- .../gamma_request_interval_generator.py | 16 +- .../poisson_request_interval_generator.py | 17 +- .../request_generator_registry.py | 4 +- .../request_interval_generator_registry.py | 4 +- .../request_length_generator_registry.py | 4 +- .../static_request_interval_generator.py | 1 + .../synthetic_request_generator.py | 56 +- .../trace_replay_request_generator.py | 61 +- .../trace_request_interval_generator.py | 49 +- .../trace_request_length_generator.py | 88 ++- .../uniform_request_length_generator.py | 8 +- .../zipf_request_length_generator.py | 26 +- .../global_scheduler/base_global_scheduler.py | 12 +- .../base_replica_scheduler.py | 12 +- .../lightllm_replica_scheduler.py | 12 +- .../replica_scheduler_registry.py | 4 +- .../sarathi_replica_scheduler.py | 8 +- .../vllm_replica_scheduler.py | 8 +- vidur/scheduler/utils/memory_planner.py | 4 +- vidur/simulator.py | 22 +- 32 files changed, 1067 insertions(+), 472 deletions(-) diff --git a/vidur/config/config.py b/vidur/config/config.py index 898cd88..074ea6f 100644 --- a/vidur/config/config.py +++ b/vidur/config/config.py @@ -13,22 +13,42 @@ @dataclass class BaseRequestIntervalGeneratorConfig(BasePolyConfig): - seed: int = 42 + seed: int = field( + default=42, + metadata={"help": "Seed for the random number generator."}, + ) @dataclass class BaseRequestLengthGeneratorConfig(BasePolyConfig): - seed: int = 42 + seed: int = field( + default=42, + metadata={"help": "Seed for the random number generator."}, + ) + max_tokens: int = field( + default=4096, + metadata={"help": "Maximum tokens."}, + ) @dataclass class TraceRequestIntervalGeneratorConfig(BaseRequestIntervalGeneratorConfig): - trace_file: str = ( - "data/processed_traces/AzureFunctionsInvocationTraceForTwoWeeksJan2021Processed.csv" + trace_file: str = field( + default="data/processed_traces/AzureFunctionsInvocationTraceForTwoWeeksJan2021Processed.csv", + metadata={"help": "Path to the trace request interval generator file."}, + ) + start_time: str = field( + default="1970-01-04 12:00:00", + metadata={"help": "Start time of the trace request interval generator."}, + ) + end_time: str = field( + default="1970-01-04 15:00:00", + metadata={"help": "End time of the trace request interval generator."}, + ) + time_scale_factor: float = field( + default=0.3, + metadata={"help": "Time scale factor for the trace request interval generator."}, ) - start_time: str = "1970-01-04 12:00:00" - end_time: str = "1970-01-04 15:00:00" - time_scale_factor: float = 0.3 @staticmethod def get_type(): @@ -37,7 +57,10 @@ def get_type(): @dataclass class PoissonRequestIntervalGeneratorConfig(BaseRequestIntervalGeneratorConfig): - qps: float = 1.0 + qps: float = field( + default=1.0, + metadata={"help": "Queries per second for Poisson Request Interval Generator."}, + ) @staticmethod def get_type(): @@ -46,8 +69,14 @@ def get_type(): @dataclass class GammaRequestIntervalGeneratorConfig(BaseRequestIntervalGeneratorConfig): - qps: float = 1.0 - cv: float = 0.5 + qps: float = field( + default=1.0, + metadata={"help": "Queries per second for Gamma Request Interval Generator."}, + ) + cv: float = field( + default=0.5, + metadata={"help": "Coefficient of variation for Gamma Request Interval Generator."}, + ) @staticmethod def get_type(): @@ -63,12 +92,22 @@ def get_type(): @dataclass class TraceRequestLengthGeneratorConfig(BaseRequestLengthGeneratorConfig): - trace_file: str = ( - "data/processed_traces/sharegpt_8k_filtered_stats_llama2_tokenizer.csv" + trace_file: str = field( + default="data/processed_traces/sharegpt_8k_filtered_stats_llama2_tokenizer.csv", + metadata={"help": "Path to the trace request length generator file."}, + ) + prefill_scale_factor: float = field( + default=1, + metadata={"help": "Prefill scale factor for the trace request length generator."}, + ) + decode_scale_factor: float = field( + default=1, + metadata={"help": "Decode scale factor for the trace request length generator."}, + ) + max_tokens: int = field( + default=4096, + metadata={"help": "Maximum tokens for the trace request length generator."}, ) - prefill_scale_factor: float = 1 - decode_scale_factor: float = 1 - max_tokens: int = 4096 @staticmethod def get_type(): @@ -77,11 +116,26 @@ def get_type(): @dataclass class ZipfRequestLengthGeneratorConfig(BaseRequestLengthGeneratorConfig): - theta: float = 0.6 - scramble: bool = False - min_tokens: int = 1024 - max_tokens: int = 4096 - prefill_to_decode_ratio: float = 20.0 + theta: float = field( + default=0.6, + metadata={"help": "Theta for Zipf Request Length Generator."}, + ) + scramble: bool = field( + default=False, + metadata={"help": "Scramble for Zipf Request Length Generator."}, + ) + min_tokens: int = field( + default=1024, + metadata={"help": "Minimum tokens for Zipf Request Length Generator."}, + ) + max_tokens: int = field( + default=4096, + metadata={"help": "Maximum tokens for Zipf Request Length Generator."}, + ) + prefill_to_decode_ratio: float = field( + default=20.0, + metadata={"help": "Prefill to decode ratio for Zipf Request Length Generator."}, + ) @staticmethod def get_type(): @@ -90,9 +144,18 @@ def get_type(): @dataclass class UniformRequestLengthGeneratorConfig(BaseRequestLengthGeneratorConfig): - min_tokens: int = 1024 - max_tokens: int = 4096 - prefill_to_decode_ratio: float = 20.0 + min_tokens: int = field( + default=1024, + metadata={"help": "Minimum tokens for Uniform Request Length Generator."}, + ) + max_tokens: int = field( + default=4096, + metadata={"help": "Maximum tokens for Uniform Request Length Generator."}, + ) + prefill_to_decode_ratio: float = field( + default=20.0, + metadata={"help": "Prefill to decode ratio for Uniform Request Length Generator."}, + ) @staticmethod def get_type(): @@ -101,8 +164,14 @@ def get_type(): @dataclass class FixedRequestLengthGeneratorConfig(BaseRequestLengthGeneratorConfig): - prefill_tokens: int = 4096 - decode_tokens: int = 512 + prefill_tokens: int = field( + default=4096, + metadata={"help": "Prefill tokens for Fixed Request Length Generator."}, + ) + decode_tokens: int = field( + default=512, + metadata={"help": "Decode tokens for Fixed Request Length Generator."}, + ) @staticmethod def get_type(): @@ -111,19 +180,43 @@ def get_type(): @dataclass class BaseRequestGeneratorConfig(BasePolyConfig): - seed: int = 42 + seed: int = field( + default=42, + metadata={"help": "Seed for the random number generator."}, + ) + max_tokens: int = field( + default=4096, + metadata={"help": "Maximum tokens."}, + ) + @dataclass class SyntheticRequestGeneratorConfig(BaseRequestGeneratorConfig): length_generator_config: BaseRequestLengthGeneratorConfig = field( - default_factory=FixedRequestLengthGeneratorConfig + default_factory=FixedRequestLengthGeneratorConfig, + metadata={"help": "Length generator config for Synthetic Request Generator."}, ) interval_generator_config: BaseRequestIntervalGeneratorConfig = field( - default_factory=PoissonRequestIntervalGeneratorConfig + default_factory=PoissonRequestIntervalGeneratorConfig, + metadata={"help": "Interval generator config for Synthetic Request Generator."}, + ) + num_requests: int = field( + default=64, + metadata={"help": "Number of requests for Synthetic Request Generator."}, + ) + duration: float = field( + default=None, + metadata={"help": "Duration of the synthetic request generator."}, ) - num_requests: int = 64 - duration: float = None + max_tokens: int = field( + init=False, + default=4096, + metadata={"help": "Maximum tokens for the synthetic request generator."}, + ) + + def __post_init__(self): + self.max_tokens = self.length_generator_config.max_tokens @staticmethod def get_type(): @@ -132,12 +225,30 @@ def get_type(): @dataclass class TraceRequestGeneratorConfig(BaseRequestGeneratorConfig): - trace_file: str = "data/processed_traces/sydney_enterprise.csv" - date: str = "2023-08-21" - prefill_scale_factor: float = 0.3 - decode_scale_factor: float = 1 - time_scale_factor: float = 0.04 - max_tokens: int = 4096 + trace_file: str = field( + default="data/processed_traces/sydney_enterprise.csv", + metadata={"help": "Path to the trace request generator file."}, + ) + date: str = field( + default="2023-08-21", + metadata={"help": "Date for the trace request generator."}, + ) + prefill_scale_factor: float = field( + default=0.3, + metadata={"help": "Prefill scale factor for the trace request generator."}, + ) + decode_scale_factor: float = field( + default=1, + metadata={"help": "Decode scale factor for the trace request generator."}, + ) + time_scale_factor: float = field( + default=0.04, + metadata={"help": "Time scale factor for the trace request generator."}, + ) + max_tokens: int = field( + default=4096, + metadata={"help": "Maximum tokens for the trace request generator."}, + ) @staticmethod def get_type(): @@ -146,11 +257,30 @@ def get_type(): @dataclass class BaseReplicaSchedulerConfig(BasePolyConfig): - max_num_seqs: int = 128 - num_pipeline_stages: int = 1 - watermark_blocks_fraction: float = 0.01 - block_size: int = 16 - num_blocks: Optional[int] = None + max_num_seqs: int = field( + default=128, + metadata={"help": "Maximum number of sequences."}, + ) + num_pipeline_stages: int = field( + default=1, + metadata={"help": "Number of pipeline stages."}, + ) + watermark_blocks_fraction: float = field( + default=0.01, + metadata={"help": "Watermark blocks fraction."}, + ) + block_size: int = field( + default=16, + metadata={"help": "Block size."}, + ) + num_blocks: Optional[int] = field( + default=None, + metadata={"help": "Number of blocks."}, + ) + batch_size_cap: int = field( + default=128, + metadata={"help": "Maximum batch size cap."}, + ) @abstractmethod def get_max_num_batched_tokens(self): @@ -159,7 +289,14 @@ def get_max_num_batched_tokens(self): @dataclass class VllmSchedulerConfig(BaseReplicaSchedulerConfig): - max_batched_tokens: int = None + max_batched_tokens: int = field( + default=None, + metadata={"help": "Maximum batched tokens for vLLM."}, + ) + max_tokens_in_batch: int = field( + default=4096, + metadata={"help": "Maximum tokens in batch for vLLM."}, + ) @staticmethod def get_type(): @@ -168,12 +305,22 @@ def get_type(): @dataclass class LightLLMSchedulerConfig(BaseReplicaSchedulerConfig): - max_batched_tokens: int = None - max_tokens_in_batch: int = None + max_batched_tokens: int = field( + default=None, + metadata={"help": "Maximum batched tokens for LightLLM."}, + ) + max_tokens_in_batch: int = field( + default=4096, + metadata={"help": "Maximum tokens in batch for LightLLM."}, + ) + max_waiting_iters: int = field( + default=10, + metadata={"help": "Maximum waiting iterations for LightLLM."}, + ) @staticmethod def get_type(): - return ReplicaSchedulerType.SIMPLE_CHUNKING + return ReplicaSchedulerType.LIGHTLLM @dataclass @@ -194,7 +341,10 @@ def get_type(): @dataclass class SarathiSchedulerConfig(BaseReplicaSchedulerConfig): - chunk_size: int = 512 + chunk_size: int = field( + default=512, + metadata={"help": "Chunk size for Sarathi."}, + ) @staticmethod def get_type(): @@ -205,32 +355,110 @@ def get_type(): class MetricsConfig: """Metric configuration.""" - write_metrics: bool = True - write_json_trace: bool = False - wandb_project: Optional[str] = None - wandb_group: Optional[str] = None - wandb_run_name: Optional[str] = None - wandb_sweep_id: Optional[str] = None - wandb_run_id: Optional[str] = None - enable_chrome_trace: bool = True - save_table_to_wandb: bool = False - store_plots: bool = True - store_operation_metrics: bool = False - store_token_completion_metrics: bool = False - store_request_metrics: bool = True - store_batch_metrics: bool = True - store_utilization_metrics: bool = True - keep_individual_batch_metrics: bool = False + write_metrics: bool = field( + default=True, + metadata={"help": "Whether to write metrics."}, + ) + write_json_trace: bool = field( + default=False, + metadata={"help": "Whether to write json trace."}, + ) + wandb_project: Optional[str] = field( + default=None, + metadata={"help": "Weights & Biases project name."}, + ) + wandb_group: Optional[str] = field( + default=None, + metadata={"help": "Weights & Biases group name."}, + ) + wandb_run_name: Optional[str] = field( + default=None, + metadata={"help": "Weights & Biases run name."}, + ) + wandb_sweep_id: Optional[str] = field( + default=None, + metadata={"help": "Weights & Biases sweep id."}, + ) + wandb_run_id: Optional[str] = field( + default=None, + metadata={"help": "Weights & Biases run id."}, + ) + enable_chrome_trace: bool = field( + default=True, + metadata={"help": "Enable Chrome tracing."}, + ) + save_table_to_wandb: bool = field( + default=False, + metadata={"help": "Whether to save table to wandb."}, + ) + store_plots: bool = field( + default=True, + metadata={"help": "Whether to store plots."}, + ) + store_operation_metrics: bool = field( + default=False, + metadata={"help": "Whether to store operation metrics."}, + ) + store_token_completion_metrics: bool = field( + default=False, + metadata={"help": "Whether to store token completion metrics."}, + ) + store_request_metrics: bool = field( + default=True, + metadata={"help": "Whether to store request metrics."}, + ) + store_batch_metrics: bool = field( + default=True, + metadata={"help": "Whether to store batch metrics."}, + ) + store_utilization_metrics: bool = field( + default=True, + metadata={"help": "Whether to store utilization metrics."}, + ) + keep_individual_batch_metrics: bool = field( + default=False, + metadata={"help": "Whether to keep individual batch metrics."}, + ) + subsamples: Optional[int] = field( + default=None, + metadata={"help": "Subsamples."}, + ) + min_batch_index: Optional[int] = field( + default=None, + metadata={"help": "Minimum batch index."}, + ) + max_batch_index: Optional[int] = field( + default=None, + metadata={"help": "Maximum batch index."}, + ) @dataclass class ReplicaConfig: - model_name: str = "meta-llama/Llama-2-7b-hf" - gpu_memory_utilization: float = 0.8 - pipeline_parallel_size: int = 1 - tensor_parallel_size: int = 1 - device: str = "a100" - network_device: str = "a100_pair_nvlink" + model_name: str = field( + default="meta-llama/Llama-2-7b-hf", + metadata={"help": "Model name."}, + ) + gpu_memory_utilization: float = field( + default=0.8, + metadata={"help": "GPU memory utilization."}, + ) + pipeline_parallel_size: int = field( + default=1, + metadata={"help": "Pipeline parallel size."}, + ) + tensor_parallel_size: int = field( + default=1, + metadata={"help": "Tensor parallel size."}, + ) + device: str = field( + default="a100", + metadata={"help": "Device."}, + ) + network_device: str = field( + default="a100_pair_nvlink", + metadata={"help": "Network device."}, + ) def __post_init__(self): self.world_size = self.pipeline_parallel_size * self.tensor_parallel_size @@ -264,31 +492,94 @@ def get_type(): @dataclass class BaseExecutionTimePredictorConfig(BasePolyConfig): - compute_input_file: str = "./data/profiling/compute/{DEVICE}/{MODEL}/mlp.csv" - attention_input_file: str = "./data/profiling/compute/{DEVICE}/{MODEL}/attention.csv" - all_reduce_input_file: str = "./data/profiling/network/{NETWORK_DEVICE}/all_reduce.csv" - send_recv_input_file: str = "./data/profiling/network/{NETWORK_DEVICE}/send_recv.csv" - cpu_overhead_input_file: str = "./data/profiling/cpu_overhead/{NETWORK_DEVICE}/{MODEL}/cpu_overheads.csv" - k_fold_cv_splits: int = 10 - no_cache: bool = False - kv_cache_prediction_granularity: int = 64 - prediction_max_prefill_chunk_size: int = 4096 - prediction_max_batch_size: int = 128 - prediction_max_tokens_per_request: int = 4096 - attention_decode_batching_overhead_fraction: float = 0.1 - attention_prefill_batching_overhead_fraction: float = 0.1 - nccl_cpu_launch_overhead_ms: float = 0.02 - nccl_cpu_skew_overhead_per_device_ms: float = 0.0 - num_training_job_threads: int = -1 - skip_cpu_overhead_modeling: bool = True + compute_input_file: str = field( + default="./data/profiling/compute/{DEVICE}/{MODEL}/mlp.csv", + metadata={"help": "Path to the compute input file."}, + ) + attention_input_file: str = field( + default="./data/profiling/compute/{DEVICE}/{MODEL}/attention.csv", + metadata={"help": "Path to the attention input file."}, + ) + all_reduce_input_file: str = field( + default="./data/profiling/network/{NETWORK_DEVICE}/all_reduce.csv", + metadata={"help": "Path to the all reduce input file."}, + ) + send_recv_input_file: str = field( + default="./data/profiling/network/{NETWORK_DEVICE}/send_recv.csv", + metadata={"help": "Path to the send recv input file."}, + ) + cpu_overhead_input_file: str = field( + default="./data/profiling/cpu_overhead/{NETWORK_DEVICE}/{MODEL}/cpu_overheads.csv", + metadata={"help": "Path to the cpu overhead input file."}, + ) + k_fold_cv_splits: int = field( + default=10, + metadata={"help": "Number of k fold cross validation splits."}, + ) + no_cache: bool = field( + default=False, + metadata={"help": "Whether to cache prediction models."}, + ) + kv_cache_prediction_granularity: int = field( + default=64, + metadata={"help": "KV cache prediction granularity."}, + ) + prediction_max_prefill_chunk_size: int = field( + default=4096, + metadata={"help": "Max prefill chunk size for prediction."}, + ) + prediction_max_batch_size: int = field( + default=128, + metadata={"help": "Max batch size for prediction."}, + ) + prediction_max_tokens_per_request: int = field( + default=4096, + metadata={"help": "Max tokens per request for prediction."}, + ) + attention_decode_batching_overhead_fraction: float = field( + default=0.1, + metadata={"help": "Attention decode batching overhead fraction."}, + ) + attention_prefill_batching_overhead_fraction: float = field( + default=0.1, + metadata={"help": "Attention prefill batching overhead fraction."}, + ) + nccl_cpu_launch_overhead_ms: float = field( + default=0.02, + metadata={"help": "NCCL CPU launch overhead in ms."}, + ) + nccl_cpu_skew_overhead_per_device_ms: float = field( + default=0.0, + metadata={"help": "NCCL CPU skew overhead per device in ms."}, + ) + num_training_job_threads: int = field( + default=-1, + metadata={"help": "Number of training job threads."}, + ) + skip_cpu_overhead_modeling: bool = field( + default=True, + metadata={"help": "Whether to skip CPU overhead modeling."}, + ) @dataclass class LinearRegressionExecutionTimePredictorConfig(BaseExecutionTimePredictorConfig): - polynomial_degree: List[int] = field(default_factory=lambda: list(range(1, 6))) - polynomial_include_bias: List[bool] = field(default_factory=lambda: [True, False]) - polynomial_interaction_only: List[bool] = field(default_factory=lambda: [True, False]) - fit_intercept: List[bool] = field(default_factory=lambda: [True, False]) + polynomial_degree: List[int] = field( + default_factory=lambda: list(range(1, 6)), + metadata={"help": "Polynomial degree for linear regression."}, + ) + polynomial_include_bias: List[bool] = field( + default_factory=lambda: [True, False], + metadata={"help": "Polynomial include bias for linear regression."}, + ) + polynomial_interaction_only: List[bool] = field( + default_factory=lambda: [True, False], + metadata={"help": "Polynomial interaction only for linear regression."}, + ) + fit_intercept: List[bool] = field( + default_factory=lambda: [True, False], + metadata={"help": "Fit intercept for linear regression."}, + ) @staticmethod def get_type(): @@ -297,9 +588,18 @@ def get_type(): @dataclass class RandomForrestExecutionTimePredictorConfig(BaseExecutionTimePredictorConfig): - num_estimators: List[int] = field(default_factory=lambda: [250, 500, 750]) - max_depth: List[int] = field(default_factory=lambda: [8, 16, 32]) - min_samples_split: List[int] = field(default_factory=lambda: [2, 5, 10]) + num_estimators: List[int] = field( + default_factory=lambda: [250, 500, 750], + metadata={"help": "Number of estimators for random forest."}, + ) + max_depth: List[int] = field( + default_factory=lambda: [8, 16, 32], + metadata={"help": "Maximum depth for random forest."}, + ) + min_samples_split: List[int] = field( + default_factory=lambda: [2, 5, 10], + metadata={"help": "Minimum samples split for random forest."}, + ) @staticmethod def get_type(): @@ -308,29 +608,58 @@ def get_type(): @dataclass class ClusterConfig: - num_replicas: int = 1 + num_replicas: int = field( + default=1, + metadata={"help": "Number of replicas."}, + ) replica_config: ReplicaConfig = field(default_factory=ReplicaConfig) execution_time_predictor_config: BaseExecutionTimePredictorConfig = field( - default_factory=RandomForrestExecutionTimePredictorConfig + default_factory=RandomForrestExecutionTimePredictorConfig, + metadata={"help": "Execution time predictor config."}, ) global_scheduler_config: BaseGlobalSchedulerConfig = field( - default_factory=RoundRobinGlobalSchedulerConfig + default_factory=RoundRobinGlobalSchedulerConfig, + metadata={"help": "Global scheduler config."}, ) replica_scheduler_config: BaseReplicaSchedulerConfig = field( - default_factory=SarathiSchedulerConfig + default_factory=SarathiSchedulerConfig, + metadata={"help": "Replica scheduler config."}, + ) + metrics_config: MetricsConfig = field( + default_factory=MetricsConfig, + metadata={"help": "Metrics config."}, ) - metrics_config: MetricsConfig = field(default_factory=MetricsConfig) @dataclass class SimulationConfig(ABC): - log_level: str = "info" - output_dir: str = "simulator_output" - cache_dir: str = "cache" - time_limit: int = 0 # in seconds, 0 is no limit - cluster_config: ClusterConfig = field(default_factory=ClusterConfig) + seed: int = field( + default=42, + metadata={"help": "Seed for the random number generator."}, + ) + log_level: str = field( + default="info", + metadata={"help": "Logging level."}, + ) + output_dir: str = field( + default="simulator_output", + metadata={"help": "Output directory."}, + ) + cache_dir: str = field( + default="cache", + metadata={"help": "Cache directory."}, + ) + time_limit: int = field( + default=0, # in seconds, 0 is no limit + metadata={"help": "Time limit for simulation in seconds. 0 means no limit."}, + ) + cluster_config: ClusterConfig = field( + default_factory=ClusterConfig, + metadata={"help": "Cluster config."}, + ) request_generator_config: BaseRequestGeneratorConfig = field( - default_factory=SyntheticRequestGeneratorConfig + default_factory=SyntheticRequestGeneratorConfig, + metadata={"help": "Request generator config."}, ) def __post_init__(self): diff --git a/vidur/config/device_sku_config.py b/vidur/config/device_sku_config.py index be4411f..73c0cb7 100644 --- a/vidur/config/device_sku_config.py +++ b/vidur/config/device_sku_config.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field from vidur.config.base_poly_config import BasePolyConfig from vidur.logger import init_logger @@ -9,15 +9,27 @@ @dataclass class BaseDeviceSKUConfig(BasePolyConfig): - fp16_tflops: int - total_memory_gb: int - num_devices_per_node: int + fp16_tflops: int = field( + metadata={"help": "The number of TFLOPS the device can achieve in FP16"}, + ) + total_memory_gb: int = field( + metadata={"help": "The total memory of the device in GB"}, + ) + num_devices_per_node: int = field( + metadata={"help": "The number of devices per node"}, + ) @dataclass class A100DeviceSKUConfig(BaseDeviceSKUConfig): - fp16_tflops: int = 312 - total_memory_gb: int = 80 + fp16_tflops: int = field( + default=312, + metadata={"help": "The number of TFLOPS the device can achieve in FP16"}, + ) + total_memory_gb: int = field( + default=80, + metadata={"help": "The total memory of the device in GB"}, + ) @staticmethod def get_type(): @@ -26,8 +38,14 @@ def get_type(): @dataclass class A40DeviceSKUConfig(BaseDeviceSKUConfig): - fp16_tflops: int = 150 - total_memory_gb: int = 45 + fp16_tflops: int = field( + default=150, + metadata={"help": "The number of TFLOPS the device can achieve in FP16"}, + ) + total_memory_gb: int = field( + default=45, + metadata={"help": "The total memory of the device in GB"}, + ) @staticmethod def get_type(): @@ -36,8 +54,14 @@ def get_type(): @dataclass class H100DeviceSKUConfig(BaseDeviceSKUConfig): - fp16_tflops: int = 1000 - total_memory_gb: int = 80 + fp16_tflops: int = field( + default=1000, + metadata={"help": "The number of TFLOPS the device can achieve in FP16"}, + ) + total_memory_gb: int = field( + default=80, + metadata={"help": "The total memory of the device in GB"}, + ) @staticmethod def get_type(): diff --git a/vidur/config/flat_dataclass.py b/vidur/config/flat_dataclass.py index 82cb9fa..c1ae755 100644 --- a/vidur/config/flat_dataclass.py +++ b/vidur/config/flat_dataclass.py @@ -50,7 +50,7 @@ def reconstruct_original_dataclass(self) -> Any: for _cls in reversed(sorted_classes): args = {} - for prefixed_filed_name, original_field_name, field_type in self.dataclass_args[ + for prefixed_field_name, original_field_name, field_type in self.dataclass_args[ _cls ]: if is_subclass(field_type, BasePolyConfig): @@ -63,7 +63,7 @@ def reconstruct_original_dataclass(self) -> Any: elif hasattr(field_type, "__dataclass_fields__"): args[original_field_name] = instances[field_type] else: - args[original_field_name] = getattr(self, prefixed_filed_name) + args[original_field_name] = getattr(self, prefixed_field_name) instances[_cls] = _cls(**args) @@ -80,6 +80,7 @@ def create_from_cli_args(cls) -> Any: for field in fields(cls): nargs = None field_type = field.type + help_text = field.metadata.get("help", None) if is_list(field.type): assert is_composed_of_primitives(field.type) @@ -95,7 +96,7 @@ def create_from_cli_args(cls) -> Any: # handle cases with default and default factory args if field.default is not MISSING: parser.add_argument( - f"--{field.name}", type=field_type, default=field.default, nargs=nargs + f"--{field.name}", type=field_type, default=field.default, nargs=nargs, help=help_text ) elif field.default_factory is not MISSING: parser.add_argument( @@ -103,10 +104,11 @@ def create_from_cli_args(cls) -> Any: type=field_type, default=field.default_factory(), nargs=nargs, + help=help_text, ) else: parser.add_argument( - f"--{field.name}", type=field_type, required=True, nargs=nargs + f"--{field.name}", type=field_type, required=True, nargs=nargs, help=help_text ) args = parser.parse_args() diff --git a/vidur/config/model_config.py b/vidur/config/model_config.py index 815d746..ce83610 100644 --- a/vidur/config/model_config.py +++ b/vidur/config/model_config.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any, Dict, Optional from vidur.config.base_poly_config import BasePolyConfig @@ -10,50 +10,145 @@ @dataclass class BaseModelConfig(BasePolyConfig): - num_layers: int - num_q_heads: int - num_kv_heads: int - embedding_dim: int - mlp_hidden_dim: int - max_position_embeddings: int - use_gated_mlp: bool - use_bias: bool - use_qkv_bias: bool - activation: ActivationType - norm: NormType - post_attn_norm: bool - vocab_size: int - is_neox_style: Optional[bool] = True - rope_theta: Optional[int] = None - rope_scaling: Optional[Dict[str, Any]] = None - partial_rotary_factor: float = 1.0 - no_tensor_parallel: bool = False + num_layers: int = field( + metadata={"help": "The number of layers in the model"}, + ) + num_q_heads: int = field( + metadata={"help": "The number of query heads in the model"}, + ) + num_kv_heads: int = field( + metadata={"help": "The number of key-value heads in the model"}, + ) + embedding_dim: int = field( + metadata={"help": "The embedding dimension of the model"}, + ) + mlp_hidden_dim: int = field( + metadata={"help": "The hidden dimension of the MLP in the model"}, + ) + max_position_embeddings: int = field( + metadata={"help": "The maximum position embeddings in the model"}, + ) + use_gated_mlp: bool = field( + metadata={"help": "Whether to use gated MLP in the model"}, + ) + use_bias: bool = field( + metadata={"help": "Whether to use bias in the model"}, + ) + use_qkv_bias: bool = field( + metadata={"help": "Whether to use bias in the QKV in the model"}, + ) + activation: ActivationType = field( + metadata={"help": "The activation function in the model"}, + ) + norm: NormType = field( + metadata={"help": "The normalization function in the model"}, + ) + post_attn_norm: bool = field( + metadata={"help": "Whether to use post-attention normalization in the model"}, + ) + vocab_size: int = field( + metadata={"help": "The vocabulary size of the model"}, + ) + is_neox_style: Optional[bool] = field( + default=True, + metadata={"help": "Whether to use the Neox style in the model"}, + ) + rope_theta: Optional[int] = field( + default=None, + metadata={"help": "The rope theta in the model"}, + ) + rope_scaling: Optional[Dict[str, Any]] = field( + default=None, + metadata={"help": "The rope scaling config for the model"}, + ) + partial_rotary_factor: float = field( + default=1.0, + metadata={"help": "The partial rotary factor in the model"}, + ) + no_tensor_parallel: bool = field( + default=False, + metadata={"help": "Whether to use tensor parallelism in the model"}, + ) @dataclass class Llama2ModelConfig(BaseModelConfig): - max_position_embeddings: int = 16384 - use_gated_mlp: bool = True - use_bias: bool = False - use_qkv_bias: bool = False - activation: ActivationType = ActivationType.SILU - norm: NormType = NormType.RMS_NORM - post_attn_norm: bool = True - vocab_size: int = 32768 - is_neox_style: Optional[bool] = True - rope_theta: Optional[int] = 10000.0 - rope_scaling: Optional[Dict[str, Any]] = None - partial_rotary_factor: float = 1.0 - no_tensor_parallel: bool = False + max_position_embeddings: int = field( + default=16384, + metadata={"help": "The maximum position embeddings in the model"}, + ) + use_gated_mlp: bool = field( + default=True, + metadata={"help": "Whether to use gated MLP in the model"}, + ) + use_bias: bool = field( + default=False, + metadata={"help": "Whether to use bias in the model"}, + ) + use_qkv_bias: bool = field( + default=False, + metadata={"help": "Whether to use bias in the QKV in the model"}, + ) + activation: ActivationType = field( + default=ActivationType.SILU, + metadata={"help": "The activation function in the model"}, + ) + norm: NormType = field( + default=NormType.RMS_NORM, + metadata={"help": "The normalization function in the model"}, + ) + post_attn_norm: bool = field( + default=True, + metadata={"help": "Whether to use post-attention normalization in the model"}, + ) + vocab_size: int = field( + default=32768, + metadata={"help": "The vocabulary size of the model"}, + ) + is_neox_style: Optional[bool] = field( + default=True, + metadata={"help": "Whether to use the Neox style in the model"}, + ) + rope_theta: Optional[int] = field( + default=10000.0, + metadata={"help": "The rope theta in the model"}, + ) + rope_scaling: Optional[Dict[str, Any]] = field( + default=None, + metadata={"help": "The rope scaling config for the model"}, + ) + partial_rotary_factor: float = field( + default=1.0, + metadata={"help": "The partial rotary factor in the model"}, + ) + no_tensor_parallel: bool = field( + default=False, + metadata={"help": "Whether to use tensor parallelism in the model"}, + ) @dataclass class CodeLlama34BModelConfig(Llama2ModelConfig): - num_layers: int = 48 - num_q_heads: int = 64 - num_kv_heads: int = 8 - embedding_dim: int = 8192 - mlp_hidden_dim: int = 22016 + num_layers: int = field( + default=48, + metadata={"help": "The number of layers in the model"}, + ) + num_q_heads: int = field( + default=64, + metadata={"help": "The number of query heads in the model"}, + ) + num_kv_heads: int = field( + default=8, + metadata={"help": "The number of key-value heads in the model"}, + ) + embedding_dim: int = field( + default=8192, + metadata={"help": "The embedding dimension of the model"}, + ) + mlp_hidden_dim: int = field( + default=22016, + metadata={"help": "The hidden dimension of the MLP in the model"}, + ) @staticmethod def get_type(): @@ -62,11 +157,26 @@ def get_type(): @dataclass class Llama2_7BModelConfig(Llama2ModelConfig): - num_layers: int = 32 - num_q_heads: int = 32 - num_kv_heads: int = 32 - embedding_dim: int = 4096 - mlp_hidden_dim: int = 11008 + num_layers: int = field( + default=32, + metadata={"help": "The number of layers in the model"}, + ) + num_q_heads: int = field( + default=32, + metadata={"help": "The number of query heads in the model"}, + ) + num_kv_heads: int = field( + default=32, + metadata={"help": "The number of key-value heads in the model"}, + ) + embedding_dim: int = field( + default=4096, + metadata={"help": "The embedding dimension of the model"}, + ) + mlp_hidden_dim: int = field( + default=11008, + metadata={"help": "The hidden dimension of the MLP in the model"}, + ) @staticmethod def get_type(): @@ -75,11 +185,26 @@ def get_type(): @dataclass class Llama2_70BModelConfig(Llama2ModelConfig): - num_layers: int = 80 - num_q_heads: int = 64 - num_kv_heads: int = 8 - embedding_dim: int = 8192 - mlp_hidden_dim: int = 28672 + num_layers: int = field( + default=80, + metadata={"help": "The number of layers in the model"}, + ) + num_q_heads: int = field( + default=64, + metadata={"help": "The number of query heads in the model"}, + ) + num_kv_heads: int = field( + default=8, + metadata={"help": "The number of key-value heads in the model"}, + ) + embedding_dim: int = field( + default=8192, + metadata={"help": "The embedding dimension of the model"}, + ) + mlp_hidden_dim: int = field( + default=28672, + metadata={"help": "The hidden dimension of the MLP in the model"}, + ) @staticmethod def get_type(): @@ -88,17 +213,38 @@ def get_type(): @dataclass class InternLM2ModelConfig(Llama2ModelConfig): - max_position_embeddings: int = 32768 - vocab_size: int = 92544 + max_position_embeddings: int = field( + default=32768, + metadata={"help": "The maximum position embeddings in the model"}, + ) + vocab_size: int = field( + default=92544, + metadata={"help": "The vocabulary size of the model"}, + ) @dataclass class InternLM2_20BModelConfig(InternLM2ModelConfig): - num_layers: int = 48 - num_q_heads: int = 48 - num_kv_heads: int = 8 - embedding_dim: int = 6144 - mlp_hidden_dim: int = 16384 + num_layers: int = field( + default=48, + metadata={"help": "The number of layers in the model"}, + ) + num_q_heads: int = field( + default=48, + metadata={"help": "The number of query heads in the model"}, + ) + num_kv_heads: int = field( + default=8, + metadata={"help": "The number of key-value heads in the model"}, + ) + embedding_dim: int = field( + default=6144, + metadata={"help": "The embedding dimension of the model"}, + ) + mlp_hidden_dim: int = field( + default=16384, + metadata={"help": "The hidden dimension of the MLP in the model"}, + ) @staticmethod def get_type(): @@ -107,24 +253,78 @@ def get_type(): @dataclass class Phi2ModelConfig(Llama2ModelConfig): - num_layers: int = 32 - num_q_heads: int = 32 - num_kv_heads: int = 32 - embedding_dim: int = 2560 - mlp_hidden_dim: int = 10240 - max_position_embeddings: int = 2048 - use_gated_mlp: bool = False - use_bias: bool = True - use_qkv_bias: bool = True - activation: ActivationType = ActivationType.GELU - norm: NormType = NormType.LAYER_NORM - post_attn_norm: bool = False - vocab_size: int = 51200 - rope_scaling: Optional[Dict[str, Any]] = None - rope_theta: Optional[int] = 10000.0 - partial_rotary_factor: float = 0.4 - no_tensor_parallel: bool = True - is_neox_style: bool = True + num_layers: int = field( + default=32, + metadata={"help": "The number of layers in the model"}, + ) + num_q_heads: int = field( + default=32, + metadata={"help": "The number of query heads in the model"}, + ) + num_kv_heads: int = field( + default=32, + metadata={"help": "The number of key-value heads in the model"}, + ) + embedding_dim: int = field( + default=2560, + metadata={"help": "The embedding dimension of the model"}, + ) + mlp_hidden_dim: int = field( + default=10240, + metadata={"help": "The hidden dimension of the MLP in the model"}, + ) + max_position_embeddings: int = field( + default=2048, + metadata={"help": "The maximum position embeddings in the model"}, + ) + use_gated_mlp: bool = field( + default=False, + metadata={"help": "Whether to use gated MLP in the model"}, + ) + use_bias: bool = field( + default=True, + metadata={"help": "Whether to use bias in the model"}, + ) + use_qkv_bias: bool = field( + default=True, + metadata={"help": "Whether to use bias in the QKV in the model"}, + ) + activation: ActivationType = field( + default=ActivationType.GELU, + metadata={"help": "The activation function in the model"}, + ) + norm: NormType = field( + default=NormType.LAYER_NORM, + metadata={"help": "The normalization function in the model"}, + ) + post_attn_norm: bool = field( + default=False, + metadata={"help": "Whether to use post-attention normalization in the model"}, + ) + vocab_size: int = field( + default=51200, + metadata={"help": "The vocabulary size of the model"}, + ) + rope_scaling: Optional[Dict[str, Any]] = field( + default=None, + metadata={"help": "The rope scaling config for the model"}, + ) + rope_theta: Optional[int] = field( + default=10000.0, + metadata={"help": "The rope theta in the model"}, + ) + partial_rotary_factor: float = field( + default=0.4, + metadata={"help": "The partial rotary factor in the model"}, + ) + no_tensor_parallel: bool = field( + default=True, + metadata={"help": "Whether to use tensor parallelism in the model"}, + ) + is_neox_style: bool = field( + default=True, + metadata={"help": "Whether to use the Neox style in the model"}, + ) @staticmethod def get_type(): @@ -133,18 +333,42 @@ def get_type(): @dataclass class QwenModelConfig(Llama2ModelConfig): - use_qkv_bias: bool = True - max_position_embeddings: int = 32768 - vocab_size: int = 152064 + use_qkv_bias: bool = field( + default=True, + metadata={"help": "Whether to use bias in the QKV in the model"}, + ) + max_position_embeddings: int = field( + default=32768, + metadata={"help": "The maximum position embeddings in the model"}, + ) + vocab_size: int = field( + default=152064, + metadata={"help": "The vocabulary size of the model"}, + ) @dataclass class Qwen72BModelConfig(QwenModelConfig): - num_layers: int = 80 - num_q_heads: int = 64 - num_kv_heads: int = 64 - embedding_dim: int = 8192 - mlp_hidden_dim: int = 24576 + num_layers: int = field( + default=80, + metadata={"help": "The number of layers in the model"}, + ) + num_q_heads: int = field( + default=64, + metadata={"help": "The number of query heads in the model"}, + ) + num_kv_heads: int = field( + default=64, + metadata={"help": "The number of key-value heads in the model"}, + ) + embedding_dim: int = field( + default=8192, + metadata={"help": "The embedding dimension of the model"}, + ) + mlp_hidden_dim: int = field( + default=24576, + metadata={"help": "The hidden dimension of the MLP in the model"}, + ) @staticmethod def get_type(): diff --git a/vidur/config/node_sku_config.py b/vidur/config/node_sku_config.py index 412f532..ababc1a 100644 --- a/vidur/config/node_sku_config.py +++ b/vidur/config/node_sku_config.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field from vidur.config.base_poly_config import BasePolyConfig from vidur.logger import init_logger @@ -9,13 +9,21 @@ @dataclass class BaseNodeSKUConfig(BasePolyConfig): - num_devices_per_node: int + num_devices_per_node: int = field( + metadata={"help": "The number of devices per node"}, + ) @dataclass class A40PairwiseNvlinkNodeSKUConfig(BaseNodeSKUConfig): - device_sku_type: DeviceSKUType = DeviceSKUType.A40 - num_devices_per_node: int = 8 + device_sku_type: DeviceSKUType = field( + default=DeviceSKUType.A40, + metadata={"help": "The device SKU type"}, + ) + num_devices_per_node: int = field( + default=8, + metadata={"help": "The number of devices per node"}, + ) @staticmethod def get_type(): @@ -24,8 +32,14 @@ def get_type(): @dataclass class A100PairwiseNvlinkNodeSKUConfig(BaseNodeSKUConfig): - device_sku_type: DeviceSKUType = DeviceSKUType.A100 - num_devices_per_node: int = 8 + device_sku_type: DeviceSKUType = field( + default=DeviceSKUType.A100, + metadata={"help": "The device SKU type"}, + ) + num_devices_per_node: int = field( + default=8, + metadata={"help": "The number of devices per node"}, + ) @staticmethod def get_type(): @@ -34,8 +48,14 @@ def get_type(): @dataclass class H100PairwiseNvlinkNodeSKUConfig(BaseNodeSKUConfig): - device_sku_type: DeviceSKUType = DeviceSKUType.H100 - num_devices_per_node: int = 8 + device_sku_type: DeviceSKUType = field( + default=DeviceSKUType.H100, + metadata={"help": "The device SKU type"}, + ) + num_devices_per_node: int = field( + default=8, + metadata={"help": "The number of devices per node"}, + ) @staticmethod def get_type(): @@ -44,8 +64,14 @@ def get_type(): @dataclass class A100DgxNodeSKUConfig(BaseNodeSKUConfig): - device_sku_type: DeviceSKUType = DeviceSKUType.A100 - num_devices_per_node: int = 8 + device_sku_type: DeviceSKUType = field( + default=DeviceSKUType.A100, + metadata={"help": "The device SKU type"}, + ) + num_devices_per_node: int = field( + default=8, + metadata={"help": "The number of devices per node"}, + ) @staticmethod def get_type(): @@ -54,8 +80,14 @@ def get_type(): @dataclass class H100DgxNodeSKUConfig(BaseNodeSKUConfig): - device_sku_type: DeviceSKUType = DeviceSKUType.H100 - num_devices_per_node: int = 8 + device_sku_type: DeviceSKUType = field( + default=DeviceSKUType.H100, + metadata={"help": "The device SKU type"}, + ) + num_devices_per_node: int = field( + default=8, + metadata={"help": "The number of devices per node"}, + ) @staticmethod def get_type(): diff --git a/vidur/entities/cluster.py b/vidur/entities/cluster.py index 34a8f48..0250c5e 100644 --- a/vidur/entities/cluster.py +++ b/vidur/entities/cluster.py @@ -1,6 +1,6 @@ import json -from vidur.config import Config +from vidur.config import SimulationConfig from vidur.entities.base_entity import BaseEntity from vidur.entities.replica import Replica from vidur.logger import init_logger @@ -9,18 +9,18 @@ class Cluster(BaseEntity): - def __init__(self, config: Config): + def __init__(self, config: SimulationConfig): self._id = Cluster.generate_id() - self._config = config + self._config: SimulationConfig = config # Init replica object handles self._replicas = {} - for _ in range(self._config.cluster_num_replicas): + for _ in range(self._config.cluster_config.num_replicas): replica = Replica(config) self._replicas[replica.id] = replica - if self._config.write_json_trace: + if self._config.cluster_config.metrics_config.write_json_trace: self._write_cluster_info_to_file() @property diff --git a/vidur/main.py b/vidur/main.py index e5def45..18406fb 100644 --- a/vidur/main.py +++ b/vidur/main.py @@ -1,10 +1,10 @@ -from vidur.config import Config +from vidur.config import SimulationConfig from vidur.simulator import Simulator from vidur.utils.random import set_seeds -def main(): - config = Config() +def main() -> None: + config: SimulationConfig = SimulationConfig.create_from_cli_args() set_seeds(config.seed) diff --git a/vidur/metrics/metrics_store.py b/vidur/metrics/metrics_store.py index 66c060e..38a5668 100644 --- a/vidur/metrics/metrics_store.py +++ b/vidur/metrics/metrics_store.py @@ -6,7 +6,7 @@ import plotly_express as px import wandb -from vidur.config import Config +from vidur.config import SimulationConfig, MetricsConfig from vidur.entities import Batch, BatchStage, ExecutionTime, Request from vidur.logger import init_logger from vidur.metrics.cdf_sketch import CDFSketch @@ -48,37 +48,40 @@ def wrapper(self, *args, **kwargs): class MetricsStore: - def __init__(self, config: Config): - self._config = config - self._num_replicas = config.cluster_num_replicas - self._num_stages = config.replica_num_pipeline_stages - self._should_write_metrics = config.write_metrics - self._subsamples = config.metrics_store_subsamples - self._save_table_to_wandb = config.metrics_store_save_table_to_wandb - self._save_plots = config.metrics_store_store_plots + + def __init__(self, config: SimulationConfig): + self._config: SimulationConfig = config + metrics_config: MetricsConfig = metrics_config + + self._num_replicas = config.cluster_config.num_replicas + self._num_stages = config.cluster_config.replica_scheduler_config.num_pipeline_stages + self._should_write_metrics = metrics_config.write_metrics + self._subsamples = metrics_config.subsamples + self._save_table_to_wandb = metrics_config.save_table_to_wandb + self._save_plots = metrics_config.store_plots self._keep_individual_batch_metrics = ( - config.metrics_store_keep_individual_batch_metrics + metrics_config.keep_individual_batch_metrics ) - self._wandb_project = config.metrics_store_wandb_project - self._wandb_group = config.metrics_store_wandb_group - self._wandb_run_name = config.metrics_store_wandb_run_name + self._wandb_project = metrics_config.wandb_project + self._wandb_group = metrics_config.wandb_group + self._wandb_run_name = metrics_config.wandb_run_name - self._min_batch_idx = config.metrics_store_min_batch_idx - self._max_batch_idx = config.metrics_store_max_batch_idx + self._min_batch_idx = metrics_config.min_batch_index + self._max_batch_idx = metrics_config.max_batch_index self._last_request_arrived_at = None self._should_store_token_completion_metrics = ( - config.metrics_store_store_token_completion_metrics + metrics_config.store_token_completion_metrics ) self._should_store_utilization_metrics = ( - config.metrics_store_store_utilization_metrics + metrics_config.store_utilization_metrics ) - self._should_store_batch_metrics = config.metrics_store_store_batch_metrics + self._should_store_batch_metrics = metrics_config.store_batch_metrics self._should_store_operation_metrics = ( - config.metrics_store_store_operation_metrics + metrics_config.store_operation_metrics ) - self._should_store_request_metrics = config.metrics_store_store_request_metrics + self._should_store_request_metrics = metrics_config.store_request_metrics # Initialise request metrics self._request_metrics_time_distributions: Dict[ diff --git a/vidur/request_generator/base_request_generator.py b/vidur/request_generator/base_request_generator.py index e04b34b..912e7a6 100644 --- a/vidur/request_generator/base_request_generator.py +++ b/vidur/request_generator/base_request_generator.py @@ -2,20 +2,14 @@ from abc import ABC, abstractmethod from typing import List -from vidur.config import Config +from vidur.config import BaseRequestGeneratorConfig from vidur.entities import Request class BaseRequestGenerator(ABC): - def __init__(self, config: Config): - self._config = config - self._should_write_json_trace = config.write_json_trace - def _write_requests_to_file(self, requests: List[Request]) -> None: - request_dicts = [request.to_dict() for request in requests] - request_file = f"{self._config.output_dir}/requests.json" - with open(request_file, "w") as f: - json.dump(request_dicts, f) + def __init__(self, config: BaseRequestGeneratorConfig): + self.config = config @abstractmethod def generate_requests(self) -> List[Request]: @@ -23,8 +17,4 @@ def generate_requests(self) -> List[Request]: def generate(self) -> List[Request]: requests = self.generate_requests() - - if self._should_write_json_trace: - self._write_requests_to_file(requests) - return requests diff --git a/vidur/request_generator/base_request_interval_generator.py b/vidur/request_generator/base_request_interval_generator.py index ca8a175..d0370e8 100644 --- a/vidur/request_generator/base_request_interval_generator.py +++ b/vidur/request_generator/base_request_interval_generator.py @@ -1,11 +1,12 @@ from abc import ABC, abstractmethod -from vidur.config import Config +from vidur.config import BaseRequestIntervalGeneratorConfig class BaseRequestIntervalGenerator(ABC): - def __init__(self, config: Config): - self._config = config + + def __init__(self, config: BaseRequestIntervalGeneratorConfig): + self.config = config @abstractmethod def get_next_inter_request_time(self) -> float: diff --git a/vidur/request_generator/base_request_length_generator.py b/vidur/request_generator/base_request_length_generator.py index 3d31028..7162ffc 100644 --- a/vidur/request_generator/base_request_length_generator.py +++ b/vidur/request_generator/base_request_length_generator.py @@ -1,12 +1,13 @@ from abc import ABC, abstractmethod from typing import Tuple -from vidur.config import Config +from vidur.config import BaseRequestLengthGeneratorConfig class BaseRequestLengthGenerator(ABC): - def __init__(self, config: Config): - self._config = config + + def __init__(self, config: BaseRequestLengthGeneratorConfig): + self.config = config @abstractmethod def get_next_num_tokens(self) -> Tuple[float, float]: diff --git a/vidur/request_generator/fixed_request_length_generator.py b/vidur/request_generator/fixed_request_length_generator.py index 4d52a94..05eb11e 100644 --- a/vidur/request_generator/fixed_request_length_generator.py +++ b/vidur/request_generator/fixed_request_length_generator.py @@ -6,8 +6,9 @@ class FixedRequestLengthGenerator(BaseRequestLengthGenerator): + def get_next_num_tokens(self) -> Tuple[float, float]: return ( - self._config.fixed_request_generator_prefill_tokens, - self._config.fixed_request_generator_decode_tokens, + self.config.prefill_tokens, + self.config.decode_tokens, ) diff --git a/vidur/request_generator/gamma_request_interval_generator.py b/vidur/request_generator/gamma_request_interval_generator.py index fc02a50..f85ca13 100644 --- a/vidur/request_generator/gamma_request_interval_generator.py +++ b/vidur/request_generator/gamma_request_interval_generator.py @@ -1,18 +1,20 @@ from scipy.stats import gamma +from vidur.config import GammaRequestIntervalGeneratorConfig from vidur.request_generator.base_request_interval_generator import ( BaseRequestIntervalGenerator, ) class GammaRequestIntervalGenerator(BaseRequestIntervalGenerator): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - cv = self._config.gamma_request_interval_generator_cv - self._qps = self._config.gamma_request_interval_generator_qps - self._gamma_shape = 1.0 / (cv**2) + def __init__(self, config: GammaRequestIntervalGeneratorConfig): + super().__init__(config) + + cv = self.config.cv + self.qps = self.config.qps + self.gamma_shape = 1.0 / (cv**2) def get_next_inter_request_time(self) -> float: - gamma_scale = 1.0 / (self._qps * self._gamma_shape) - return gamma.rvs(self._gamma_shape, scale=gamma_scale) + gamma_scale = 1.0 / (self.qps * self.gamma_shape) + return gamma.rvs(self.gamma_shape, scale=gamma_scale) diff --git a/vidur/request_generator/poisson_request_interval_generator.py b/vidur/request_generator/poisson_request_interval_generator.py index 2be7b31..53a067c 100644 --- a/vidur/request_generator/poisson_request_interval_generator.py +++ b/vidur/request_generator/poisson_request_interval_generator.py @@ -1,20 +1,23 @@ import math import random +from vidur.config import PoissonRequestIntervalGeneratorConfig from vidur.request_generator.base_request_interval_generator import ( BaseRequestIntervalGenerator, ) class PoissonRequestIntervalGenerator(BaseRequestIntervalGenerator): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._qps = self._config.poisson_request_interval_generator_qps - self._std = 1.0 / self._qps - self._max_interval = self._std * 3.0 + def __init__(self, config: PoissonRequestIntervalGeneratorConfig): + super().__init__(config) + + self.qps = self.config.qps + self.std = 1.0 / self.qps + self.max_interval = self.std * 3.0 def get_next_inter_request_time(self) -> float: - next_interval = -math.log(1.0 - random.random()) / self._qps - next_interval = min(next_interval, self._max_interval) + next_interval = -math.log(1.0 - random.random()) / self.qps + next_interval = min(next_interval, self.max_interval) + return next_interval diff --git a/vidur/request_generator/request_generator_registry.py b/vidur/request_generator/request_generator_registry.py index 9e79008..44c920f 100644 --- a/vidur/request_generator/request_generator_registry.py +++ b/vidur/request_generator/request_generator_registry.py @@ -9,9 +9,7 @@ class RequestGeneratorRegistry(BaseRegistry): - @classmethod - def get_key_from_str(cls, key_str: str) -> RequestGeneratorType: - return RequestGeneratorType.from_str(key_str) + pass RequestGeneratorRegistry.register( diff --git a/vidur/request_generator/request_interval_generator_registry.py b/vidur/request_generator/request_interval_generator_registry.py index 4d1e570..4296161 100644 --- a/vidur/request_generator/request_interval_generator_registry.py +++ b/vidur/request_generator/request_interval_generator_registry.py @@ -15,9 +15,7 @@ class RequestIntervalGeneratorRegistry(BaseRegistry): - @classmethod - def get_key_from_str(cls, key_str: str) -> RequestIntervalGeneratorType: - return RequestIntervalGeneratorType.from_str(key_str) + pass RequestIntervalGeneratorRegistry.register( diff --git a/vidur/request_generator/request_length_generator_registry.py b/vidur/request_generator/request_length_generator_registry.py index 7cdec9a..12775cc 100644 --- a/vidur/request_generator/request_length_generator_registry.py +++ b/vidur/request_generator/request_length_generator_registry.py @@ -15,9 +15,7 @@ class RequestLengthGeneratorRegistry(BaseRegistry): - @classmethod - def get_key_from_str(cls, key_str: str) -> RequestLengthGeneratorType: - return RequestLengthGeneratorType.from_str(key_str) + pass RequestLengthGeneratorRegistry.register( diff --git a/vidur/request_generator/static_request_interval_generator.py b/vidur/request_generator/static_request_interval_generator.py index 87ad49a..57eae72 100644 --- a/vidur/request_generator/static_request_interval_generator.py +++ b/vidur/request_generator/static_request_interval_generator.py @@ -4,5 +4,6 @@ class StaticRequestIntervalGenerator(BaseRequestIntervalGenerator): + def get_next_inter_request_time(self) -> float: return 0 diff --git a/vidur/request_generator/synthetic_request_generator.py b/vidur/request_generator/synthetic_request_generator.py index aa8684c..596f768 100644 --- a/vidur/request_generator/synthetic_request_generator.py +++ b/vidur/request_generator/synthetic_request_generator.py @@ -1,7 +1,10 @@ from typing import List +from vidur.config import SyntheticRequestGeneratorConfig from vidur.entities import Request -from vidur.request_generator.base_request_generator import BaseRequestGenerator +from vidur.request_generator.base_request_generator import ( + BaseRequestGenerator, +) from vidur.request_generator.request_interval_generator_registry import ( RequestIntervalGeneratorRegistry, ) @@ -12,23 +15,22 @@ class SyntheticRequestGenerator(BaseRequestGenerator): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._seed = self._config.seed + def __init__(self, config: SyntheticRequestGeneratorConfig): + super().__init__(config) - self._request_length_generator = RequestLengthGeneratorRegistry.get_from_str( - self._config.synthetic_request_generator_length_provider, self._config + self.request_length_generator = RequestLengthGeneratorRegistry.get( + self.config.length_generator_config.get_type(), + self.config.length_generator_config, ) - self._request_interval_generator = ( - RequestIntervalGeneratorRegistry.get_from_str( - self._config.synthetic_request_generator_interval_provider, self._config - ) + self.request_interval_generator = RequestIntervalGeneratorRegistry.get( + self.config.interval_generator_config.get_type(), + self.config.interval_generator_config, ) def _generate_next_request(self, last_arrived_at: float) -> Request: inter_request_time = ( - self._request_interval_generator.get_next_inter_request_time() + self.request_interval_generator.get_next_inter_request_time() ) if inter_request_time is None: return None @@ -37,7 +39,7 @@ def _generate_next_request(self, last_arrived_at: float) -> Request: ( prefill_tokens, decode_tokens, - ) = self._request_length_generator.get_next_num_tokens() + ) = self.request_length_generator.get_next_num_tokens() if prefill_tokens is None or decode_tokens is None: return None @@ -54,18 +56,22 @@ def _generate_requests(self) -> List[Request]: current_time = 0 # first priority is duration - if self._config.synthetic_request_generator_duration is not None: - while current_time < self._config.synthetic_request_generator_duration: + if self.config.duration is not None: + while current_time < self.config.duration: request = self._generate_next_request(current_time) current_time = request.arrived_at requests.append(request) - elif self._config.synthetic_request_generator_num_requests is not None: - for _ in range(self._config.synthetic_request_generator_num_requests): + elif self.config.num_requests is not None: + for _ in range(self.config.num_requests): request = self._generate_next_request(current_time) current_time = request.arrived_at requests.append(request) else: - assert self._config.synthetic_request_generator_interval_provider == "trace" + assert ( + self.config.interval_generator_config.get_type() + == RequestLengthGeneratorRegistry.TRACE + ) + while True: request = self._generate_next_request(current_time) if request is None: @@ -77,24 +83,24 @@ def _generate_requests(self) -> List[Request]: def generate_requests(self) -> List[Request]: assert ( - self._config.synthetic_request_generator_num_requests - or self._config.synthetic_request_generator_duration - or self._config.synthetic_request_generator_interval_provider == "trace" + self.config.num_requests + or self.config.duration + or self.config.interval_generator_config.get_type() + == RequestLengthGeneratorRegistry.TRACE ) - set_seeds(self._seed) + set_seeds(self.config.seed) requests = self._generate_requests() # sort requests by arrival time - requests.sort(key=lambda x: (x.arrived_at, x.id)) + requests.sort(key=lambda x: x.arrived_at) # remove any requests that arrived after the time limit - if self._config.synthetic_request_generator_duration is not None: + if self.config.duration is not None: requests = [ request for request in requests - if request.arrived_at - < self._config.synthetic_request_generator_duration + if request.arrived_at < self.config.duration ] return requests diff --git a/vidur/request_generator/trace_replay_request_generator.py b/vidur/request_generator/trace_replay_request_generator.py index 8283b29..971f2da 100644 --- a/vidur/request_generator/trace_replay_request_generator.py +++ b/vidur/request_generator/trace_replay_request_generator.py @@ -1,12 +1,15 @@ +import logging from typing import List import pandas as pd +from vidur.config import TraceRequestGeneratorConfig from vidur.entities import Request -from vidur.logger import init_logger -from vidur.request_generator.base_request_generator import BaseRequestGenerator +from vidur.request_generator.base_request_generator import ( + BaseRequestGenerator, +) -logger = init_logger(__name__) +logger = logging.getLogger(__name__) class TraceReplayRequestGenerator(BaseRequestGenerator): @@ -15,70 +18,62 @@ class TraceReplayRequestGenerator(BaseRequestGenerator): inter-request times, number of tokens. """ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, config: TraceRequestGeneratorConfig): + super().__init__(config) - self._trace_file = self._config.trace_request_generator_trace_file # load into a pd dataframe - self._trace_df = pd.read_csv(self._trace_file) + self.trace_df = pd.read_csv(config.trace_file) # restrict trace_df to be a subset of rows that have the same date - self._trace_df = self._trace_df[ - self._trace_df["Date"] == self._config.trace_request_generator_date - ] + self.trace_df = self.trace_df[self.trace_df["Date"] == config.date] # scale prefill and decode tokens - self._trace_df["PromptTokenCount"] = ( - self._trace_df["PromptTokenCount"] - * self._config.trace_request_generator_prefill_scale_factor + self.trace_df["PromptTokenCount"] = ( + self.trace_df["PromptTokenCount"] * config.prefill_scale_factor ) - self._trace_df["CompletionTokenCount"] = ( - self._trace_df["CompletionTokenCount"] - * self._config.trace_request_generator_decode_scale_factor + self.trace_df["CompletionTokenCount"] = ( + self.trace_df["CompletionTokenCount"] * config.decode_scale_factor ) # make sure all the prefill and decode counts are integers - self._trace_df["PromptTokenCount"] = self._trace_df["PromptTokenCount"].astype( + self.trace_df["PromptTokenCount"] = self.trace_df["PromptTokenCount"].astype( int ) - self._trace_df["CompletionTokenCount"] = self._trace_df[ + self.trace_df["CompletionTokenCount"] = self.trace_df[ "CompletionTokenCount" ].astype(int) # make sure that there is at least one prefill and decode token - self._trace_df["PromptTokenCount"] = self._trace_df["PromptTokenCount"].clip( + self.trace_df["PromptTokenCount"] = self.trace_df["PromptTokenCount"].clip( lower=1 ) - self._trace_df["CompletionTokenCount"] = self._trace_df[ + self.trace_df["CompletionTokenCount"] = self.trace_df[ "CompletionTokenCount" ].clip(lower=1) # make sure the total does not exceed the max tokens, adjust the prefill tokens if needed total_tokens = ( - self._trace_df["PromptTokenCount"] + self._trace_df["CompletionTokenCount"] + self.trace_df["PromptTokenCount"] + self.trace_df["CompletionTokenCount"] ) - diff_tokens = total_tokens - self._config.request_generator_max_tokens + diff_tokens = total_tokens - config.max_tokens diff_tokens = diff_tokens.clip(lower=0) - self._trace_df["PromptTokenCount"] = ( - self._trace_df["PromptTokenCount"] - diff_tokens + self.trace_df["PromptTokenCount"] = ( + self.trace_df["PromptTokenCount"] - diff_tokens ) assert all( - self._trace_df["PromptTokenCount"] + self._trace_df["CompletionTokenCount"] - <= self._config.request_generator_max_tokens + self.trace_df["PromptTokenCount"] + self.trace_df["CompletionTokenCount"] + <= config.max_tokens ) # rescale the time to change QPS - self._trace_df["Time"] = ( - self._trace_df["Time"] - * self._config.trace_request_generator_time_scale_factor - ) + self.trace_df["Time"] = self.trace_df["Time"] * config.time_scale_factor # compute pd ratio and log the 25, 50, 75, 90, 95, 99 percentiles pd_ratio = ( - self._trace_df["PromptTokenCount"] / self._trace_df["CompletionTokenCount"] + self.trace_df["PromptTokenCount"] / self.trace_df["CompletionTokenCount"] ) logger.info( - f"Loaded trace file {self._trace_file} with {len(self._trace_df)} requests" + f"Loaded trace file {config.trace_file} with {len(self.trace_df)} requests" ) logger.info( f"Prompt/decode token ratio stats\n:{pd_ratio.describe(percentiles=[0.25, 0.5, 0.75, 0.9, 0.95, 0.99])}" @@ -87,7 +82,7 @@ def __init__(self, *args, **kwargs): def generate_requests(self) -> List[Request]: requests = [] - for _, row in self._trace_df.iterrows(): + for _, row in self.trace_df.iterrows(): request = Request( arrived_at=row["Time"], num_prefill_tokens=row["PromptTokenCount"], diff --git a/vidur/request_generator/trace_request_interval_generator.py b/vidur/request_generator/trace_request_interval_generator.py index 59ddf39..f5ad0ff 100644 --- a/vidur/request_generator/trace_request_interval_generator.py +++ b/vidur/request_generator/trace_request_interval_generator.py @@ -1,11 +1,13 @@ +import logging + import pandas as pd -from vidur.logger import init_logger +from vidur.config import TraceRequestIntervalGeneratorConfig from vidur.request_generator.base_request_interval_generator import ( BaseRequestIntervalGenerator, ) -logger = init_logger(__name__) +logger = logging.getLogger(__name__) class TraceRequestIntervalGenerator(BaseRequestIntervalGenerator): @@ -14,52 +16,45 @@ class TraceRequestIntervalGenerator(BaseRequestIntervalGenerator): inter-request times, number of tokens. """ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, config: TraceRequestIntervalGeneratorConfig): + super().__init__(config) - trace_file = self._config.trace_request_interval_generator_trace_file # load into a pd dataframe - self._trace_df = pd.read_csv(trace_file) + self.trace_df = pd.read_csv(config.trace_file) - self._trace_df["arrival_time"] = pd.to_datetime(self._trace_df["arrival_time"]) + self.trace_df["arrival_time"] = pd.to_datetime(self.trace_df["arrival_time"]) # restrict trace_df to be a subset of rows that have the same date - self._trace_df = self._trace_df[ - ( - self._trace_df["arrival_time"] - > self._config.trace_request_interval_generator_start_time - ) - & ( - self._trace_df["arrival_time"] - < self._config.trace_request_interval_generator_end_time - ) + self.trace_df = self.trace_df[ + (self.trace_df["arrival_time"] > config.start_time) + & (self.trace_df["arrival_time"] < config.end_time) ] # change back to seconds - self._trace_df["arrival_time"] = ( - self._trace_df["arrival_time"] - self._trace_df["arrival_time"].min() + self.trace_df["arrival_time"] = ( + self.trace_df["arrival_time"] - self.trace_df["arrival_time"].min() ) // pd.Timedelta("1s") # rescale the time to change QPS - self._trace_df["arrival_time"] = ( - self._trace_df["arrival_time"] - * self._config.trace_request_interval_generator_time_scale_factor + self.trace_df["arrival_time"] = ( + self.trace_df["arrival_time"] * config.time_scale_factor ) # compute the inter-request time - self._trace_df["inter_request_time"] = self._trace_df["arrival_time"].diff() + self.trace_df["inter_request_time"] = self.trace_df["arrival_time"].diff() - self._next_request_idx = 1 + self.next_request_idx = 1 logger.info( - f"Loaded interval trace file {trace_file} with {len(self._trace_df)} requests" + f"Loaded interval trace file {config.trace_file} with {len(self.trace_df)} requests" ) def get_next_inter_request_time(self) -> float: - if self._next_request_idx >= len(self._trace_df): + if self.next_request_idx >= len(self.trace_df): return None - inter_request_time = self._trace_df.iloc[self._next_request_idx][ + inter_request_time = self.trace_df.iloc[self.next_request_idx][ "inter_request_time" ] - self._next_request_idx += 1 + self.next_request_idx += 1 + return inter_request_time diff --git a/vidur/request_generator/trace_request_length_generator.py b/vidur/request_generator/trace_request_length_generator.py index 6931558..e21c6e9 100644 --- a/vidur/request_generator/trace_request_length_generator.py +++ b/vidur/request_generator/trace_request_length_generator.py @@ -1,106 +1,98 @@ +import logging from typing import Tuple import numpy as np import pandas as pd -from vidur.logger import init_logger +from vidur.config import TraceRequestLengthGeneratorConfig from vidur.request_generator.base_request_length_generator import ( BaseRequestLengthGenerator, ) -logger = init_logger(__name__) +logger = logging.getLogger(__name__) class TraceRequestLengthGenerator(BaseRequestLengthGenerator): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - trace_file = self._config.trace_request_length_generator_trace_file - self._trace_df = pd.read_csv(trace_file) + def __init__(self, config: TraceRequestLengthGeneratorConfig): + super().__init__(config) + + self.trace_df = pd.read_csv(config.trace_file) # scale prefill and decode tokens - self._trace_df["num_prefill_tokens"] = ( - self._trace_df["num_prefill_tokens"] - * self._config.trace_request_length_generator_prefill_scale_factor + self.trace_df["num_prefill_tokens"] = ( + self.trace_df["num_prefill_tokens"] * config.prefill_scale_factor ) - self._trace_df["num_decode_tokens"] = ( - self._trace_df["num_decode_tokens"] - * self._config.trace_request_length_generator_decode_scale_factor + self.trace_df["num_decode_tokens"] = ( + self.trace_df["num_decode_tokens"] * config.decode_scale_factor ) # make sure all the prefill and decode counts are integers - self._trace_df["num_prefill_tokens"] = self._trace_df[ + self.trace_df["num_prefill_tokens"] = self.trace_df[ "num_prefill_tokens" ].astype(int) - self._trace_df["num_decode_tokens"] = self._trace_df[ - "num_decode_tokens" - ].astype(int) + self.trace_df["num_decode_tokens"] = self.trace_df["num_decode_tokens"].astype( + int + ) # make sure the total does not exceed the max tokens, adjust the prefill tokens if needed total_tokens = ( - self._trace_df["num_prefill_tokens"] + self._trace_df["num_decode_tokens"] + self.trace_df["num_prefill_tokens"] + self.trace_df["num_decode_tokens"] ) - diff_tokens = total_tokens - self._config.request_generator_max_tokens + diff_tokens = total_tokens - config.max_tokens diff_tokens = diff_tokens.clip(lower=0) - # dedcut the diff tokens from the prefill and decode tokens proportionally - prefill_tokens_ratio = self._trace_df["num_prefill_tokens"] / total_tokens - decode_tokens_ratio = self._trace_df["num_decode_tokens"] / total_tokens + # deduct the diff tokens from the prefill and decode tokens proportionally + prefill_tokens_ratio = self.trace_df["num_prefill_tokens"] / total_tokens + decode_tokens_ratio = self.trace_df["num_decode_tokens"] / total_tokens - self._trace_df["num_prefill_tokens"] -= ( + self.trace_df["num_prefill_tokens"] -= ( np.ceil(diff_tokens * prefill_tokens_ratio) ).astype(int) - self._trace_df["num_decode_tokens"] -= ( + self.trace_df["num_decode_tokens"] -= ( np.ceil(diff_tokens * decode_tokens_ratio) ).astype(int) # make sure that there is at least one prefill and decode token - self._trace_df["num_prefill_tokens"] = self._trace_df[ - "num_prefill_tokens" - ].clip(lower=1) - self._trace_df["num_decode_tokens"] = self._trace_df["num_decode_tokens"].clip( + self.trace_df["num_prefill_tokens"] = self.trace_df["num_prefill_tokens"].clip( + lower=1 + ) + self.trace_df["num_decode_tokens"] = self.trace_df["num_decode_tokens"].clip( lower=1 ) assert all( - self._trace_df["num_prefill_tokens"] + self._trace_df["num_decode_tokens"] - <= self._config.request_generator_max_tokens + self.trace_df["num_prefill_tokens"] + self.trace_df["num_decode_tokens"] + <= self.config.max_tokens ) - assert all(self._trace_df["num_prefill_tokens"] > 0) + assert all(self.trace_df["num_prefill_tokens"] > 0) - assert all(self._trace_df["num_decode_tokens"] > 0) + assert all(self.trace_df["num_decode_tokens"] > 0) # compute pd ratio and log the 25, 50, 75, 90, 95, 99 percentiles pd_ratio = ( - self._trace_df["num_prefill_tokens"] / self._trace_df["num_decode_tokens"] + self.trace_df["num_prefill_tokens"] / self.trace_df["num_decode_tokens"] ) - percentiles = [0.25, 0.5, 0.75, 0.9, 0.95, 0.99] - logger.info( - f"Loaded request length trace file {trace_file} with {len(self._trace_df)} requests" - ) - logger.debug( - f"Prompt token stats\n:{self._trace_df['num_prefill_tokens'].describe(percentiles=percentiles)}" - ) - logger.debug( - f"Decode token stats\n:{self._trace_df['num_decode_tokens'].describe(percentiles=percentiles)}" + f"Loaded request length trace file {config.trace_file} with {len(self.trace_df)} requests" ) - logger.debug( - f"Prompt/decode token ratio stats\n:{pd_ratio.describe(percentiles=percentiles)}" + pd_distribution = pd_ratio.describe( + percentiles=[0.25, 0.5, 0.75, 0.9, 0.95, 0.99] ) + logger.debug(f"Prompt/decode token ratio stats\n: {pd_distribution}") # randomly shuffle the df based on the seed - self._trace_df = self._trace_df.sample(frac=1, random_state=self._config.seed) - self._next_request_idx = 0 + self.trace_df = self.trace_df.sample(frac=1, random_state=self.config.seed) + self.next_request_idx = 0 def get_next_num_tokens(self) -> Tuple[float, float]: - if self._next_request_idx >= len(self._trace_df): + if self.next_request_idx >= len(self.trace_df): return None, None - row = self._trace_df.iloc[self._next_request_idx] - self._next_request_idx += 1 + row = self.trace_df.iloc[self.next_request_idx] + self.next_request_idx += 1 return ( row["num_prefill_tokens"], diff --git a/vidur/request_generator/uniform_request_length_generator.py b/vidur/request_generator/uniform_request_length_generator.py index 8ad5368..4f1e6c6 100644 --- a/vidur/request_generator/uniform_request_length_generator.py +++ b/vidur/request_generator/uniform_request_length_generator.py @@ -8,15 +8,15 @@ class UniformRequestLengthGenerator(BaseRequestLengthGenerator): + def get_next_num_tokens(self) -> Tuple[float, float]: total_tokens = random.uniform( - self._config.synthetic_request_generator_min_tokens, - self._config.request_generator_max_tokens, + self.config.min_tokens, + self.config.max_tokens, ) decode_tokens = math.ceil( - total_tokens - / (1 + self._config.synthetic_request_generator_prefill_to_decode_ratio) + total_tokens / (1 + self.config.prefill_to_decode_ratio) ) prefill_tokens = total_tokens - decode_tokens assert prefill_tokens > 0 and decode_tokens > 0 diff --git a/vidur/request_generator/zipf_request_length_generator.py b/vidur/request_generator/zipf_request_length_generator.py index 43a68ba..fc08ccd 100644 --- a/vidur/request_generator/zipf_request_length_generator.py +++ b/vidur/request_generator/zipf_request_length_generator.py @@ -1,5 +1,6 @@ from typing import Tuple +from vidur.config import ZipfRequestLengthGeneratorConfig from vidur.request_generator.base_request_length_generator import ( BaseRequestLengthGenerator, ) @@ -7,23 +8,22 @@ class ZipfRequestLengthGenerator(BaseRequestLengthGenerator): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - self._zipf_generator = ZipfGenerator( - self._config.synthetic_request_generator_min_tokens, - self._config.request_generator_max_tokens, - self._config.zipf_request_length_generator_theta, - self._config.zipf_request_length_generator_scramble, - self._config.seed, + + def __init__(self, config: ZipfRequestLengthGeneratorConfig): + super().__init__(config) + + self.zipf_generator = ZipfGenerator( + config.min_tokens, + config.max_tokens, + config.theta, + config.scramble, + config.seed, ) def get_next_num_tokens(self) -> Tuple[float, float]: - total_tokens = self._zipf_generator.next() + total_tokens = self.zipf_generator.next() - decode_tokens = total_tokens / ( - 1 + self._config.synthetic_request_generator_prefill_to_decode_ratio - ) + decode_tokens = total_tokens / (1 + self.config.prefill_to_decode_ratio) prefill_tokens = total_tokens - decode_tokens return prefill_tokens, decode_tokens diff --git a/vidur/scheduler/global_scheduler/base_global_scheduler.py b/vidur/scheduler/global_scheduler/base_global_scheduler.py index be047c9..6e35e21 100644 --- a/vidur/scheduler/global_scheduler/base_global_scheduler.py +++ b/vidur/scheduler/global_scheduler/base_global_scheduler.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Dict, List, Tuple -from vidur.config import Config +from vidur.config import SimulationConfig from vidur.entities import Replica, Request from vidur.execution_time_predictor import ExecutionTimePredictorRegistry from vidur.scheduler.replica_scheduler.replica_scheduler_registry import ( @@ -10,19 +10,19 @@ class BaseGlobalScheduler(ABC): - def __init__(self, config: Config, replicas: Dict[int, Replica]): + def __init__(self, config: SimulationConfig, replicas: Dict[int, Replica]): self._config = config self._replicas = replicas self._num_replicas = len(self._replicas) - execution_time_predictor = ExecutionTimePredictorRegistry.get_from_str( - self._config.execution_time_predictor_provider, + execution_time_predictor = ExecutionTimePredictorRegistry.get( + self._config.cluster_config.execution_time_predictor_config.get_type(), self._config, ) self._replica_schedulers = { - replica_id: ReplicaSchedulerRegistry.get_from_str( - config.replica_scheduler_provider, + replica_id: ReplicaSchedulerRegistry.get( + config.cluster_config.replica_scheduler_config.get_type(), config, replica, replica.num_pipeline_stages, diff --git a/vidur/scheduler/replica_scheduler/base_replica_scheduler.py b/vidur/scheduler/replica_scheduler/base_replica_scheduler.py index a0b2f6c..c1b40fd 100644 --- a/vidur/scheduler/replica_scheduler/base_replica_scheduler.py +++ b/vidur/scheduler/replica_scheduler/base_replica_scheduler.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import List -from vidur.config import Config +from vidur.config import SimulationConfig from vidur.entities import Batch, Replica, Request from vidur.execution_time_predictor import BaseExecutionTimePredictor from vidur.logger import init_logger @@ -14,7 +14,7 @@ class BaseReplicaScheduler(ABC): def __init__( self, - config: Config, + config: SimulationConfig, replica: Replica, num_stages: int, execution_time_predictor: BaseExecutionTimePredictor, @@ -24,15 +24,15 @@ def __init__( self._num_stages = num_stages # store config variables - self._block_size = self._config.replica_block_size + self._block_size = self._config.cluster_config.replica_scheduler_config.block_size self._max_blocks_per_sequence = ( - self._config.request_generator_max_tokens // self._block_size + self._config.request_generator_config.max_tokens // self._block_size ) memory_planner = MemoryPlanner(config, replica) - self._num_total_blocks = config.replica_scheduler_num_blocks + self._num_total_blocks = config.cluster_config.replica_scheduler_config.num_blocks if not self._num_total_blocks: self._num_total_blocks = ( @@ -40,7 +40,7 @@ def __init__( ) self._max_batch_size = min( memory_planner.get_max_batch_size(), - config.replica_scheduler_batch_size_cap, + config.cluster_config.replica_scheduler_config.batch_size_cap, ) logger.debug( diff --git a/vidur/scheduler/replica_scheduler/lightllm_replica_scheduler.py b/vidur/scheduler/replica_scheduler/lightllm_replica_scheduler.py index 526d58c..87db3b8 100644 --- a/vidur/scheduler/replica_scheduler/lightllm_replica_scheduler.py +++ b/vidur/scheduler/replica_scheduler/lightllm_replica_scheduler.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Tuple import numpy as np @@ -14,11 +14,11 @@ def __init__(self, *args, **kwargs): self._preempted_requests: List[Request] = [] self._num_running_batches = 0 - self._max_tokens_in_batch = self._config.lightllm_scheduler_max_tokens_in_batch - self._max_waiting_iters = self._config.lightllm_scheduler_max_waiting_iters - self._max_batch_size = self._config.replica_scheduler_batch_size_cap + self._max_tokens_in_batch = self._config.cluster_config.replica_scheduler_config.max_tokens_in_batch + self._max_waiting_iters = self._config.cluster_config.replica_scheduler_config.max_waiting_iters + self._max_batch_size = self._config.cluster_config.replica_scheduler_config.batch_size_cap self._max_micro_batch_size = ( - self._config.replica_scheduler_batch_size_cap // self._num_stages + self._max_batch_size // self._num_stages ) assert ( self._block_size == 1 @@ -39,7 +39,7 @@ def on_batch_end(self, batch: Batch) -> None: else: self._preempted_requests.append(request) - def _get_tuple_tokens(self, request: Request) -> (int, int): + def _get_tuple_tokens(self, request: Request) -> Tuple[int, int]: if request.scheduled: num_processed_tokens = request.num_processed_tokens remaining_tokens = ( diff --git a/vidur/scheduler/replica_scheduler/replica_scheduler_registry.py b/vidur/scheduler/replica_scheduler/replica_scheduler_registry.py index 85d15fb..6a9eb9b 100644 --- a/vidur/scheduler/replica_scheduler/replica_scheduler_registry.py +++ b/vidur/scheduler/replica_scheduler/replica_scheduler_registry.py @@ -18,9 +18,7 @@ class ReplicaSchedulerRegistry(BaseRegistry): - @classmethod - def get_key_from_str(cls, key_str: str) -> ReplicaSchedulerType: - return ReplicaSchedulerType.from_str(key_str) + pass ReplicaSchedulerRegistry.register( diff --git a/vidur/scheduler/replica_scheduler/sarathi_replica_scheduler.py b/vidur/scheduler/replica_scheduler/sarathi_replica_scheduler.py index 60288ef..274bdb5 100644 --- a/vidur/scheduler/replica_scheduler/sarathi_replica_scheduler.py +++ b/vidur/scheduler/replica_scheduler/sarathi_replica_scheduler.py @@ -13,16 +13,16 @@ def __init__(self, *args, **kwargs): # sarathi config self._num_running_batches = 0 self._preempted_requests = [] - self._chunk_size = self._config.sarathi_scheduler_chunk_size + self._chunk_size = self._config.cluster_config.replica_scheduler_config.chunk_size # vLLM config self._watermark_blocks_fraction = ( - self._config.sarathi_scheduler_watermark_blocks_fraction + self._config.cluster_config.replica_scheduler_config.watermark_blocks_fraction ) # For vLLM and its derivatives, we only need to set a loose max batch size # Memory requirements are handled explicitly by the scheduler - self._max_batch_size = self._config.replica_scheduler_batch_size_cap + self._max_batch_size = self._config.cluster_config.replica_scheduler_config.batch_size_cap self._max_micro_batch_size = ( - self._config.replica_scheduler_batch_size_cap // self._num_stages + self._max_batch_size // self._num_stages ) self._watermark_blocks = int( self._watermark_blocks_fraction * self._num_total_blocks diff --git a/vidur/scheduler/replica_scheduler/vllm_replica_scheduler.py b/vidur/scheduler/replica_scheduler/vllm_replica_scheduler.py index c84f33b..5b6f6cc 100644 --- a/vidur/scheduler/replica_scheduler/vllm_replica_scheduler.py +++ b/vidur/scheduler/replica_scheduler/vllm_replica_scheduler.py @@ -14,14 +14,14 @@ def __init__(self, *args, **kwargs): self._preempted_requests: List[Request] = [] self._num_running_batches = 0 self._watermark_blocks_fraction = ( - self._config.vllm_scheduler_watermark_blocks_fraction + self._config.cluster_config.replica_scheduler_config.watermark_blocks_fraction ) - self._max_tokens_in_batch = self._config.vllm_scheduler_max_tokens_in_batch + self._max_tokens_in_batch = self._config.cluster_config.replica_scheduler_config.max_tokens_in_batch # For vLLM and its derivatives, we only need to set a loose max batch size # Memory requirements are handled explicitly by the scheduler - self._max_batch_size = self._config.replica_scheduler_batch_size_cap + self._max_batch_size = self._config.cluster_config.replica_scheduler_config.batch_size_cap self._max_micro_batch_size = ( - self._config.replica_scheduler_batch_size_cap // self._num_stages + self._max_batch_size // self._num_stages ) self._watermark_blocks = int( self._watermark_blocks_fraction * self._num_total_blocks diff --git a/vidur/scheduler/utils/memory_planner.py b/vidur/scheduler/utils/memory_planner.py index 874deaf..ae79b21 100644 --- a/vidur/scheduler/utils/memory_planner.py +++ b/vidur/scheduler/utils/memory_planner.py @@ -1,10 +1,10 @@ -from vidur.config import Config +from vidur.config import SimulationConfig from vidur.entities.replica import Replica from vidur.utils.param_counter import ParamCounter class MemoryPlanner: - def __init__(self, config: Config, replica: Replica) -> None: + def __init__(self, config: SimulationConfig, replica: Replica) -> None: self._param_counter = ParamCounter(config) self._replica = replica diff --git a/vidur/simulator.py b/vidur/simulator.py index 02d77ad..311dd30 100644 --- a/vidur/simulator.py +++ b/vidur/simulator.py @@ -3,7 +3,7 @@ import json from typing import List -from vidur.config import Config +from vidur.config import SimulationConfig from vidur.entities import Cluster from vidur.events import BaseEvent, RequestArrivalEvent from vidur.logger import init_logger @@ -15,30 +15,32 @@ class Simulator: - def __init__(self, config: Config) -> None: - self._config = config + def __init__(self, config: SimulationConfig) -> None: + self._config: SimulationConfig = config self._time = 0 self._terminate = False - self._time_limit = self._config.simulator_time_limit + self._time_limit = self._config.time_limit if not self._time_limit: self._time_limit = float("inf") self._event_queue = [] - self._should_write_json_trace = self._config.write_json_trace - self._should_write_chrome_trace = self._config.write_chrome_trace + self._should_write_json_trace = self._config.cluster_config.metrics_config.write_json_trace + self._should_write_chrome_trace = self._config.cluster_config.metrics_config.enable_chrome_trace self._event_trace = [] self._event_chrome_trace = [] self._cluster = Cluster(self._config) self._metric_store = MetricsStore(self._config) - self._request_generator = RequestGeneratorRegistry.get_from_str( - self._config.request_generator_provider, self._config + self._request_generator = RequestGeneratorRegistry.get( + self._config.request_generator_config.get_type(), + self._config.request_generator_config, ) - self._scheduler = GlobalSchedulerRegistry.get_from_str( - self._config.global_scheduler_provider, self._config, self._cluster.replicas + self._scheduler = GlobalSchedulerRegistry.get( + self._config.cluster_config.global_scheduler_config.get_type(), + self._config, self._cluster.replicas ) self._init_event_queue() From 16525b68278a83656f1c409905888f56d2fb93eb Mon Sep 17 00:00:00 2001 From: Anmol Agarwal Date: Thu, 25 Jul 2024 03:42:02 -0400 Subject: [PATCH 03/24] Almost complete, testing needed --- vidur/config/config.py | 94 ++++++++++++++++--- vidur/config/constants.py | 20 ++++ vidur/config/device_sku_config.py | 7 +- vidur/config/flat_dataclass.py | 2 +- vidur/config/model_config.py | 81 ++++++++++++++++ .../config_explorer/config/config.py | 65 +++++++------ vidur/constants.py | 8 -- vidur/entities/replica.py | 37 ++++---- .../base_execution_time_predictor.py | 17 ++-- ...ear_regression_execution_time_predictor.py | 12 ++- ...random_forrest_execution_time_predictor.py | 12 +-- .../sklearn_execution_time_predictor.py | 72 +++++++------- vidur/profiling/common/model_config.py | 12 +-- vidur/types/model_type.py | 8 +- vidur/utils/mfu_calculator.py | 18 ++-- vidur/utils/param_counter.py | 26 ++--- 16 files changed, 338 insertions(+), 153 deletions(-) create mode 100644 vidur/config/constants.py delete mode 100644 vidur/constants.py diff --git a/vidur/config/config.py b/vidur/config/config.py index 074ea6f..04cf0b1 100644 --- a/vidur/config/config.py +++ b/vidur/config/config.py @@ -4,7 +4,12 @@ from typing import Optional, List from vidur.config.base_poly_config import BasePolyConfig +from vidur.config.constants import MODEL_NAME_MAPPING, NETWORK_DEVICE_MAPPING +from vidur.config.device_sku_config import BaseDeviceSKUConfig from vidur.config.flat_dataclass import create_flat_dataclass +from vidur.config.model_config import BaseModelConfig +from vidur.config.node_sku_config import BaseNodeSKUConfig +from vidur.config.utils import get_all_subclasses from vidur.logger import init_logger from vidur.types import ReplicaSchedulerType, GlobalSchedulerType, ExecutionTimePredictorType, RequestGeneratorType, RequestIntervalGeneratorType, RequestLengthGeneratorType @@ -252,7 +257,7 @@ class TraceRequestGeneratorConfig(BaseRequestGeneratorConfig): @staticmethod def get_type(): - return RequestGeneratorType.TRACE + return RequestGeneratorType.TRACE_REPLAY @dataclass @@ -261,10 +266,6 @@ class BaseReplicaSchedulerConfig(BasePolyConfig): default=128, metadata={"help": "Maximum number of sequences."}, ) - num_pipeline_stages: int = field( - default=1, - metadata={"help": "Number of pipeline stages."}, - ) watermark_blocks_fraction: float = field( default=0.01, metadata={"help": "Watermark blocks fraction."}, @@ -282,10 +283,6 @@ class BaseReplicaSchedulerConfig(BasePolyConfig): metadata={"help": "Maximum batch size cap."}, ) - @abstractmethod - def get_max_num_batched_tokens(self): - pass - @dataclass class VllmSchedulerConfig(BaseReplicaSchedulerConfig): @@ -443,9 +440,13 @@ class ReplicaConfig: default=0.8, metadata={"help": "GPU memory utilization."}, ) - pipeline_parallel_size: int = field( + memory_margin_fraction: float = field( + default=0.1, + metadata={"help": "Memory margin fraction."}, + ) + num_pipeline_stages: int = field( default=1, - metadata={"help": "Pipeline parallel size."}, + metadata={"help": "Number of pipeline stages."}, ) tensor_parallel_size: int = field( default=1, @@ -460,8 +461,40 @@ class ReplicaConfig: metadata={"help": "Network device."}, ) + def get_model_config(model_name: str) -> BaseModelConfig: + model_configs = get_all_subclasses(BaseModelConfig) + if model_name not in MODEL_NAME_MAPPING: + raise ValueError(f"Model name not found: {model_name}") + model_type = MODEL_NAME_MAPPING[model_name] + for model_config in model_configs: + if model_config.get_type() == model_type: + return model_config + return ValueError(f"Model config not found for model name: {model_name}") + + + def get_device_config(device_name: str) -> BaseDeviceSKUConfig: + device_configs = get_all_subclasses(BaseDeviceSKUConfig) + for device_config in device_configs: + if str(device_config.get_type()) == device_name: + return device_config + raise ValueError(f"Device config not found for device name: {device_name}") + + + def get_node_config(network_device: str) -> BaseNodeSKUConfig: + node_configs = get_all_subclasses(BaseNodeSKUConfig) + if network_device not in NETWORK_DEVICE_MAPPING: + raise ValueError(f"Network device not found: {network_device}") + network_type = NETWORK_DEVICE_MAPPING[network_device] + for node_config in node_configs: + if node_config.get_type() == network_type: + return node_config + raise ValueError(f"Node config not found for network device: {network_device}") + def __post_init__(self): - self.world_size = self.pipeline_parallel_size * self.tensor_parallel_size + self.world_size = self.num_pipeline_stages * self.tensor_parallel_size + self.model_config: BaseModelConfig = self.get_model_config(self.model_name) + self.device_config: BaseDeviceSKUConfig = self.get_device_config(self.device) + self.node_config: BaseNodeSKUConfig = self.get_node_config(self.network_device) @dataclass @@ -630,6 +663,43 @@ class ClusterConfig: metadata={"help": "Metrics config."}, ) + def update_predictor_config(self): + if "{DEVICE}" in self.execution_time_predictor_config.compute_input_file: + self.execution_time_predictor_config.compute_input_file = self.execution_time_predictor_config.compute_input_file.replace( + "{DEVICE}", self.replica_config.device + ) + if "{MODEL}" in self.execution_time_predictor_config.compute_input_file: + self.execution_time_predictor_config.compute_input_file = self.execution_time_predictor_config.compute_input_file.replace( + "{MODEL}", self.replica_config.model_name + ) + if "{DEVICE}" in self.execution_time_predictor_config.attention_input_file: + self.execution_time_predictor_config.attention_input_file = self.execution_time_predictor_config.attention_input_file.replace( + "{DEVICE}", self.replica_config.device + ) + if "{MODEL}" in self.execution_time_predictor_config.attention_input_file: + self.execution_time_predictor_config.attention_input_file = self.execution_time_predictor_config.attention_input_file.replace( + "{MODEL}", self.replica_config.model_name + ) + if "{NETWORK_DEVICE}" in self.execution_time_predictor_config.all_reduce_input_file: + self.execution_time_predictor_config.all_reduce_input_file = self.execution_time_predictor_config.all_reduce_input_file.replace( + "{NETWORK_DEVICE}", self.replica_config.network_device + ) + if "{NETWORK_DEVICE}" in self.execution_time_predictor_config.send_recv_input_file: + self.execution_time_predictor_config.send_recv_input_file = self.execution_time_predictor_config.send_recv_input_file.replace( + "{NETWORK_DEVICE}", self.replica_config.network_device + ) + if "{NETWORK_DEVICE}" in self.execution_time_predictor_config.cpu_overhead_input_file: + self.execution_time_predictor_config.cpu_overhead_input_file = self.execution_time_predictor_config.cpu_overhead_input_file.replace( + "{NETWORK_DEVICE}", self.replica_config.network_device + ) + if "{MODEL}" in self.execution_time_predictor_config.cpu_overhead_input_file: + self.execution_time_predictor_config.cpu_overhead_input_file = self.execution_time_predictor_config.cpu_overhead_input_file.replace( + "{MODEL}", self.replica_config.model_name + ) + + def __post_init__(self): + self.update_predictor_config() + @dataclass class SimulationConfig(ABC): diff --git a/vidur/config/constants.py b/vidur/config/constants.py new file mode 100644 index 0000000..f7ea806 --- /dev/null +++ b/vidur/config/constants.py @@ -0,0 +1,20 @@ +from vidur.types import ModelType, NodeSKUType + +MODEL_NAME_MAPPING = { + "Qwen/Qwen-72B": ModelType.QWEN_72B, + "codellama/CodeLlama-34b-Instruct-hf": ModelType.CODE_LLAMA_34B, + "internlm/internlm2-20b": ModelType.INTERNLM_2_20B, + "meta-llama/Llama-2-7b-hf": ModelType.LLAMA_2_7B, + "meta-llama/Llama-2-70b-hf": ModelType.LLAMA_2_70B, + "meta-llama/Meta-Llama-3-8b": ModelType.LLAMA_3_8B, + "meta-llama/Meta-Llama-3-70B": ModelType.LLAMA_3_70B, + "microsoft/phi-2": ModelType.PHI2, +} + +NETWORK_DEVICE_MAPPING = { + "a40_pair_nvlink": NodeSKUType.A40_PAIRWISE_NVLINK, + "a100_pair_nvlink": NodeSKUType.A100_PAIRWISE_NVLINK, + "h100_pair_nvlink": NodeSKUType.H100_PAIRWISE_NVLINK, + "a100_dgx": NodeSKUType.A100_DGX, + "h100_dgx": NodeSKUType.H100_DGX, +} diff --git a/vidur/config/device_sku_config.py b/vidur/config/device_sku_config.py index 73c0cb7..83e1bb1 100644 --- a/vidur/config/device_sku_config.py +++ b/vidur/config/device_sku_config.py @@ -15,9 +15,6 @@ class BaseDeviceSKUConfig(BasePolyConfig): total_memory_gb: int = field( metadata={"help": "The total memory of the device in GB"}, ) - num_devices_per_node: int = field( - metadata={"help": "The number of devices per node"}, - ) @dataclass @@ -33,7 +30,7 @@ class A100DeviceSKUConfig(BaseDeviceSKUConfig): @staticmethod def get_type(): - return DeviceSKUType.A100 + return DeviceSKUType.A40 @dataclass @@ -65,5 +62,5 @@ class H100DeviceSKUConfig(BaseDeviceSKUConfig): @staticmethod def get_type(): - return DeviceSKUType.A100 + return DeviceSKUType.H100 diff --git a/vidur/config/flat_dataclass.py b/vidur/config/flat_dataclass.py index c1ae755..771a262 100644 --- a/vidur/config/flat_dataclass.py +++ b/vidur/config/flat_dataclass.py @@ -181,7 +181,7 @@ def process_dataclass(_input_dataclass, prefix=""): ) elif field_default_factory is not MISSING: meta_fields_with_defaults.append( - (prefixed_name, field_type, field_default_factory()) + (prefixed_name, field_type, field_default_factory) ) else: meta_fields_without_defaults.append((prefixed_name, field_type)) diff --git a/vidur/config/model_config.py b/vidur/config/model_config.py index ce83610..6a0bf10 100644 --- a/vidur/config/model_config.py +++ b/vidur/config/model_config.py @@ -211,6 +211,87 @@ def get_type(): return ModelType.LLAMA_2_70B +@dataclass +class Llama3_8BModelConfig(Llama2ModelConfig): + num_layers: int = field( + default=32, + metadata={"help": "The number of layers in the model"}, + ) + num_q_heads: int = field( + default=32, + metadata={"help": "The number of query heads in the model"}, + ) + num_kv_heads: int = field( + default=8, + metadata={"help": "The number of key-value heads in the model"}, + ) + embedding_dim: int = field( + default=4096, + metadata={"help": "The embedding dimension of the model"}, + ) + mlp_hidden_dim: int = field( + default=14336, + metadata={"help": "The hidden dimension of the MLP in the model"}, + ) + max_position_embeddings: int = field( + default=4096, + metadata={"help": "The maximum position embeddings in the model"}, + ) + rope_theta: Optional[int] = field( + default=500000.0, + metadata={"help": "The rope theta in the model"}, + ) + vocab_size: int = field( + default=128256, + metadata={"help": "The vocabulary size of the model"}, + ) + + + @staticmethod + def get_type(): + return ModelType.LLAMA_3_70B + + +@dataclass +class Llama3_70BModelConfig(Llama2ModelConfig): + num_layers: int = field( + default=80, + metadata={"help": "The number of layers in the model"}, + ) + num_q_heads: int = field( + default=64, + metadata={"help": "The number of query heads in the model"}, + ) + num_kv_heads: int = field( + default=8, + metadata={"help": "The number of key-value heads in the model"}, + ) + embedding_dim: int = field( + default=8192, + metadata={"help": "The embedding dimension of the model"}, + ) + mlp_hidden_dim: int = field( + default=28672, + metadata={"help": "The hidden dimension of the MLP in the model"}, + ) + max_position_embeddings: int = field( + default=8192, + metadata={"help": "The maximum position embeddings in the model"}, + ) + rope_theta: Optional[int] = field( + default=500000.0, + metadata={"help": "The rope theta in the model"}, + ) + vocab_size: int = field( + default=128256, + metadata={"help": "The vocabulary size of the model"}, + ) + + + @staticmethod + def get_type(): + return ModelType.LLAMA_3_70B + @dataclass class InternLM2ModelConfig(Llama2ModelConfig): max_position_embeddings: int = field( diff --git a/vidur/config_optimizer/config_explorer/config/config.py b/vidur/config_optimizer/config_explorer/config/config.py index db4da92..f92296b 100644 --- a/vidur/config_optimizer/config_explorer/config/config.py +++ b/vidur/config_optimizer/config_explorer/config/config.py @@ -15,7 +15,7 @@ def get_key(self): def to_config_dict(self): return { - "replica_model_name": self.identifier, + "replica_config_model_name": self.identifier, } def is_tensor_parallel_degree_valid(self, tp_degree: int): @@ -35,15 +35,20 @@ def get_key(self): def to_config_dict(self): return { - "request_generator_provider": "synthetic", - "synthetic_request_generator_length_provider": "trace", - "synthetic_request_generator_interval_provider": "poisson", - "request_generator_max_tokens": self.max_seq_len, - "trace_request_length_generator_trace_file": self.trace_file, - "trace_request_length_generator_prefill_scale_factor": 1, - "trace_request_length_generator_decode_scale_factor": 1, - "synthetic_request_generator_num_requests": self.num_requests, - "vllm_scheduler_max_tokens_in_batch": self.max_seq_len, + "request_generator_config_type": "synthetic", + "length_generator_config_type": "trace", + "interval_generator_config_type": "poisson", + "synthetic_request_generator_config_max_tokens": self.max_seq_len, + "trace_request_length_generator_config_max_tokens": self.max_seq_len, + "zipf_request_length_generator_config_max_tokens": self.max_seq_len, + "uniform_request_length_generator_config_max_tokens": self.max_seq_len, + "fixed_request_length_generator_config_max_tokens": self.max_seq_len, + "trace_request_generator_config_max_tokens": self.max_seq_len, + "trace_request_length_generator_config_trace_file": self.trace_file, + "trace_request_length_generator_config_prefill_scale_factor": 1, + "trace_request_length_generator_config_decode_scale_factor": 1, + "synthetic_request_generator_config_num_requests": self.num_requests, + "vllm_scheduler_config_max_tokens_in_batch": self.max_seq_len, } @@ -58,7 +63,7 @@ def get_key(self): def to_config_dict(self): return { - "replica_device": self.device, + "replica_config_device": self.device, } @@ -78,16 +83,14 @@ def get_key(self): def to_config_dict(self): if self.scheduler == "vllm": return { - "replica_scheduler_provider": "vllm", + "replica_scheduler_config_type": "vllm", } assert self.scheduler == "sarathi" assert self.chunk_size is not None return { - "replica_scheduler_provider": "sarathi", - "sarathi_scheduler_chunk_size": self.chunk_size, - "sarathi_scheduler_enable_rolling_prefills": None, - "sarathi_scheduler_prefill_fitting_tolerance": 0.0, + "replica_scheduler_config_type": "sarathi", + "sarathi_scheduler_config_chunk_size": self.chunk_size, } @@ -145,10 +148,14 @@ def to_config_dict(self): **self.trace_config.to_config_dict(), **self.cluster_config.to_config_dict(), **self.scheduler_config.to_config_dict(), - "replica_num_tensor_parallel_workers": self.num_tensor_parallel_workers, - "replica_num_pipeline_stages": self.num_pipeline_stages, - "replica_scheduler_batch_size_cap": self.batch_size, - "cluster_num_replicas": self.num_replicas, + "replica_config_tensor_parallel_size": self.num_tensor_parallel_workers, + "replica_config_num_pipeline_stages": self.num_pipeline_stages, + "vllm_scheduler_config_batch_size_cap": self.batch_size, + "light_l_l_m_scheduler_config_batch_size_cap": self.batch_size, + "orca_scheduler_config_batch_size_cap": self.batch_size, + "faster_transformer_scheduler_config_batch_size_cap": self.batch_size, + "sarathi_scheduler_config_batch_size_cap": self.batch_size, + "cluster_config_num_replicas": self.num_replicas, } @classmethod @@ -234,14 +241,16 @@ def to_config_dict(self): **self.job_config.to_config_dict(), "output_dir": self.get_run_dir(), "cache_dir": self.cache_dir, - "poisson_request_interval_generator_qps": self.qps, - "simulator_time_limit": self.time_limit * 60, # to seconds - "no-metrics_store_save_table_to_wandb": None, - "no-metrics_store_store_plots": None, - "no-metrics_store_store_operation_metrics": None, - "no-metrics_store_store_token_completion_metrics": None, - "no-write_chrome_trace": None, - "sklearn_execution_time_predictor_skip_cpu_overhead_modeling": None, + "poisson_request_interval_generator_config_qps": self.qps, + "gamma_request_interval_generator_config_qps": self.qps, + "time_limit": self.time_limit * 60, # to seconds + "no-metrics_config_save_table_to_wandb": None, + "no-metrics_config_store_plots": None, + "no-metrics_config_store_operation_metrics": None, + "no-metrics_config_store_token_completion_metrics": None, + "no-metrics_config_enable_chrome_trace": None, + "linear_regression_execution_time_predictor_config_skip_cpu_overhead_modeling": None, + "random_forrest_execution_time_predictor_config_skip_cpu_overhead_modeling": None, } def to_args(self): diff --git a/vidur/constants.py b/vidur/constants.py deleted file mode 100644 index e1dfca7..0000000 --- a/vidur/constants.py +++ /dev/null @@ -1,8 +0,0 @@ -import os - -PY_ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) - -DEFAULT_CONFIG_FILE = f"{PY_ROOT_DIR}/config/default.yml" -MODEL_CONFIG_DIR = f"{PY_ROOT_DIR}/../data/model_configs" -DEVICE_CONFIG_DIR = f"{PY_ROOT_DIR}/../data/device_configs" -CACHE_DIR = f"{PY_ROOT_DIR}/../.simulator_cache" diff --git a/vidur/entities/replica.py b/vidur/entities/replica.py index d0ec553..2484574 100644 --- a/vidur/entities/replica.py +++ b/vidur/entities/replica.py @@ -1,6 +1,6 @@ from math import ceil -from vidur.config import Config +from vidur.config import SimulationConfig from vidur.entities.base_entity import BaseEntity from vidur.logger import init_logger @@ -8,28 +8,31 @@ class Replica(BaseEntity): - def __init__(self, config: Config) -> None: - assert config.replica_num_layers % config.replica_num_pipeline_stages == 0 + def __init__(self, config: SimulationConfig) -> None: + assert config.cluster_config.replica_config.model_config.num_layers % config.cluster_config.replica_config.num_pipeline_stages == 0 assert ( - config.replica_embedding_dim % config.replica_num_tensor_parallel_workers + config.cluster_config.replica_config.model_config.embedding_dim % config.cluster_config.replica_config.tensor_parallel_size == 0 ) self._id = Replica.generate_id() - self._num_pipeline_stages = config.replica_num_pipeline_stages - self._num_tensor_parallel_workers = config.replica_num_tensor_parallel_workers - self._num_layers = config.replica_num_layers - self._num_q_heads = config.replica_num_q_heads - self._num_kv_heads = config.replica_num_kv_heads - self._embedding_dim = config.replica_embedding_dim - self._mlp_hidden_dim = config.replica_mlp_hidden_dim - self._use_gated_mlp = config.replica_use_gated_mlp - self._vocab_size = config.replica_vocab_size - self._total_memory_gb = config.replica_total_memory_gb - self._memory_margin_fraction = config.replica_memory_margin_fraction - self._max_request_tokens = config.request_generator_max_tokens - self._per_device_flops = config.replica_fp16_tflops * 2**40 + replica_config = config.cluster_config.replica_config + model_config = replica_config.model_config + + self._num_pipeline_stages = replica_config.num_pipeline_stages + self._num_tensor_parallel_workers = config.cluster_config.replica_config.tensor_parallel_size + self._num_layers = model_config.num_layers + self._num_q_heads = model_config.num_q_heads + self._num_kv_heads = model_config.num_kv_heads + self._embedding_dim = model_config.embedding_dim + self._mlp_hidden_dim = model_config.mlp_hidden_dim + self._use_gated_mlp = model_config.use_gated_mlp + self._vocab_size = model_config.vocab_size + self._total_memory_gb = replica_config.device_config.total_memory_gb + self._memory_margin_fraction = replica_config.memory_margin_fraction + self._max_request_tokens = config.request_generator_config.max_tokens + self._per_device_flops = replica_config.device_config.fp16_tflops * 2**40 @property def num_layers(self) -> int: diff --git a/vidur/execution_time_predictor/base_execution_time_predictor.py b/vidur/execution_time_predictor/base_execution_time_predictor.py index 84a21e2..0dfd7c6 100644 --- a/vidur/execution_time_predictor/base_execution_time_predictor.py +++ b/vidur/execution_time_predictor/base_execution_time_predictor.py @@ -1,18 +1,21 @@ from abc import ABC, abstractmethod -from vidur.config import Config +from vidur.config import SimulationConfig from vidur.entities import Batch, ExecutionTime class BaseExecutionTimePredictor(ABC): - def __init__(self, config: Config) -> None: - self._num_tensor_parallel_workers = config.replica_num_tensor_parallel_workers - self._num_pipeline_stages = config.replica_num_pipeline_stages - self._num_layers = config.replica_num_layers + def __init__(self, config: SimulationConfig) -> None: + replica_config = config.cluster_config.replica_config + model_config = replica_config.model_config + + self._num_tensor_parallel_workers = replica_config.tensor_parallel_size + self._num_pipeline_stages = replica_config.num_pipeline_stages + self._num_layers = model_config.num_layers self._num_layers_per_pipeline_stage = ( - config.replica_num_layers // config.replica_num_pipeline_stages + model_config.num_layers // replica_config.num_pipeline_stages ) - self._replica_scheduler_provider = config.replica_scheduler_provider + self._replica_scheduler_provider = str(config.cluster_config.replica_scheduler_config.get_type()) def get_execution_time(self, batch: Batch, pipeline_stage: int) -> ExecutionTime: if pipeline_stage == self._num_pipeline_stages - 1: diff --git a/vidur/execution_time_predictor/linear_regression_execution_time_predictor.py b/vidur/execution_time_predictor/linear_regression_execution_time_predictor.py index 6506e1e..6e0c6f7 100644 --- a/vidur/execution_time_predictor/linear_regression_execution_time_predictor.py +++ b/vidur/execution_time_predictor/linear_regression_execution_time_predictor.py @@ -2,24 +2,26 @@ from sklearn.pipeline import make_pipeline from sklearn.preprocessing import PolynomialFeatures +from vidur.config import SimulationConfig from vidur.execution_time_predictor.sklearn_execution_time_predictor import ( SklearnExecutionTimePredictor, ) class LinearRegressionExecutionTimePredictor(SklearnExecutionTimePredictor): - def __init__(self, config): + def __init__(self, config: SimulationConfig): + predictor_config = config.cluster_config.execution_time_predictor_config self._polynomial_degree = ( - config.linear_regression_execution_time_predictor_polynomial_degree + predictor_config.polynomial_degree ) self._polynomial_include_bias = ( - config.linear_regression_execution_time_predictor_polynomial_include_bias + predictor_config.polynomial_include_bias ) self._polynomial_interaction_only = ( - config.linear_regression_execution_time_predictor_polynomial_interaction_only + predictor_config.polynomial_interaction_only ) self._fit_intercept = ( - config.linear_regression_execution_time_predictor_fit_intercept + predictor_config.fit_intercept ) # will trigger model training diff --git a/vidur/execution_time_predictor/random_forrest_execution_time_predictor.py b/vidur/execution_time_predictor/random_forrest_execution_time_predictor.py index 0898a0c..4f8736e 100644 --- a/vidur/execution_time_predictor/random_forrest_execution_time_predictor.py +++ b/vidur/execution_time_predictor/random_forrest_execution_time_predictor.py @@ -1,19 +1,19 @@ from sklearn.ensemble import RandomForestRegressor +from vidur.config import SimulationConfig from vidur.execution_time_predictor.sklearn_execution_time_predictor import ( SklearnExecutionTimePredictor, ) class RandomForrestExecutionTimePredictor(SklearnExecutionTimePredictor): - def __init__(self, config): + def __init__(self, config: SimulationConfig): + predictor_config = config.cluster_config.execution_time_predictor_config self._num_estimators = ( - config.random_forrest_execution_time_predictor_num_estimators - ) - self._max_depth = config.random_forrest_execution_time_predictor_max_depth - self._min_samples_split = ( - config.random_forrest_execution_time_predictor_min_samples_split + predictor_config.num_estimators ) + self._max_depth = predictor_config.max_depth + self._min_samples_split = predictor_config.min_samples_split # will trigger model training super().__init__(config) diff --git a/vidur/execution_time_predictor/sklearn_execution_time_predictor.py b/vidur/execution_time_predictor/sklearn_execution_time_predictor.py index 852720b..0c8bb86 100644 --- a/vidur/execution_time_predictor/sklearn_execution_time_predictor.py +++ b/vidur/execution_time_predictor/sklearn_execution_time_predictor.py @@ -12,7 +12,7 @@ from sklearn.metrics import make_scorer from sklearn.model_selection import GridSearchCV -from vidur.config import Config +from vidur.config import SimulationConfig from vidur.entities import Batch from vidur.execution_time_predictor.base_execution_time_predictor import ( BaseExecutionTimePredictor, @@ -23,94 +23,95 @@ class SklearnExecutionTimePredictor(BaseExecutionTimePredictor): - def __init__(self, config: Config) -> None: + def __init__(self, config: SimulationConfig) -> None: super().__init__(config) self._cache_dir = f"{config.cache_dir}/execution_time_predictor" os.makedirs(self._cache_dir, exist_ok=True) - self._no_cache = config.sklearn_execution_time_predictor_no_cache + predictor_config = config.cluster_config.execution_time_predictor_config + model_config = config.cluster_config.replica_config.model_config + + self._no_cache = predictor_config.no_cache self._k_fold_cv_splits = ( - config.sklearn_execution_time_predictor_k_fold_cv_splits + predictor_config.k_fold_cv_splits ) - self._model_name = config.replica_model_name - self._num_q_heads = config.replica_num_q_heads - self._num_kv_heads = config.replica_num_kv_heads - self._embedding_dim = config.replica_embedding_dim - self._mlp_hidden_dim = config.replica_mlp_hidden_dim - self._use_gated_mlp = config.replica_use_gated_mlp - self._vocab_size = config.replica_vocab_size - self._block_size = config.replica_block_size - self._norm = config.replica_norm - self._post_attn_norm = config.replica_post_attn_norm + self._model_name = str(model_config.get_type()) + self._num_q_heads = model_config.num_q_heads + self._num_kv_heads = model_config.num_kv_heads + self._embedding_dim = model_config.embedding_dim + self._mlp_hidden_dim = model_config.mlp_hidden_dim + self._use_gated_mlp = model_config.use_gated_mlp + self._vocab_size = model_config.vocab_size + self._block_size = config.cluster_config.replica_scheduler_config.block_size + self._norm = model_config.norm + self._post_attn_norm = model_config.post_attn_norm - self._model_provider = config.execution_time_predictor_provider + self._model_provider = str(model_config.get_type()) # These overheads are only for GQA models self._attention_prefill_batching_overhead_fraction = ( ( - config.sklearn_execution_time_predictor_attention_prefill_batching_overhead_fraction + predictor_config.attention_prefill_batching_overhead_fraction ) if self._num_q_heads > self._num_kv_heads else 0 ) self._attention_decode_batching_overhead_fraction = ( ( - config.sklearn_execution_time_predictor_attention_decode_batching_overhead_fraction + predictor_config.attention_decode_batching_overhead_fraction ) if self._num_q_heads > self._num_kv_heads else 0 ) self._nccl_cpu_launch_overhead_ms = ( - config.sklearn_execution_time_predictor_nccl_cpu_launch_overhead_ms + predictor_config.nccl_cpu_launch_overhead_ms ) self._nccl_cpu_skew_overhead_per_device_ms = ( - config.sklearn_execution_time_predictor_nccl_cpu_skew_overhead_per_device_ms + predictor_config.nccl_cpu_skew_overhead_per_device_ms ) self._max_batch_size = ( - config.sklearn_execution_time_predictor_prediction_max_batch_size + predictor_config.prediction_max_batch_size ) self._max_tokens_per_request = ( - config.sklearn_execution_time_predictor_prediction_max_tokens_per_request + predictor_config.prediction_max_tokens_per_request ) - if config.replica_scheduler_provider == "orca": + if self._replica_scheduler_provider == "orca": self._max_tokens = self._max_tokens_per_request * self._max_batch_size else: self._max_tokens = self._max_tokens_per_request - self._prefill_chunk_size = config.replica_prefill_chunk_size - self._compute_input_file = ( - config.sklearn_execution_time_predictor_compute_input_file + predictor_config.compute_input_file ) self._attention_input_file = ( - config.sklearn_execution_time_predictor_attention_input_file + predictor_config.attention_input_file ) self._all_reduce_input_file = ( - config.sklearn_execution_time_predictor_all_reduce_input_file + predictor_config.all_reduce_input_file ) self._send_recv_input_file = ( - config.sklearn_execution_time_predictor_send_recv_input_file + predictor_config.send_recv_input_file ) self._cpu_overhead_input_file = ( - config.sklearn_execution_time_predictor_cpu_overhead_input_file + predictor_config.cpu_overhead_input_file ) self._kv_cache_prediction_granularity = ( - config.sklearn_execution_time_predictor_kv_cache_prediction_granularity + predictor_config.kv_cache_prediction_granularity ) self._prediction_max_prefill_chunk_size = ( - config.sklearn_execution_time_predictor_prediction_max_prefill_chunk_size + predictor_config.prediction_max_prefill_chunk_size ) - self._device_memory = config.replica_total_memory_gb + self._device_memory = config.cluster_config.replica_config.device_config.total_memory_gb self._num_training_job_threads = ( - config.sklearn_execution_time_predictor_num_training_job_threads + predictor_config.num_training_job_threads ) - devices_per_node = config.replica_num_devices_per_node + devices_per_node = config.cluster_config.replica_config.node_config.num_devices_per_node num_workers = self._num_pipeline_stages * self._num_tensor_parallel_workers assert ( num_workers < devices_per_node or num_workers % devices_per_node == 0 @@ -118,9 +119,8 @@ def __init__(self, config: Config) -> None: self._is_multi_node = num_workers > devices_per_node - self._max_batch_tokens = config.vllm_scheduler_max_tokens_in_batch self._skip_cpu_overhead_modeling = ( - config.sklearn_execution_time_predictor_skip_cpu_overhead_modeling + predictor_config.skip_cpu_overhead_modeling ) self._models = self._train_models() diff --git a/vidur/profiling/common/model_config.py b/vidur/profiling/common/model_config.py index 7a5359f..dcf3542 100644 --- a/vidur/profiling/common/model_config.py +++ b/vidur/profiling/common/model_config.py @@ -1,8 +1,9 @@ from typing import Any, Dict, Optional -import yaml +from dataclasses import asdict -from vidur.constants import MODEL_CONFIG_DIR +from vidur.config.model_config import BaseModelConfig +from vidur.config.utils import get_model_config class ModelConfig: @@ -58,8 +59,7 @@ def __init__( @staticmethod def from_model_name(model_name: str): - model_config_path = f"{MODEL_CONFIG_DIR}/{model_name}.yml" - with open(model_config_path, "r") as f: - model_config = yaml.safe_load(f) + model_config: BaseModelConfig = get_model_config(model_name) + model_config_dict = asdict(model_config) - return ModelConfig(model_name, **model_config) + return ModelConfig(model_name, **model_config_dict) diff --git a/vidur/types/model_type.py b/vidur/types/model_type.py index e5e1bab..aeaf364 100644 --- a/vidur/types/model_type.py +++ b/vidur/types/model_type.py @@ -5,6 +5,8 @@ class ModelType(BaseIntEnum): CODE_LLAMA_34B = 0 LLAMA_2_7B = 1 LLAMA_2_70B = 2 - INTERNLM_2_20B = 3 - PHI2 = 4 - QWEN_72B = 5 + LLAMA_3_8B = 3 + LLAMA_3_70B = 4 + INTERNLM_2_20B = 5 + PHI2 = 6 + QWEN_72B = 7 diff --git a/vidur/utils/mfu_calculator.py b/vidur/utils/mfu_calculator.py index 482385d..ea0c511 100644 --- a/vidur/utils/mfu_calculator.py +++ b/vidur/utils/mfu_calculator.py @@ -1,21 +1,25 @@ -from vidur.config import Config +from vidur.config import SimulationConfig from vidur.entities import BatchStage from vidur.utils.param_counter import ParamCounter class MFUCalculator: - def __init__(self, config: Config): + def __init__(self, config: SimulationConfig): param_counter = ParamCounter(config) self._num_params_per_device = param_counter.get_num_parameters_per_device() + + replica_config = config.cluster_config.replica_config + model_config = replica_config.model_config + self._num_layers_per_device = ( - config.replica_num_layers // config.replica_num_pipeline_stages + model_config.num_layers // replica_config.num_pipeline_stages ) - self._embedding_dim = config.replica_embedding_dim + self._embedding_dim = model_config.embedding_dim self._num_heads_per_device = ( - config.replica_num_q_heads // config.replica_num_tensor_parallel_workers + model_config.num_q_heads // replica_config.tensor_parallel_size ) - self._head_dimension = self._embedding_dim // config.replica_num_q_heads - self._device_flops = config.replica_fp16_tflops * 2**40 + self._head_dimension = self._embedding_dim // model_config.num_q_heads + self._device_flops = replica_config.device_config.fp16_tflops * 2**40 def _get_mlp_flops(self, batch_stage: BatchStage) -> float: num_tokens = sum(batch_stage.num_tokens) diff --git a/vidur/utils/param_counter.py b/vidur/utils/param_counter.py index 1ca5b28..fd7552c 100644 --- a/vidur/utils/param_counter.py +++ b/vidur/utils/param_counter.py @@ -1,20 +1,22 @@ from math import ceil -from vidur.config import Config +from vidur.config import SimulationConfig class ParamCounter: - def __init__(self, config: Config) -> None: - self._embedding_dim = config.replica_embedding_dim - self._num_pipeline_stages = config.replica_num_pipeline_stages - self._num_tensor_parallel_workers = config.replica_num_tensor_parallel_workers - self._num_layers = config.replica_num_layers - self._num_q_heads = config.replica_num_q_heads - self._num_kv_heads = config.replica_num_kv_heads - self._embedding_dim = config.replica_embedding_dim - self._mlp_hidden_dim = config.replica_mlp_hidden_dim - self._use_gated_mlp = config.replica_use_gated_mlp - self._vocab_size = config.replica_vocab_size + def __init__(self, config: SimulationConfig) -> None: + replica_config = config.cluster_config.replica_config + model_config = replica_config.model_config + + self._embedding_dim = model_config.embedding_dim + self._num_pipeline_stages = replica_config.num_pipeline_stages + self._num_tensor_parallel_workers = replica_config.tensor_parallel_size + self._num_layers = model_config.num_layers + self._num_q_heads = model_config.num_q_heads + self._num_kv_heads = model_config.num_kv_heads + self._mlp_hidden_dim = model_config.mlp_hidden_dim + self._use_gated_mlp = model_config.use_gated_mlp + self._vocab_size = model_config.vocab_size assert self._num_q_heads % self._num_tensor_parallel_workers == 0 assert self._num_layers % self._num_pipeline_stages == 0 From bb89f23f6feab0b9c4262ccbfe523ef84782aff6 Mon Sep 17 00:00:00 2001 From: Anmol Agarwal Date: Thu, 25 Jul 2024 03:59:41 -0400 Subject: [PATCH 04/24] model config using model names --- vidur/config/config.py | 13 +++++-------- vidur/config/constants.py | 13 +------------ vidur/config/model_config.py | 33 ++++++++++++++++++++++++++++++++- 3 files changed, 38 insertions(+), 21 deletions(-) diff --git a/vidur/config/config.py b/vidur/config/config.py index 04cf0b1..16363c3 100644 --- a/vidur/config/config.py +++ b/vidur/config/config.py @@ -4,7 +4,7 @@ from typing import Optional, List from vidur.config.base_poly_config import BasePolyConfig -from vidur.config.constants import MODEL_NAME_MAPPING, NETWORK_DEVICE_MAPPING +from vidur.config.constants import NETWORK_DEVICE_MAPPING from vidur.config.device_sku_config import BaseDeviceSKUConfig from vidur.config.flat_dataclass import create_flat_dataclass from vidur.config.model_config import BaseModelConfig @@ -463,12 +463,9 @@ class ReplicaConfig: def get_model_config(model_name: str) -> BaseModelConfig: model_configs = get_all_subclasses(BaseModelConfig) - if model_name not in MODEL_NAME_MAPPING: - raise ValueError(f"Model name not found: {model_name}") - model_type = MODEL_NAME_MAPPING[model_name] for model_config in model_configs: - if model_config.get_type() == model_type: - return model_config + if model_config.get_name() == model_name: + return model_config() return ValueError(f"Model config not found for model name: {model_name}") @@ -476,7 +473,7 @@ def get_device_config(device_name: str) -> BaseDeviceSKUConfig: device_configs = get_all_subclasses(BaseDeviceSKUConfig) for device_config in device_configs: if str(device_config.get_type()) == device_name: - return device_config + return device_config() raise ValueError(f"Device config not found for device name: {device_name}") @@ -487,7 +484,7 @@ def get_node_config(network_device: str) -> BaseNodeSKUConfig: network_type = NETWORK_DEVICE_MAPPING[network_device] for node_config in node_configs: if node_config.get_type() == network_type: - return node_config + return node_config() raise ValueError(f"Node config not found for network device: {network_device}") def __post_init__(self): diff --git a/vidur/config/constants.py b/vidur/config/constants.py index f7ea806..6e5bd45 100644 --- a/vidur/config/constants.py +++ b/vidur/config/constants.py @@ -1,15 +1,4 @@ -from vidur.types import ModelType, NodeSKUType - -MODEL_NAME_MAPPING = { - "Qwen/Qwen-72B": ModelType.QWEN_72B, - "codellama/CodeLlama-34b-Instruct-hf": ModelType.CODE_LLAMA_34B, - "internlm/internlm2-20b": ModelType.INTERNLM_2_20B, - "meta-llama/Llama-2-7b-hf": ModelType.LLAMA_2_7B, - "meta-llama/Llama-2-70b-hf": ModelType.LLAMA_2_70B, - "meta-llama/Meta-Llama-3-8b": ModelType.LLAMA_3_8B, - "meta-llama/Meta-Llama-3-70B": ModelType.LLAMA_3_70B, - "microsoft/phi-2": ModelType.PHI2, -} +from vidur.types import NodeSKUType NETWORK_DEVICE_MAPPING = { "a40_pair_nvlink": NodeSKUType.A40_PAIRWISE_NVLINK, diff --git a/vidur/config/model_config.py b/vidur/config/model_config.py index 6a0bf10..16da454 100644 --- a/vidur/config/model_config.py +++ b/vidur/config/model_config.py @@ -153,6 +153,10 @@ class CodeLlama34BModelConfig(Llama2ModelConfig): @staticmethod def get_type(): return ModelType.CODE_LLAMA_34B + + @staticmethod + def get_name(): + return "codellama/CodeLlama-34b-Instruct-hf" @dataclass @@ -181,6 +185,10 @@ class Llama2_7BModelConfig(Llama2ModelConfig): @staticmethod def get_type(): return ModelType.LLAMA_2_7B + + @staticmethod + def get_name(): + return "meta-llama/Llama-2-7b-hf" @dataclass @@ -209,6 +217,10 @@ class Llama2_70BModelConfig(Llama2ModelConfig): @staticmethod def get_type(): return ModelType.LLAMA_2_70B + + @staticmethod + def get_name(): + return "meta-llama/Llama-2-70b-hf" @dataclass @@ -250,6 +262,10 @@ class Llama3_8BModelConfig(Llama2ModelConfig): @staticmethod def get_type(): return ModelType.LLAMA_3_70B + + @staticmethod + def get_name(): + return "meta-llama/Meta-Llama-3-8b" @dataclass @@ -287,10 +303,13 @@ class Llama3_70BModelConfig(Llama2ModelConfig): metadata={"help": "The vocabulary size of the model"}, ) - @staticmethod def get_type(): return ModelType.LLAMA_3_70B + + @staticmethod + def get_name(): + return "meta-llama/Meta-Llama-3-70B" @dataclass class InternLM2ModelConfig(Llama2ModelConfig): @@ -330,6 +349,10 @@ class InternLM2_20BModelConfig(InternLM2ModelConfig): @staticmethod def get_type(): return ModelType.INTERNLM_2_20B + + @staticmethod + def get_name(): + return "internlm/internlm2-20b" @dataclass @@ -410,6 +433,10 @@ class Phi2ModelConfig(Llama2ModelConfig): @staticmethod def get_type(): return ModelType.PHI2 + + @staticmethod + def get_name(): + return "microsoft/phi-2" @dataclass @@ -454,3 +481,7 @@ class Qwen72BModelConfig(QwenModelConfig): @staticmethod def get_type(): return ModelType.QWEN_72B + + @staticmethod + def get_name(): + return "Qwen/Qwen-72B" From 8ee41e4cde13fae581a6b5e68ba4d0c14f93179e Mon Sep 17 00:00:00 2001 From: Anmol Agarwal Date: Thu, 25 Jul 2024 04:02:15 -0400 Subject: [PATCH 05/24] Remove model type --- vidur/config/model_config.py | 33 --------------------------------- vidur/types/model_type.py | 12 ------------ 2 files changed, 45 deletions(-) delete mode 100644 vidur/types/model_type.py diff --git a/vidur/config/model_config.py b/vidur/config/model_config.py index 16da454..7fe5987 100644 --- a/vidur/config/model_config.py +++ b/vidur/config/model_config.py @@ -149,10 +149,6 @@ class CodeLlama34BModelConfig(Llama2ModelConfig): default=22016, metadata={"help": "The hidden dimension of the MLP in the model"}, ) - - @staticmethod - def get_type(): - return ModelType.CODE_LLAMA_34B @staticmethod def get_name(): @@ -181,10 +177,6 @@ class Llama2_7BModelConfig(Llama2ModelConfig): default=11008, metadata={"help": "The hidden dimension of the MLP in the model"}, ) - - @staticmethod - def get_type(): - return ModelType.LLAMA_2_7B @staticmethod def get_name(): @@ -213,10 +205,6 @@ class Llama2_70BModelConfig(Llama2ModelConfig): default=28672, metadata={"help": "The hidden dimension of the MLP in the model"}, ) - - @staticmethod - def get_type(): - return ModelType.LLAMA_2_70B @staticmethod def get_name(): @@ -258,11 +246,6 @@ class Llama3_8BModelConfig(Llama2ModelConfig): metadata={"help": "The vocabulary size of the model"}, ) - - @staticmethod - def get_type(): - return ModelType.LLAMA_3_70B - @staticmethod def get_name(): return "meta-llama/Meta-Llama-3-8b" @@ -303,10 +286,6 @@ class Llama3_70BModelConfig(Llama2ModelConfig): metadata={"help": "The vocabulary size of the model"}, ) - @staticmethod - def get_type(): - return ModelType.LLAMA_3_70B - @staticmethod def get_name(): return "meta-llama/Meta-Llama-3-70B" @@ -346,10 +325,6 @@ class InternLM2_20BModelConfig(InternLM2ModelConfig): metadata={"help": "The hidden dimension of the MLP in the model"}, ) - @staticmethod - def get_type(): - return ModelType.INTERNLM_2_20B - @staticmethod def get_name(): return "internlm/internlm2-20b" @@ -430,10 +405,6 @@ class Phi2ModelConfig(Llama2ModelConfig): metadata={"help": "Whether to use the Neox style in the model"}, ) - @staticmethod - def get_type(): - return ModelType.PHI2 - @staticmethod def get_name(): return "microsoft/phi-2" @@ -478,10 +449,6 @@ class Qwen72BModelConfig(QwenModelConfig): metadata={"help": "The hidden dimension of the MLP in the model"}, ) - @staticmethod - def get_type(): - return ModelType.QWEN_72B - @staticmethod def get_name(): return "Qwen/Qwen-72B" diff --git a/vidur/types/model_type.py b/vidur/types/model_type.py deleted file mode 100644 index aeaf364..0000000 --- a/vidur/types/model_type.py +++ /dev/null @@ -1,12 +0,0 @@ -from vidur.types.base_int_enum import BaseIntEnum - - -class ModelType(BaseIntEnum): - CODE_LLAMA_34B = 0 - LLAMA_2_7B = 1 - LLAMA_2_70B = 2 - LLAMA_3_8B = 3 - LLAMA_3_70B = 4 - INTERNLM_2_20B = 5 - PHI2 = 6 - QWEN_72B = 7 From d7ba88f5fe77a699e5111393951361a509070378 Mon Sep 17 00:00:00 2001 From: Anmol Agarwal Date: Thu, 25 Jul 2024 04:03:58 -0400 Subject: [PATCH 06/24] bug fix --- vidur/config/model_config.py | 2 +- vidur/types/__init__.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/vidur/config/model_config.py b/vidur/config/model_config.py index 7fe5987..9c1081a 100644 --- a/vidur/config/model_config.py +++ b/vidur/config/model_config.py @@ -3,7 +3,7 @@ from vidur.config.base_poly_config import BasePolyConfig from vidur.logger import init_logger -from vidur.types import NormType, ActivationType, ModelType +from vidur.types import NormType, ActivationType logger = init_logger(__name__) diff --git a/vidur/types/__init__.py b/vidur/types/__init__.py index a432aa9..f21104d 100644 --- a/vidur/types/__init__.py +++ b/vidur/types/__init__.py @@ -10,7 +10,6 @@ from vidur.types.node_sku_type import NodeSKUType from vidur.types.norm_type import NormType from vidur.types.activation_type import ActivationType -from vidur.types.model_type import ModelType __all__ = [ @@ -25,6 +24,5 @@ NodeSKUType, NormType, ActivationType, - ModelType, BaseIntEnum, ] From a9a63af5578ebfcc4552b36243186afdf46351e4 Mon Sep 17 00:00:00 2001 From: Anmol Agarwal Date: Thu, 25 Jul 2024 04:11:41 -0400 Subject: [PATCH 07/24] Lightllm --- vidur/config/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vidur/config/config.py b/vidur/config/config.py index 16363c3..30b76e5 100644 --- a/vidur/config/config.py +++ b/vidur/config/config.py @@ -301,7 +301,7 @@ def get_type(): @dataclass -class LightLLMSchedulerConfig(BaseReplicaSchedulerConfig): +class LightllmSchedulerConfig(BaseReplicaSchedulerConfig): max_batched_tokens: int = field( default=None, metadata={"help": "Maximum batched tokens for LightLLM."}, From 49fe24ef37765be5bec66e39f843ee7049b7716c Mon Sep 17 00:00:00 2001 From: Anmol Agarwal Date: Thu, 25 Jul 2024 04:21:14 -0400 Subject: [PATCH 08/24] Merge branch 'main' into users/amey/new_config --- .../meta-llama/Meta-Llama-3-70B.yml | 16 ---------------- .../model_configs/meta-llama/Meta-Llama-3-8B.yml | 16 ---------------- vidur/config/config.py | 11 ++++++----- 3 files changed, 6 insertions(+), 37 deletions(-) delete mode 100644 data/model_configs/meta-llama/Meta-Llama-3-70B.yml delete mode 100644 data/model_configs/meta-llama/Meta-Llama-3-8B.yml diff --git a/data/model_configs/meta-llama/Meta-Llama-3-70B.yml b/data/model_configs/meta-llama/Meta-Llama-3-70B.yml deleted file mode 100644 index eff626a..0000000 --- a/data/model_configs/meta-llama/Meta-Llama-3-70B.yml +++ /dev/null @@ -1,16 +0,0 @@ -num_layers: 80 -num_q_heads: 64 -num_kv_heads: 8 -embedding_dim: 8192 -mlp_hidden_dim: 28672 -max_position_embeddings: 8192 -use_gated_mlp: true -use_bias: false -use_qkv_bias: false -activation: silu -norm: rms_norm -post_attn_norm: true -rope_theta: 500000.0 -rope_scaling: null -vocab_size: 128256 -is_neox_style: true diff --git a/data/model_configs/meta-llama/Meta-Llama-3-8B.yml b/data/model_configs/meta-llama/Meta-Llama-3-8B.yml deleted file mode 100644 index e4bba4c..0000000 --- a/data/model_configs/meta-llama/Meta-Llama-3-8B.yml +++ /dev/null @@ -1,16 +0,0 @@ -num_layers: 32 -num_q_heads: 32 -num_kv_heads: 8 -embedding_dim: 4096 -mlp_hidden_dim: 14336 -max_position_embeddings: 4096 -use_gated_mlp: true -use_bias: false -use_qkv_bias: false -activation: silu -norm: rms_norm -post_attn_norm: true -rope_theta: 500000.0 -rope_scaling: null -vocab_size: 128256 -is_neox_style: true diff --git a/vidur/config/config.py b/vidur/config/config.py index 30b76e5..4bcbeae 100644 --- a/vidur/config/config.py +++ b/vidur/config/config.py @@ -461,6 +461,7 @@ class ReplicaConfig: metadata={"help": "Network device."}, ) + @staticmethod def get_model_config(model_name: str) -> BaseModelConfig: model_configs = get_all_subclasses(BaseModelConfig) for model_config in model_configs: @@ -468,7 +469,7 @@ def get_model_config(model_name: str) -> BaseModelConfig: return model_config() return ValueError(f"Model config not found for model name: {model_name}") - + @staticmethod def get_device_config(device_name: str) -> BaseDeviceSKUConfig: device_configs = get_all_subclasses(BaseDeviceSKUConfig) for device_config in device_configs: @@ -476,7 +477,7 @@ def get_device_config(device_name: str) -> BaseDeviceSKUConfig: return device_config() raise ValueError(f"Device config not found for device name: {device_name}") - + @staticmethod def get_node_config(network_device: str) -> BaseNodeSKUConfig: node_configs = get_all_subclasses(BaseNodeSKUConfig) if network_device not in NETWORK_DEVICE_MAPPING: @@ -489,9 +490,9 @@ def get_node_config(network_device: str) -> BaseNodeSKUConfig: def __post_init__(self): self.world_size = self.num_pipeline_stages * self.tensor_parallel_size - self.model_config: BaseModelConfig = self.get_model_config(self.model_name) - self.device_config: BaseDeviceSKUConfig = self.get_device_config(self.device) - self.node_config: BaseNodeSKUConfig = self.get_node_config(self.network_device) + self.model_config: BaseModelConfig = ReplicaConfig.get_model_config(self.model_name) + self.device_config: BaseDeviceSKUConfig = ReplicaConfig.get_device_config(self.device) + self.node_config: BaseNodeSKUConfig = ReplicaConfig.get_node_config(self.network_device) @dataclass From b3289596f9491330428f10ec65d3f1a1f45fb8ad Mon Sep 17 00:00:00 2001 From: Anmol Agarwal Date: Fri, 26 Jul 2024 20:48:12 -0400 Subject: [PATCH 09/24] Bug fixes and improvements --- vidur/config/config.py | 18 +++++++++++------- vidur/config/flat_dataclass.py | 10 +++++++--- vidur/config/model_config.py | 4 ++++ vidur/config/utils.py | 9 +++++++++ .../sklearn_execution_time_predictor.py | 4 ++-- vidur/metrics/metrics_store.py | 5 +++-- 6 files changed, 36 insertions(+), 14 deletions(-) diff --git a/vidur/config/config.py b/vidur/config/config.py index 4bcbeae..5ad0d0c 100644 --- a/vidur/config/config.py +++ b/vidur/config/config.py @@ -1,6 +1,8 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field from datetime import datetime +import json +import os from typing import Optional, List from vidur.config.base_poly_config import BasePolyConfig @@ -9,7 +11,7 @@ from vidur.config.flat_dataclass import create_flat_dataclass from vidur.config.model_config import BaseModelConfig from vidur.config.node_sku_config import BaseNodeSKUConfig -from vidur.config.utils import get_all_subclasses +from vidur.config.utils import get_all_subclasses, dataclass_to_dict from vidur.logger import init_logger from vidur.types import ReplicaSchedulerType, GlobalSchedulerType, ExecutionTimePredictorType, RequestGeneratorType, RequestIntervalGeneratorType, RequestLengthGeneratorType @@ -214,11 +216,6 @@ class SyntheticRequestGeneratorConfig(BaseRequestGeneratorConfig): default=None, metadata={"help": "Duration of the synthetic request generator."}, ) - max_tokens: int = field( - init=False, - default=4096, - metadata={"help": "Maximum tokens for the synthetic request generator."}, - ) def __post_init__(self): self.max_tokens = self.length_generator_config.max_tokens @@ -564,7 +561,7 @@ class BaseExecutionTimePredictorConfig(BasePolyConfig): metadata={"help": "Max batch size for prediction."}, ) prediction_max_tokens_per_request: int = field( - default=4096, + default=8192, metadata={"help": "Max tokens per request for prediction."}, ) attention_decode_batching_overhead_fraction: float = field( @@ -734,6 +731,8 @@ def __post_init__(self): self.output_dir = ( f"{self.output_dir}/{datetime.now().strftime('%Y-%m-%d_%H-%M-%S-%f')}" ) + os.makedirs(self.output_dir, exist_ok=True) + self.write_config_to_file() @classmethod def create_from_cli_args(cls): @@ -748,3 +747,8 @@ def to_dict(self): return self.__dict__ return self.__flat_config__.__dict__ + + def write_config_to_file(self): + config_dict = dataclass_to_dict(self) + with open(f"{self.output_dir}/config.json", "w") as f: + json.dump(config_dict, f, indent=4) diff --git a/vidur/config/flat_dataclass.py b/vidur/config/flat_dataclass.py index 771a262..c86a6d7 100644 --- a/vidur/config/flat_dataclass.py +++ b/vidur/config/flat_dataclass.py @@ -57,13 +57,17 @@ def reconstruct_original_dataclass(self) -> Any: config_type = getattr(self, f"{original_field_name}_type") # find all subclasses of field_type and check which one matches the config_type for subclass in get_all_subclasses(field_type): - if subclass.get_type() == config_type: + if str(subclass.get_type()) == config_type: args[original_field_name] = instances[subclass] break elif hasattr(field_type, "__dataclass_fields__"): args[original_field_name] = instances[field_type] else: - args[original_field_name] = getattr(self, prefixed_field_name) + value = getattr(self, prefixed_field_name) + if callable(value): + # to handle default factory values + value = value() + args[original_field_name] = value instances[_cls] = _cls(**args) @@ -148,7 +152,7 @@ def process_dataclass(_input_dataclass, prefix=""): ) type_field_name = f"{field.name}_type" - default_value = field.default_factory().get_type() + default_value = str(field.default_factory().get_type()) meta_fields_with_defaults.append( (type_field_name, type(default_value), default_value) ) diff --git a/vidur/config/model_config.py b/vidur/config/model_config.py index 9c1081a..6b1c5ec 100644 --- a/vidur/config/model_config.py +++ b/vidur/config/model_config.py @@ -126,6 +126,10 @@ class Llama2ModelConfig(BaseModelConfig): metadata={"help": "Whether to use tensor parallelism in the model"}, ) + @staticmethod + def get_name(): + return "meta-llama/Llama-2-Config" + @dataclass class CodeLlama34BModelConfig(Llama2ModelConfig): diff --git a/vidur/config/utils.py b/vidur/config/utils.py index e0ed107..81e358d 100644 --- a/vidur/config/utils.py +++ b/vidur/config/utils.py @@ -1,3 +1,4 @@ +from dataclasses import asdict from typing import Union, get_args, get_origin primitive_types = {int, str, float, bool, type(None)} @@ -58,3 +59,11 @@ def get_inner_type(field_type: type) -> type: def is_subclass(cls, parent: type) -> bool: return hasattr(cls, "__bases__") and parent in cls.__bases__ + +def dataclass_to_dict(dataclass_instance): + if isinstance(dataclass_instance, list): + return [dataclass_to_dict(item) for item in dataclass_instance] + elif hasattr(dataclass_instance, '__dataclass_fields__'): + return {k: dataclass_to_dict(v) for k, v in asdict(dataclass_instance).items()} + else: + return dataclass_instance diff --git a/vidur/execution_time_predictor/sklearn_execution_time_predictor.py b/vidur/execution_time_predictor/sklearn_execution_time_predictor.py index 0c8bb86..0d19c24 100644 --- a/vidur/execution_time_predictor/sklearn_execution_time_predictor.py +++ b/vidur/execution_time_predictor/sklearn_execution_time_predictor.py @@ -37,7 +37,7 @@ def __init__(self, config: SimulationConfig) -> None: self._k_fold_cv_splits = ( predictor_config.k_fold_cv_splits ) - self._model_name = str(model_config.get_type()) + self._model_name = model_config.get_name() self._num_q_heads = model_config.num_q_heads self._num_kv_heads = model_config.num_kv_heads self._embedding_dim = model_config.embedding_dim @@ -48,7 +48,7 @@ def __init__(self, config: SimulationConfig) -> None: self._norm = model_config.norm self._post_attn_norm = model_config.post_attn_norm - self._model_provider = str(model_config.get_type()) + self._model_provider = str(config.cluster_config.execution_time_predictor_config.get_type()) # These overheads are only for GQA models self._attention_prefill_batching_overhead_fraction = ( diff --git a/vidur/metrics/metrics_store.py b/vidur/metrics/metrics_store.py index 4cd7c72..0add146 100644 --- a/vidur/metrics/metrics_store.py +++ b/vidur/metrics/metrics_store.py @@ -4,6 +4,7 @@ import pandas as pd import plotly_express as px +import wandb from vidur.config import SimulationConfig, MetricsConfig from vidur.entities import Batch, BatchStage, ExecutionTime, Request @@ -50,10 +51,10 @@ class MetricsStore: def __init__(self, config: SimulationConfig): self._config: SimulationConfig = config - metrics_config: MetricsConfig = metrics_config + metrics_config: MetricsConfig = config.cluster_config.metrics_config self._num_replicas = config.cluster_config.num_replicas - self._num_stages = config.cluster_config.replica_scheduler_config.num_pipeline_stages + self._num_stages = config.cluster_config.replica_config.num_pipeline_stages self._should_write_metrics = metrics_config.write_metrics self._subsamples = metrics_config.subsamples self._save_table_to_wandb = metrics_config.save_table_to_wandb From 4317ae9a179a3c76325551e76b83919a6ab7968a Mon Sep 17 00:00:00 2001 From: Anmol Agarwal Date: Fri, 26 Jul 2024 22:00:16 -0400 Subject: [PATCH 10/24] config --- vidur/config/config.py | 20 ++++++++------------ vidur/config/constants.py | 9 --------- 2 files changed, 8 insertions(+), 21 deletions(-) delete mode 100644 vidur/config/constants.py diff --git a/vidur/config/config.py b/vidur/config/config.py index 5ad0d0c..a7dd144 100644 --- a/vidur/config/config.py +++ b/vidur/config/config.py @@ -6,7 +6,6 @@ from typing import Optional, List from vidur.config.base_poly_config import BasePolyConfig -from vidur.config.constants import NETWORK_DEVICE_MAPPING from vidur.config.device_sku_config import BaseDeviceSKUConfig from vidur.config.flat_dataclass import create_flat_dataclass from vidur.config.model_config import BaseModelConfig @@ -65,7 +64,7 @@ def get_type(): @dataclass class PoissonRequestIntervalGeneratorConfig(BaseRequestIntervalGeneratorConfig): qps: float = field( - default=1.0, + default=0.5, metadata={"help": "Queries per second for Poisson Request Interval Generator."}, ) @@ -77,7 +76,7 @@ def get_type(): @dataclass class GammaRequestIntervalGeneratorConfig(BaseRequestIntervalGeneratorConfig): qps: float = field( - default=1.0, + default=0.2, metadata={"help": "Queries per second for Gamma Request Interval Generator."}, ) cv: float = field( @@ -172,7 +171,7 @@ def get_type(): @dataclass class FixedRequestLengthGeneratorConfig(BaseRequestLengthGeneratorConfig): prefill_tokens: int = field( - default=4096, + default=2048, metadata={"help": "Prefill tokens for Fixed Request Length Generator."}, ) decode_tokens: int = field( @@ -209,7 +208,7 @@ class SyntheticRequestGeneratorConfig(BaseRequestGeneratorConfig): metadata={"help": "Interval generator config for Synthetic Request Generator."}, ) num_requests: int = field( - default=64, + default=128, metadata={"help": "Number of requests for Synthetic Request Generator."}, ) duration: float = field( @@ -442,7 +441,7 @@ class ReplicaConfig: metadata={"help": "Memory margin fraction."}, ) num_pipeline_stages: int = field( - default=1, + default=4, metadata={"help": "Number of pipeline stages."}, ) tensor_parallel_size: int = field( @@ -454,7 +453,7 @@ class ReplicaConfig: metadata={"help": "Device."}, ) network_device: str = field( - default="a100_pair_nvlink", + default="a100_pairwise_nvlink", metadata={"help": "Network device."}, ) @@ -477,11 +476,8 @@ def get_device_config(device_name: str) -> BaseDeviceSKUConfig: @staticmethod def get_node_config(network_device: str) -> BaseNodeSKUConfig: node_configs = get_all_subclasses(BaseNodeSKUConfig) - if network_device not in NETWORK_DEVICE_MAPPING: - raise ValueError(f"Network device not found: {network_device}") - network_type = NETWORK_DEVICE_MAPPING[network_device] for node_config in node_configs: - if node_config.get_type() == network_type: + if str(node_config.get_type()) == network_device: return node_config() raise ValueError(f"Node config not found for network device: {network_device}") @@ -561,7 +557,7 @@ class BaseExecutionTimePredictorConfig(BasePolyConfig): metadata={"help": "Max batch size for prediction."}, ) prediction_max_tokens_per_request: int = field( - default=8192, + default=4096, metadata={"help": "Max tokens per request for prediction."}, ) attention_decode_batching_overhead_fraction: float = field( diff --git a/vidur/config/constants.py b/vidur/config/constants.py deleted file mode 100644 index 6e5bd45..0000000 --- a/vidur/config/constants.py +++ /dev/null @@ -1,9 +0,0 @@ -from vidur.types import NodeSKUType - -NETWORK_DEVICE_MAPPING = { - "a40_pair_nvlink": NodeSKUType.A40_PAIRWISE_NVLINK, - "a100_pair_nvlink": NodeSKUType.A100_PAIRWISE_NVLINK, - "h100_pair_nvlink": NodeSKUType.H100_PAIRWISE_NVLINK, - "a100_dgx": NodeSKUType.A100_DGX, - "h100_dgx": NodeSKUType.H100_DGX, -} From 34a0f007d972787f36a6639abb3e2be4be5f6f41 Mon Sep 17 00:00:00 2001 From: Anmol Agarwal Date: Fri, 26 Jul 2024 22:34:49 -0400 Subject: [PATCH 11/24] storing config to file --- vidur/config/config.py | 2 +- vidur/config/utils.py | 27 ++++++++++++++++++++------- vidur/simulator.py | 2 +- 3 files changed, 22 insertions(+), 9 deletions(-) diff --git a/vidur/config/config.py b/vidur/config/config.py index a7dd144..66eabc8 100644 --- a/vidur/config/config.py +++ b/vidur/config/config.py @@ -1,4 +1,4 @@ -from abc import ABC, abstractmethod +from abc import ABC from dataclasses import dataclass, field from datetime import datetime import json diff --git a/vidur/config/utils.py b/vidur/config/utils.py index 81e358d..e206893 100644 --- a/vidur/config/utils.py +++ b/vidur/config/utils.py @@ -1,4 +1,4 @@ -from dataclasses import asdict +from dataclasses import is_dataclass, fields from typing import Union, get_args, get_origin primitive_types = {int, str, float, bool, type(None)} @@ -60,10 +60,23 @@ def get_inner_type(field_type: type) -> type: def is_subclass(cls, parent: type) -> bool: return hasattr(cls, "__bases__") and parent in cls.__bases__ -def dataclass_to_dict(dataclass_instance): - if isinstance(dataclass_instance, list): - return [dataclass_to_dict(item) for item in dataclass_instance] - elif hasattr(dataclass_instance, '__dataclass_fields__'): - return {k: dataclass_to_dict(v) for k, v in asdict(dataclass_instance).items()} +def dataclass_to_dict(obj): + if isinstance(obj, list): + return [dataclass_to_dict(item) for item in obj] + elif is_dataclass(obj): + data = {} + for field in fields(obj): + value = getattr(obj, field.name) + data[field.name] = dataclass_to_dict(value) + # Include members created in __post_init__ + for key, value in obj.__dict__.items(): + if key not in data: + data[key] = dataclass_to_dict(value) + # Include the name of the class + if hasattr(obj, 'get_type') and callable(getattr(obj, 'get_type')): + data['name'] = str(obj.get_type()) + elif hasattr(obj, 'get_name') and callable(getattr(obj, 'get_name')): + data['name'] = obj.get_name() + return data else: - return dataclass_instance + return obj diff --git a/vidur/simulator.py b/vidur/simulator.py index 311dd30..a1f718c 100644 --- a/vidur/simulator.py +++ b/vidur/simulator.py @@ -56,7 +56,7 @@ def metric_store(self) -> MetricsStore: def run(self) -> None: logger.info( - f"Starting simulation with cluster: {self._cluster} and {len(self._event_queue) - 1} requests" + f"Starting simulation with cluster: {self._cluster} and {len(self._event_queue)} requests" ) while self._event_queue and not self._terminate: From 894e8624d5464ee16a72e0504a4048160bb09d30 Mon Sep 17 00:00:00 2001 From: Anmol Agarwal Date: Sat, 27 Jul 2024 15:19:11 -0400 Subject: [PATCH 12/24] Narrowing configs, removing class properties --- vidur/config/base_fixed_config.py | 30 ++ vidur/config/config.py | 131 ++++----- vidur/config/device_sku_config.py | 4 +- vidur/config/model_config.py | 4 +- vidur/config/node_sku_config.py | 4 +- .../config_explorer/config/config.py | 2 +- vidur/entities/cluster.py | 12 +- vidur/entities/replica.py | 90 +++--- .../base_execution_time_predictor.py | 23 +- ...ear_regression_execution_time_predictor.py | 29 +- ...random_forrest_execution_time_predictor.py | 20 +- .../sklearn_execution_time_predictor.py | 274 ++++++++---------- vidur/metrics/metrics_store.py | 179 +++++------- .../global_scheduler/base_global_scheduler.py | 9 +- .../base_replica_scheduler.py | 31 +- .../lightllm_replica_scheduler.py | 15 +- .../sarathi_replica_scheduler.py | 24 +- .../vllm_replica_scheduler.py | 23 +- vidur/scheduler/utils/memory_planner.py | 6 +- vidur/simulator.py | 18 +- vidur/types/norm_type.py | 1 - vidur/utils/mfu_calculator.py | 11 +- vidur/utils/param_counter.py | 52 ++-- 23 files changed, 437 insertions(+), 555 deletions(-) create mode 100644 vidur/config/base_fixed_config.py diff --git a/vidur/config/base_fixed_config.py b/vidur/config/base_fixed_config.py new file mode 100644 index 0000000..1ab355f --- /dev/null +++ b/vidur/config/base_fixed_config.py @@ -0,0 +1,30 @@ +from abc import ABC +from dataclasses import dataclass +from typing import Any + +from vidur.config.utils import get_all_subclasses + + +@dataclass +class BaseFixedConfig(ABC): + + @classmethod + def create_from_type(cls, type_: Any) -> Any: + for subclass in get_all_subclasses(cls): + if subclass.get_type() == type_: + return subclass() + raise ValueError(f"[{cls.__name__}] Invalid type: {type_}") + + @classmethod + def create_from_name(cls, name: str) -> Any: + for subclass in get_all_subclasses(cls): + if subclass.get_name() == name: + return subclass() + raise ValueError(f"[{cls.__name__}] Invalid name: {name}") + + @classmethod + def create_from_type_string(cls, type_str: str) -> Any: + for subclass in get_all_subclasses(cls): + if str(subclass.get_type()) == type_str: + return subclass() + raise ValueError(f"[{cls.__name__}] Invalid type string: {type_str}") diff --git a/vidur/config/config.py b/vidur/config/config.py index 66eabc8..6936bc0 100644 --- a/vidur/config/config.py +++ b/vidur/config/config.py @@ -10,7 +10,7 @@ from vidur.config.flat_dataclass import create_flat_dataclass from vidur.config.model_config import BaseModelConfig from vidur.config.node_sku_config import BaseNodeSKUConfig -from vidur.config.utils import get_all_subclasses, dataclass_to_dict +from vidur.config.utils import dataclass_to_dict from vidur.logger import init_logger from vidur.types import ReplicaSchedulerType, GlobalSchedulerType, ExecutionTimePredictorType, RequestGeneratorType, RequestIntervalGeneratorType, RequestLengthGeneratorType @@ -425,6 +425,11 @@ class MetricsConfig: metadata={"help": "Maximum batch index."}, ) + def __post_init__(self): + self.output_dir = None + self.num_replicas = None + self.num_pipeline_stages = None + @dataclass class ReplicaConfig: @@ -457,35 +462,13 @@ class ReplicaConfig: metadata={"help": "Network device."}, ) - @staticmethod - def get_model_config(model_name: str) -> BaseModelConfig: - model_configs = get_all_subclasses(BaseModelConfig) - for model_config in model_configs: - if model_config.get_name() == model_name: - return model_config() - return ValueError(f"Model config not found for model name: {model_name}") - - @staticmethod - def get_device_config(device_name: str) -> BaseDeviceSKUConfig: - device_configs = get_all_subclasses(BaseDeviceSKUConfig) - for device_config in device_configs: - if str(device_config.get_type()) == device_name: - return device_config() - raise ValueError(f"Device config not found for device name: {device_name}") - - @staticmethod - def get_node_config(network_device: str) -> BaseNodeSKUConfig: - node_configs = get_all_subclasses(BaseNodeSKUConfig) - for node_config in node_configs: - if str(node_config.get_type()) == network_device: - return node_config() - raise ValueError(f"Node config not found for network device: {network_device}") - def __post_init__(self): self.world_size = self.num_pipeline_stages * self.tensor_parallel_size - self.model_config: BaseModelConfig = ReplicaConfig.get_model_config(self.model_name) - self.device_config: BaseDeviceSKUConfig = ReplicaConfig.get_device_config(self.device) - self.node_config: BaseNodeSKUConfig = ReplicaConfig.get_node_config(self.network_device) + self.model_config: BaseModelConfig = BaseModelConfig.create_from_name(self.model_name) + self.device_config: BaseDeviceSKUConfig = BaseDeviceSKUConfig.create_from_type_string(self.device) + self.node_config: BaseNodeSKUConfig = BaseNodeSKUConfig.create_from_type_string(self.network_device) + + self.max_tokens = None @dataclass @@ -585,6 +568,18 @@ class BaseExecutionTimePredictorConfig(BasePolyConfig): metadata={"help": "Whether to skip CPU overhead modeling."}, ) + def __post_init__(self): + self.num_tensor_parallel_workers = None + self.num_pipeline_stages = None + self.num_layers_per_pipeline_stage = None + self.replica_scheduler_provider = None + self.cache_dir = None + self.block_size = None + self.total_memory_gb = None + self.devices_per_node = None + self.device = None + self.network_device = None + @dataclass class LinearRegressionExecutionTimePredictorConfig(BaseExecutionTimePredictorConfig): @@ -637,10 +632,6 @@ class ClusterConfig: metadata={"help": "Number of replicas."}, ) replica_config: ReplicaConfig = field(default_factory=ReplicaConfig) - execution_time_predictor_config: BaseExecutionTimePredictorConfig = field( - default_factory=RandomForrestExecutionTimePredictorConfig, - metadata={"help": "Execution time predictor config."}, - ) global_scheduler_config: BaseGlobalSchedulerConfig = field( default_factory=RoundRobinGlobalSchedulerConfig, metadata={"help": "Global scheduler config."}, @@ -649,47 +640,10 @@ class ClusterConfig: default_factory=SarathiSchedulerConfig, metadata={"help": "Replica scheduler config."}, ) - metrics_config: MetricsConfig = field( - default_factory=MetricsConfig, - metadata={"help": "Metrics config."}, - ) - - def update_predictor_config(self): - if "{DEVICE}" in self.execution_time_predictor_config.compute_input_file: - self.execution_time_predictor_config.compute_input_file = self.execution_time_predictor_config.compute_input_file.replace( - "{DEVICE}", self.replica_config.device - ) - if "{MODEL}" in self.execution_time_predictor_config.compute_input_file: - self.execution_time_predictor_config.compute_input_file = self.execution_time_predictor_config.compute_input_file.replace( - "{MODEL}", self.replica_config.model_name - ) - if "{DEVICE}" in self.execution_time_predictor_config.attention_input_file: - self.execution_time_predictor_config.attention_input_file = self.execution_time_predictor_config.attention_input_file.replace( - "{DEVICE}", self.replica_config.device - ) - if "{MODEL}" in self.execution_time_predictor_config.attention_input_file: - self.execution_time_predictor_config.attention_input_file = self.execution_time_predictor_config.attention_input_file.replace( - "{MODEL}", self.replica_config.model_name - ) - if "{NETWORK_DEVICE}" in self.execution_time_predictor_config.all_reduce_input_file: - self.execution_time_predictor_config.all_reduce_input_file = self.execution_time_predictor_config.all_reduce_input_file.replace( - "{NETWORK_DEVICE}", self.replica_config.network_device - ) - if "{NETWORK_DEVICE}" in self.execution_time_predictor_config.send_recv_input_file: - self.execution_time_predictor_config.send_recv_input_file = self.execution_time_predictor_config.send_recv_input_file.replace( - "{NETWORK_DEVICE}", self.replica_config.network_device - ) - if "{NETWORK_DEVICE}" in self.execution_time_predictor_config.cpu_overhead_input_file: - self.execution_time_predictor_config.cpu_overhead_input_file = self.execution_time_predictor_config.cpu_overhead_input_file.replace( - "{NETWORK_DEVICE}", self.replica_config.network_device - ) - if "{MODEL}" in self.execution_time_predictor_config.cpu_overhead_input_file: - self.execution_time_predictor_config.cpu_overhead_input_file = self.execution_time_predictor_config.cpu_overhead_input_file.replace( - "{MODEL}", self.replica_config.model_name - ) def __post_init__(self): - self.update_predictor_config() + self.output_dir = None + self.write_json_trace = None @dataclass @@ -722,6 +676,14 @@ class SimulationConfig(ABC): default_factory=SyntheticRequestGeneratorConfig, metadata={"help": "Request generator config."}, ) + execution_time_predictor_config: BaseExecutionTimePredictorConfig = field( + default_factory=RandomForrestExecutionTimePredictorConfig, + metadata={"help": "Execution time predictor config."}, + ) + metrics_config: MetricsConfig = field( + default_factory=MetricsConfig, + metadata={"help": "Metrics config."}, + ) def __post_init__(self): self.output_dir = ( @@ -730,6 +692,11 @@ def __post_init__(self): os.makedirs(self.output_dir, exist_ok=True) self.write_config_to_file() + # Update the config + self.update_cluster_config() + self.update_metrics_config() + self.update_predictor_config() + @classmethod def create_from_cli_args(cls): flat_config = create_flat_dataclass(cls).create_from_cli_args() @@ -748,3 +715,25 @@ def write_config_to_file(self): config_dict = dataclass_to_dict(self) with open(f"{self.output_dir}/config.json", "w") as f: json.dump(config_dict, f, indent=4) + + def update_cluster_config(self): + self.cluster_config.output_dir = self.output_dir + self.cluster_config.replica_config.max_tokens = self.request_generator_config.max_tokens + self.cluster_config.write_json_trace = self.metrics_config.write_json_trace + + def update_metrics_config(self): + self.metrics_config.output_dir = self.output_dir + self.metrics_config.num_replicas = self.cluster_config.num_replicas + self.metrics_config.num_pipeline_stages = self.cluster_config.replica_config.num_pipeline_stages + + def update_predictor_config(self): + self.execution_time_predictor_config.num_tensor_parallel_workers = self.cluster_config.replica_config.tensor_parallel_size + self.execution_time_predictor_config.num_pipeline_stages = self.cluster_config.replica_config.num_pipeline_stages + self.execution_time_predictor_config.num_layers_per_pipeline_stage = self.cluster_config.replica_config.model_config.num_layers // self.cluster_config.replica_config.num_pipeline_stages + self.execution_time_predictor_config.replica_scheduler_provider = str(self.cluster_config.replica_scheduler_config.get_type()) + self.execution_time_predictor_config.cache_dir = f"{self.cache_dir}/execution_time_predictor" + self.execution_time_predictor_config.block_size = self.cluster_config.replica_scheduler_config.block_size + self.execution_time_predictor_config.total_memory_gb = self.cluster_config.replica_config.device_config.total_memory_gb + self.execution_time_predictor_config.devices_per_node = self.cluster_config.replica_config.node_config.num_devices_per_node + self.execution_time_predictor_config.device = self.cluster_config.replica_config.device + self.execution_time_predictor_config.network_device = self.cluster_config.replica_config.network_device diff --git a/vidur/config/device_sku_config.py b/vidur/config/device_sku_config.py index 83e1bb1..415e49d 100644 --- a/vidur/config/device_sku_config.py +++ b/vidur/config/device_sku_config.py @@ -1,6 +1,6 @@ from dataclasses import dataclass, field -from vidur.config.base_poly_config import BasePolyConfig +from vidur.config.base_fixed_config import BaseFixedConfig from vidur.logger import init_logger from vidur.types import DeviceSKUType @@ -8,7 +8,7 @@ @dataclass -class BaseDeviceSKUConfig(BasePolyConfig): +class BaseDeviceSKUConfig(BaseFixedConfig): fp16_tflops: int = field( metadata={"help": "The number of TFLOPS the device can achieve in FP16"}, ) diff --git a/vidur/config/model_config.py b/vidur/config/model_config.py index 6b1c5ec..c866a47 100644 --- a/vidur/config/model_config.py +++ b/vidur/config/model_config.py @@ -1,7 +1,7 @@ from dataclasses import dataclass, field from typing import Any, Dict, Optional -from vidur.config.base_poly_config import BasePolyConfig +from vidur.config.base_fixed_config import BaseFixedConfig from vidur.logger import init_logger from vidur.types import NormType, ActivationType @@ -9,7 +9,7 @@ @dataclass -class BaseModelConfig(BasePolyConfig): +class BaseModelConfig(BaseFixedConfig): num_layers: int = field( metadata={"help": "The number of layers in the model"}, ) diff --git a/vidur/config/node_sku_config.py b/vidur/config/node_sku_config.py index ababc1a..a3324a4 100644 --- a/vidur/config/node_sku_config.py +++ b/vidur/config/node_sku_config.py @@ -1,6 +1,6 @@ from dataclasses import dataclass, field -from vidur.config.base_poly_config import BasePolyConfig +from vidur.config.base_fixed_config import BaseFixedConfig from vidur.logger import init_logger from vidur.types import NodeSKUType, DeviceSKUType @@ -8,7 +8,7 @@ @dataclass -class BaseNodeSKUConfig(BasePolyConfig): +class BaseNodeSKUConfig(BaseFixedConfig): num_devices_per_node: int = field( metadata={"help": "The number of devices per node"}, ) diff --git a/vidur/config_optimizer/config_explorer/config/config.py b/vidur/config_optimizer/config_explorer/config/config.py index f92296b..a9dc263 100644 --- a/vidur/config_optimizer/config_explorer/config/config.py +++ b/vidur/config_optimizer/config_explorer/config/config.py @@ -151,7 +151,7 @@ def to_config_dict(self): "replica_config_tensor_parallel_size": self.num_tensor_parallel_workers, "replica_config_num_pipeline_stages": self.num_pipeline_stages, "vllm_scheduler_config_batch_size_cap": self.batch_size, - "light_l_l_m_scheduler_config_batch_size_cap": self.batch_size, + "lightllm_scheduler_config_batch_size_cap": self.batch_size, "orca_scheduler_config_batch_size_cap": self.batch_size, "faster_transformer_scheduler_config_batch_size_cap": self.batch_size, "sarathi_scheduler_config_batch_size_cap": self.batch_size, diff --git a/vidur/entities/cluster.py b/vidur/entities/cluster.py index 0250c5e..9d4f836 100644 --- a/vidur/entities/cluster.py +++ b/vidur/entities/cluster.py @@ -1,6 +1,6 @@ import json -from vidur.config import SimulationConfig +from vidur.config import ClusterConfig from vidur.entities.base_entity import BaseEntity from vidur.entities.replica import Replica from vidur.logger import init_logger @@ -9,18 +9,18 @@ class Cluster(BaseEntity): - def __init__(self, config: SimulationConfig): + def __init__(self, cluster_config: ClusterConfig) -> None: self._id = Cluster.generate_id() - self._config: SimulationConfig = config + self._config = cluster_config # Init replica object handles self._replicas = {} - for _ in range(self._config.cluster_config.num_replicas): - replica = Replica(config) + for _ in range(self._config.num_replicas): + replica = Replica(self._config.replica_config) self._replicas[replica.id] = replica - if self._config.cluster_config.metrics_config.write_json_trace: + if self._config.write_json_trace: self._write_cluster_info_to_file() @property diff --git a/vidur/entities/replica.py b/vidur/entities/replica.py index 2484574..66b224a 100644 --- a/vidur/entities/replica.py +++ b/vidur/entities/replica.py @@ -1,6 +1,6 @@ from math import ceil -from vidur.config import SimulationConfig +from vidur.config import ReplicaConfig from vidur.entities.base_entity import BaseEntity from vidur.logger import init_logger @@ -8,110 +8,100 @@ class Replica(BaseEntity): - def __init__(self, config: SimulationConfig) -> None: - assert config.cluster_config.replica_config.model_config.num_layers % config.cluster_config.replica_config.num_pipeline_stages == 0 + def __init__(self, replica_config: ReplicaConfig) -> None: + self._id = Replica.generate_id() + + self._replica_config = replica_config + self._model_config = replica_config.model_config + self._device_config = replica_config.device_config + + assert self._model_config.num_layers % self._replica_config.num_pipeline_stages == 0 assert ( - config.cluster_config.replica_config.model_config.embedding_dim % config.cluster_config.replica_config.tensor_parallel_size - == 0 + self._model_config.embedding_dim % self._replica_config.tensor_parallel_size == 0 ) - self._id = Replica.generate_id() - - replica_config = config.cluster_config.replica_config - model_config = replica_config.model_config - - self._num_pipeline_stages = replica_config.num_pipeline_stages - self._num_tensor_parallel_workers = config.cluster_config.replica_config.tensor_parallel_size - self._num_layers = model_config.num_layers - self._num_q_heads = model_config.num_q_heads - self._num_kv_heads = model_config.num_kv_heads - self._embedding_dim = model_config.embedding_dim - self._mlp_hidden_dim = model_config.mlp_hidden_dim - self._use_gated_mlp = model_config.use_gated_mlp - self._vocab_size = model_config.vocab_size - self._total_memory_gb = replica_config.device_config.total_memory_gb - self._memory_margin_fraction = replica_config.memory_margin_fraction - self._max_request_tokens = config.request_generator_config.max_tokens - self._per_device_flops = replica_config.device_config.fp16_tflops * 2**40 + @property + def id(self) -> int: + return self._id @property def num_layers(self) -> int: - return self._num_layers + return self._model_config.num_layers @property def num_q_heads(self) -> int: - return self._num_q_heads + return self._model_config.num_q_heads @property def num_kv_heads(self) -> int: - return self._num_kv_heads + return self._model_config.num_kv_heads @property def embedding_dim(self) -> int: - return self._embedding_dim + return self._model_config.embedding_dim @property def mlp_hidden_dim(self) -> int: - return self._mlp_hidden_dim + return self._model_config.mlp_hidden_dim @property def use_gated_mlp(self) -> int: - return self._use_gated_mlp + return self._model_config.use_gated_mlp @property def vocab_size(self) -> int: - return self._vocab_size + return self._model_config.vocab_size @property def num_pipeline_stages(self) -> int: - return self._num_pipeline_stages + return self._replica_config.num_pipeline_stages @property def num_layers_per_pipeline_stage(self) -> int: - return self._num_layers // self._num_pipeline_stages + return self._model_config.num_layers // self._replica_config.num_pipeline_stages @property def attention_head_dim(self) -> int: - return self._embedding_dim // self._num_q_heads + return self._model_config.embedding_dim // self._model_config.num_q_heads @property def q_heads_per_tensor_parallel_worker(self) -> int: - return self._num_q_heads // self._num_tensor_parallel_workers + return self._model_config.num_q_heads // self._replica_config.tensor_parallel_size @property def kv_heads_per_tensor_parallel_worker(self) -> int: - return ceil(self._num_kv_heads / self._num_tensor_parallel_workers) + return ceil(self._model_config.num_kv_heads / self._replica_config.tensor_parallel_size) @property def num_tensor_parallel_workers(self) -> int: - return self._num_tensor_parallel_workers + return self._replica_config.tensor_parallel_size @property def total_memory_gb(self) -> int: - return self._total_memory_gb + return self._device_config.total_memory_gb @property def memory_margin_fraction(self) -> float: - return self._memory_margin_fraction + return self._replica_config.memory_margin_fraction @property def max_request_tokens(self) -> int: - return self._max_request_tokens + return self._replica_config.max_tokens @property def per_device_flops(self) -> float: - return self._per_device_flops + return self._device_config.fp16_tflops * 2**40 def to_dict(self) -> dict: return { - "id": self._id, - "num_layers": self._num_layers, - "num_q_heads": self._num_q_heads, - "num_kv_heads": self._num_kv_heads, - "embedding_dim": self._embedding_dim, - "mlp_hidden_dim": self._mlp_hidden_dim, - "use_gated_mlp": self._use_gated_mlp, - "vocab_size": self._vocab_size, - "num_pipeline_stages": self._num_pipeline_stages, - "num_tensor_parallel_workers": self._num_tensor_parallel_workers, + "id": self.id, + "num_layers": self.num_layers, + "num_q_heads": self.num_q_heads, + "num_kv_heads": self.num_kv_heads, + "embedding_dim": self.embedding_dim, + "mlp_hidden_dim": self.mlp_hidden_dim, + "use_gated_mlp": self.use_gated_mlp, + "vocab_size": self.vocab_size, + "num_pipeline_stages": self.num_pipeline_stages, + "num_tensor_parallel_workers": self.num_tensor_parallel_workers, } diff --git a/vidur/execution_time_predictor/base_execution_time_predictor.py b/vidur/execution_time_predictor/base_execution_time_predictor.py index 0dfd7c6..7f1744e 100644 --- a/vidur/execution_time_predictor/base_execution_time_predictor.py +++ b/vidur/execution_time_predictor/base_execution_time_predictor.py @@ -1,31 +1,24 @@ from abc import ABC, abstractmethod -from vidur.config import SimulationConfig +from vidur.config import BaseExecutionTimePredictorConfig +from vidur.config.model_config import BaseModelConfig from vidur.entities import Batch, ExecutionTime class BaseExecutionTimePredictor(ABC): - def __init__(self, config: SimulationConfig) -> None: - replica_config = config.cluster_config.replica_config - model_config = replica_config.model_config - - self._num_tensor_parallel_workers = replica_config.tensor_parallel_size - self._num_pipeline_stages = replica_config.num_pipeline_stages - self._num_layers = model_config.num_layers - self._num_layers_per_pipeline_stage = ( - model_config.num_layers // replica_config.num_pipeline_stages - ) - self._replica_scheduler_provider = str(config.cluster_config.replica_scheduler_config.get_type()) + def __init__(self, config: BaseExecutionTimePredictorConfig, model_config: BaseModelConfig) -> None: + self._config = config + self._model_config = model_config def get_execution_time(self, batch: Batch, pipeline_stage: int) -> ExecutionTime: - if pipeline_stage == self._num_pipeline_stages - 1: + if pipeline_stage == self._config.num_pipeline_stages - 1: pipeline_parallel_communication_time = 0 else: pipeline_parallel_communication_time = ( self._get_pipeline_parallel_communication_time(batch) ) - if self._num_tensor_parallel_workers == 1: + if self._config.num_tensor_parallel_workers == 1: tensor_parallel_communication_time = 0 else: tensor_parallel_communication_time = ( @@ -33,7 +26,7 @@ def get_execution_time(self, batch: Batch, pipeline_stage: int) -> ExecutionTime ) return ExecutionTime( - self._num_layers_per_pipeline_stage, + self._config.num_layers_per_pipeline_stage, self._get_attention_rope_execution_time(batch), self._get_attention_kv_cache_save_execution_time(batch), self._get_attention_decode_execution_time(batch), diff --git a/vidur/execution_time_predictor/linear_regression_execution_time_predictor.py b/vidur/execution_time_predictor/linear_regression_execution_time_predictor.py index 6e0c6f7..18cc719 100644 --- a/vidur/execution_time_predictor/linear_regression_execution_time_predictor.py +++ b/vidur/execution_time_predictor/linear_regression_execution_time_predictor.py @@ -2,37 +2,24 @@ from sklearn.pipeline import make_pipeline from sklearn.preprocessing import PolynomialFeatures -from vidur.config import SimulationConfig +from vidur.config import LinearRegressionExecutionTimePredictorConfig +from vidur.config.model_config import BaseModelConfig from vidur.execution_time_predictor.sklearn_execution_time_predictor import ( SklearnExecutionTimePredictor, ) class LinearRegressionExecutionTimePredictor(SklearnExecutionTimePredictor): - def __init__(self, config: SimulationConfig): - predictor_config = config.cluster_config.execution_time_predictor_config - self._polynomial_degree = ( - predictor_config.polynomial_degree - ) - self._polynomial_include_bias = ( - predictor_config.polynomial_include_bias - ) - self._polynomial_interaction_only = ( - predictor_config.polynomial_interaction_only - ) - self._fit_intercept = ( - predictor_config.fit_intercept - ) - + def __init__(self, config: LinearRegressionExecutionTimePredictorConfig, model_config: BaseModelConfig): # will trigger model training - super().__init__(config) + super().__init__(config, model_config) def _get_grid_search_params(self): return { - "polynomialfeatures__degree": self._polynomial_degree, - "polynomialfeatures__include_bias": self._polynomial_include_bias, - "polynomialfeatures__interaction_only": self._polynomial_interaction_only, - "linearregression__fit_intercept": self._fit_intercept, + "polynomialfeatures__degree": self._config.polynomial_degree, + "polynomialfeatures__include_bias": self._config.polynomial_include_bias, + "polynomialfeatures__interaction_only": self._config.polynomial_interaction_only, + "linearregression__fit_intercept": self._config.fit_intercept, } def _get_estimator(self): diff --git a/vidur/execution_time_predictor/random_forrest_execution_time_predictor.py b/vidur/execution_time_predictor/random_forrest_execution_time_predictor.py index 4f8736e..1e7d664 100644 --- a/vidur/execution_time_predictor/random_forrest_execution_time_predictor.py +++ b/vidur/execution_time_predictor/random_forrest_execution_time_predictor.py @@ -1,28 +1,22 @@ from sklearn.ensemble import RandomForestRegressor -from vidur.config import SimulationConfig +from vidur.config import RandomForrestExecutionTimePredictorConfig +from vidur.config.model_config import BaseModelConfig from vidur.execution_time_predictor.sklearn_execution_time_predictor import ( SklearnExecutionTimePredictor, ) class RandomForrestExecutionTimePredictor(SklearnExecutionTimePredictor): - def __init__(self, config: SimulationConfig): - predictor_config = config.cluster_config.execution_time_predictor_config - self._num_estimators = ( - predictor_config.num_estimators - ) - self._max_depth = predictor_config.max_depth - self._min_samples_split = predictor_config.min_samples_split - + def __init__(self, config: RandomForrestExecutionTimePredictorConfig, model_config: BaseModelConfig): # will trigger model training - super().__init__(config) + super().__init__(config, model_config) def _get_grid_search_params(self): return { - "n_estimators": self._num_estimators, - "max_depth": self._max_depth, - "min_samples_split": self._min_samples_split, + "n_estimators": self._config.num_estimators, + "max_depth": self._config.max_depth, + "min_samples_split": self._config.min_samples_split, } def _get_estimator(self): diff --git a/vidur/execution_time_predictor/sklearn_execution_time_predictor.py b/vidur/execution_time_predictor/sklearn_execution_time_predictor.py index 0d19c24..b358140 100644 --- a/vidur/execution_time_predictor/sklearn_execution_time_predictor.py +++ b/vidur/execution_time_predictor/sklearn_execution_time_predictor.py @@ -12,7 +12,8 @@ from sklearn.metrics import make_scorer from sklearn.model_selection import GridSearchCV -from vidur.config import SimulationConfig +from vidur.config import BaseExecutionTimePredictorConfig +from vidur.config.model_config import BaseModelConfig from vidur.entities import Batch from vidur.execution_time_predictor.base_execution_time_predictor import ( BaseExecutionTimePredictor, @@ -23,131 +24,86 @@ class SklearnExecutionTimePredictor(BaseExecutionTimePredictor): - def __init__(self, config: SimulationConfig) -> None: - super().__init__(config) - - self._cache_dir = f"{config.cache_dir}/execution_time_predictor" - os.makedirs(self._cache_dir, exist_ok=True) - - predictor_config = config.cluster_config.execution_time_predictor_config - model_config = config.cluster_config.replica_config.model_config - - self._no_cache = predictor_config.no_cache - - self._k_fold_cv_splits = ( - predictor_config.k_fold_cv_splits - ) - self._model_name = model_config.get_name() - self._num_q_heads = model_config.num_q_heads - self._num_kv_heads = model_config.num_kv_heads - self._embedding_dim = model_config.embedding_dim - self._mlp_hidden_dim = model_config.mlp_hidden_dim - self._use_gated_mlp = model_config.use_gated_mlp - self._vocab_size = model_config.vocab_size - self._block_size = config.cluster_config.replica_scheduler_config.block_size - self._norm = model_config.norm - self._post_attn_norm = model_config.post_attn_norm - - self._model_provider = str(config.cluster_config.execution_time_predictor_config.get_type()) + def __init__(self, config: BaseExecutionTimePredictorConfig, model_config: BaseModelConfig) -> None: + super().__init__(config, model_config) + os.makedirs(self._config.cache_dir, exist_ok=True) # These overheads are only for GQA models self._attention_prefill_batching_overhead_fraction = ( ( - predictor_config.attention_prefill_batching_overhead_fraction + self._config.attention_prefill_batching_overhead_fraction ) - if self._num_q_heads > self._num_kv_heads + if self._model_config.num_q_heads > self._model_config.num_kv_heads else 0 ) self._attention_decode_batching_overhead_fraction = ( ( - predictor_config.attention_decode_batching_overhead_fraction + self._config.attention_decode_batching_overhead_fraction ) - if self._num_q_heads > self._num_kv_heads + if self._model_config.num_q_heads > self._model_config.num_kv_heads else 0 ) - self._nccl_cpu_launch_overhead_ms = ( - predictor_config.nccl_cpu_launch_overhead_ms - ) - self._nccl_cpu_skew_overhead_per_device_ms = ( - predictor_config.nccl_cpu_skew_overhead_per_device_ms - ) - - self._max_batch_size = ( - predictor_config.prediction_max_batch_size - ) - self._max_tokens_per_request = ( - predictor_config.prediction_max_tokens_per_request - ) - - if self._replica_scheduler_provider == "orca": - self._max_tokens = self._max_tokens_per_request * self._max_batch_size + if self._config.replica_scheduler_provider == "orca": + self._max_tokens = self._config.prediction_max_tokens_per_request * self._config.prediction_max_batch_size else: - self._max_tokens = self._max_tokens_per_request - - self._compute_input_file = ( - predictor_config.compute_input_file - ) - self._attention_input_file = ( - predictor_config.attention_input_file - ) - self._all_reduce_input_file = ( - predictor_config.all_reduce_input_file - ) - self._send_recv_input_file = ( - predictor_config.send_recv_input_file - ) - self._cpu_overhead_input_file = ( - predictor_config.cpu_overhead_input_file - ) - self._kv_cache_prediction_granularity = ( - predictor_config.kv_cache_prediction_granularity - ) - self._prediction_max_prefill_chunk_size = ( - predictor_config.prediction_max_prefill_chunk_size - ) + self._max_tokens = self._config.prediction_max_tokens_per_request - self._device_memory = config.cluster_config.replica_config.device_config.total_memory_gb - self._num_training_job_threads = ( - predictor_config.num_training_job_threads - ) - - devices_per_node = config.cluster_config.replica_config.node_config.num_devices_per_node - num_workers = self._num_pipeline_stages * self._num_tensor_parallel_workers + num_workers = self._config.num_pipeline_stages * self._config.num_tensor_parallel_workers assert ( - num_workers < devices_per_node or num_workers % devices_per_node == 0 + num_workers < self._config.devices_per_node or num_workers % self._config.devices_per_node == 0 ), "Number of workers should be less than devices per node or a multiple of devices per node" - self._is_multi_node = num_workers > devices_per_node + self._is_multi_node = num_workers > self._config.devices_per_node - self._skip_cpu_overhead_modeling = ( - predictor_config.skip_cpu_overhead_modeling - ) + ( + self._compute_input_file, + self._attention_input_file, + self._all_reduce_input_file, + self._send_recv_input_file, + self._cpu_overhead_input_file + ) = self._get_input_files() self._models = self._train_models() self._predictions = self._predict_from_models() + def _get_input_files(self) -> Tuple[str, str, str, str, str]: + input_files = [ + self._config.compute_input_file, + self._config.attention_input_file, + self._config.all_reduce_input_file, + self._config.send_recv_input_file, + self._config.cpu_overhead_input_file, + ] + for i in range(len(input_files)): + input_files[i] = input_files[i].replace( + "{DEVICE}", self._config.device).replace( + "{MODEL}", self._model_config.get_name()).replace( + "{NETWORK_DEVICE}", self._config.network_device) + + return tuple(input_files) + def _load_compute_df(self, file_path: str) -> pd.DataFrame: df = self._read_input_file(file_path) df = df.drop_duplicates() logger.debug(f"Length of complete compute df: {len(df)} {file_path}") - logger.debug(f"self._num_q_heads: {self._num_q_heads}") - logger.debug(f"self._embedding_dim: {self._embedding_dim}") - logger.debug(f"self._mlp_hidden_dim: {self._mlp_hidden_dim}") - logger.debug(f"self._use_gated_mlp: {self._use_gated_mlp}") - logger.debug(f"self._vocab_size: {self._vocab_size}") + logger.debug(f"self._num_q_heads: {self._model_config.num_q_heads}") + logger.debug(f"self._embedding_dim: {self._model_config.embedding_dim}") + logger.debug(f"self._mlp_hidden_dim: {self._model_config.mlp_hidden_dim}") + logger.debug(f"self._use_gated_mlp: {self._model_config.use_gated_mlp}") + logger.debug(f"self._vocab_size: {self._model_config.vocab_size}") logger.debug( - f"self._num_tensor_parallel_workers: {self._num_tensor_parallel_workers}" + f"self._num_tensor_parallel_workers: {self._config.num_tensor_parallel_workers}" ) df = df[ - (df["n_head"] == self._num_q_heads) - & (df["n_kv_head"] == self._num_kv_heads) - & (df["n_embd"] == self._embedding_dim) - & (df["n_expanded_embd"] == self._mlp_hidden_dim) - & (df["use_gated_mlp"] == self._use_gated_mlp) - & (df["vocab_size"] == self._vocab_size) - & (df["num_tensor_parallel_workers"] == self._num_tensor_parallel_workers) + (df["n_head"] == self._model_config.num_q_heads) + & (df["n_kv_head"] == self._model_config.num_kv_heads) + & (df["n_embd"] == self._model_config.embedding_dim) + & (df["n_expanded_embd"] == self._model_config.mlp_hidden_dim) + & (df["use_gated_mlp"] == self._model_config.use_gated_mlp) + & (df["vocab_size"] == self._model_config.vocab_size) + & (df["num_tensor_parallel_workers"] == self._config.num_tensor_parallel_workers) ] for column in [ @@ -174,18 +130,18 @@ def _load_attention_df(self, file_path: str) -> pd.DataFrame: df.fillna({column: 0}, inplace=True) return df[ - (df["n_embd"] == self._embedding_dim) - & (df["n_q_head"] == self._num_q_heads) - & (df["n_kv_head"] == self._num_kv_heads) - & (df["block_size"] == self._block_size) - & (df["num_tensor_parallel_workers"] == self._num_tensor_parallel_workers) + (df["n_embd"] == self._model_config.embedding_dim) + & (df["n_q_head"] == self._model_config.num_q_heads) + & (df["n_kv_head"] == self._model_config.num_kv_heads) + & (df["block_size"] == self._config.block_size) + & (df["num_tensor_parallel_workers"] == self._config.num_tensor_parallel_workers) ] def _load_all_reduce_df(self, file_path: str) -> pd.DataFrame: df = self._read_input_file(file_path) return df[ - (df["num_workers"] == self._num_tensor_parallel_workers) - & (df["devices_per_node"] == self._num_tensor_parallel_workers) + (df["num_workers"] == self._config.num_tensor_parallel_workers) + & (df["devices_per_node"] == self._config.num_tensor_parallel_workers) & (df["collective"] == "all_reduce") ] @@ -205,8 +161,8 @@ def _load_send_recv_df(self, file_path: str) -> pd.DataFrame: def _load_cpu_overhead_df(self, file_path: str) -> pd.DataFrame: df = self._read_input_file(file_path) filtered_df = df[ - (df["model_name"] == self._model_name) - & (df["tensor_parallel_degree"] == self._num_tensor_parallel_workers) + (df["model_name"] == self._model_config.get_name()) + & (df["tensor_parallel_degree"] == self._config.num_tensor_parallel_workers) ] return filtered_df @@ -239,14 +195,14 @@ def _get_all_reduce_df_with_derived_features( # convert bytes to num tokens # each token is of size 2 * h bytes df_with_derived_features["num_tokens"] = ( - df_with_derived_features["size"] / self._embedding_dim / 2 + df_with_derived_features["size"] / self._model_config.embedding_dim / 2 ) return df_with_derived_features def _get_send_recv_df_with_derived_features(self, df: pd.DataFrame) -> pd.DataFrame: df_with_derived_features = df.copy() df_with_derived_features["num_tokens"] = ( - df_with_derived_features["size"] / self._embedding_dim / 2 + df_with_derived_features["size"] / self._model_config.embedding_dim / 2 ) return df_with_derived_features @@ -306,12 +262,12 @@ def _get_model_hash(self, model_name: str, df: pd.DataFrame = None) -> str: def _load_model_from_cache(self, model_name: str, model_hash: str) -> BaseEstimator: with InterProcessReaderWriterLock( - f"{self._cache_dir}/{model_hash}_model_lock.file" + f"{self._config.cache_dir}/{model_hash}_model_lock.file" ).read_lock(): - if self._no_cache: + if self._config.no_cache: return # check if model is in cache - cache_file = f"{self._cache_dir}/{model_name}_{model_hash}.pkl" + cache_file = f"{self._config.cache_dir}/{model_name}_{model_hash}.pkl" if not os.path.exists(cache_file): return @@ -323,10 +279,10 @@ def _store_model_in_cache( self, model_name: str, model_hash: str, model: BaseEstimator ) -> None: with InterProcessReaderWriterLock( - f"{self._cache_dir}/{model_hash}_model_lock.file" + f"{self._config.cache_dir}/{model_hash}_model_lock.file" ).write_lock(): # store model in cache - cache_file = f"{self._cache_dir}/{model_name}_{model_hash}.pkl" + cache_file = f"{self._config.cache_dir}/{model_name}_{model_hash}.pkl" pickle.dump(model, open(cache_file, "wb"), protocol=pickle.HIGHEST_PROTOCOL) def _store_training_prediction_data( @@ -345,7 +301,7 @@ def _store_training_prediction_data( # store the prediction data df[feature_cols + [target_col, "prediction"]].to_csv( - f"{self._cache_dir}/{model_name}_{model_hash}_training_predictions.csv", + f"{self._config.cache_dir}/{model_name}_{model_hash}_training_predictions.csv", index=False, ) @@ -368,17 +324,17 @@ def _train_model( model = self._get_estimator() grid_search_params = self._get_grid_search_params() - if len(df) < self._k_fold_cv_splits: + if len(df) < self._config.k_fold_cv_splits: cv = 2 else: - cv = self._k_fold_cv_splits + cv = self._config.k_fold_cv_splits grid_search = GridSearchCV( estimator=model, param_grid=grid_search_params, scoring=self._get_scorer(), cv=cv, - n_jobs=self._num_training_job_threads, + n_jobs=self._config.num_training_job_threads, ) # we don't create a train/test split, because we want to use all data for training @@ -409,9 +365,9 @@ def _store_model_predication_cache( self, model_name: str, model_hash: str, predictions: Dict[Tuple, float] ) -> None: with InterProcessReaderWriterLock( - f"{self._cache_dir}/{model_hash}_prediction_lock.file" + f"{self._config.cache_dir}/{model_hash}_prediction_lock.file" ).write_lock(): - cache_file = f"{self._cache_dir}/{model_name}_{model_hash}_predictions.pkl" + cache_file = f"{self._config.cache_dir}/{model_name}_{model_hash}_predictions.pkl" pickle.dump( predictions, open(cache_file, "wb"), protocol=pickle.HIGHEST_PROTOCOL ) @@ -420,11 +376,11 @@ def _load_model_predication_cache( self, model_name: str, model_hash: str ) -> Dict[Tuple, float]: with InterProcessReaderWriterLock( - f"{self._cache_dir}/{model_hash}_prediction_lock.file" + f"{self._config.cache_dir}/{model_hash}_prediction_lock.file" ).read_lock(): - if self._no_cache: + if self._config.no_cache: return - cache_file = f"{self._cache_dir}/{model_name}_{model_hash}_predictions.pkl" + cache_file = f"{self._config.cache_dir}/{model_name}_{model_hash}_predictions.pkl" if not os.path.exists(cache_file): return @@ -457,7 +413,7 @@ def _get_model_prediction( X["prediction"] = predictions_array X.to_csv( - f"{self._cache_dir}/{model_name}_{model_hash}_predictions.csv", + f"{self._config.cache_dir}/{model_name}_{model_hash}_predictions.csv", index=False, ) @@ -506,7 +462,7 @@ def _train_compute_models(self) -> Dict[str, BaseEstimator]: target_col=f"time_stats.{model_name}.median", ) - if self._num_pipeline_stages > 1: + if self._config.num_pipeline_stages > 1: send_recv_df = self._load_send_recv_df(self._send_recv_input_file) send_recv_df = self._get_send_recv_df_with_derived_features(send_recv_df) @@ -517,7 +473,7 @@ def _train_compute_models(self) -> Dict[str, BaseEstimator]: target_col="time_stats.send_recv.median", ) - if self._num_tensor_parallel_workers > 1: + if self._config.num_tensor_parallel_workers > 1: all_reduce_df = self._load_all_reduce_df(self._all_reduce_input_file) all_reduce_df = self._get_all_reduce_df_with_derived_features(all_reduce_df) @@ -531,7 +487,7 @@ def _train_compute_models(self) -> Dict[str, BaseEstimator]: return models def _train_cpu_overhead_models(self) -> Dict[str, BaseEstimator]: - if self._skip_cpu_overhead_modeling: + if self._config.skip_cpu_overhead_modeling: return {} models = {} @@ -616,10 +572,10 @@ def _predict_for_compute_models(self) -> Dict[str, Any]: "add", ] - if self._num_pipeline_stages > 1: + if self._config.num_pipeline_stages > 1: model_names.append("send_recv") - if self._num_tensor_parallel_workers > 1: + if self._config.num_tensor_parallel_workers > 1: model_names.append("all_reduce") num_token_range = np.arange(1, self._max_tokens + 1) @@ -632,7 +588,7 @@ def _predict_for_compute_models(self) -> Dict[str, Any]: return predictions def _predict_for_cpu_overhead_models(self) -> Dict[str, Any]: - if self._skip_cpu_overhead_modeling: + if self._config.skip_cpu_overhead_modeling: return {} predictions = {} @@ -645,7 +601,7 @@ def _predict_for_cpu_overhead_models(self) -> Dict[str, Any]: "ray_comm_time", ] - batch_size_range = np.arange(1, self._max_batch_size + 1) + batch_size_range = np.arange(1, self._config.prediction_max_batch_size + 1) X = pd.DataFrame({"batch_size": batch_size_range}) for model_name in model_names: @@ -657,9 +613,9 @@ def _predict_for_cpu_overhead_models(self) -> Dict[str, Any]: def _predict_for_attention_layer_models(self) -> Dict[str, Any]: predictions = {} - decode_batch_size_range = np.arange(1, self._max_batch_size + 1) + decode_batch_size_range = np.arange(1, self._config.prediction_max_batch_size + 1) decode_kv_cache_size_range = np.arange( - 0, self._max_tokens_per_request + 1, self._kv_cache_prediction_granularity + 0, self._config.prediction_max_tokens_per_request + 1, self._config.kv_cache_prediction_granularity ) decode_prefill_chunk_size_range = [0] decode_batch_size, decode_kv_cache_size, decode_prefill_chunk_size = zip( @@ -672,10 +628,10 @@ def _predict_for_attention_layer_models(self) -> Dict[str, Any]: prefill_batch_size_range = [1] prefill_kv_cache_size_range = np.arange( - 0, self._max_tokens_per_request + 1, self._kv_cache_prediction_granularity + 0, self._config.prediction_max_tokens_per_request + 1, self._config.kv_cache_prediction_granularity ) prefill_prefill_chunk_size_range = np.arange( - 1, self._prediction_max_prefill_chunk_size + 1 + 1, self._config.prediction_max_prefill_chunk_size + 1 ) prefill_batch_size, prefill_kv_cache_size, prefill_prefill_chunk_size = zip( *product( @@ -748,9 +704,9 @@ def _get_batch_decode_attention_params(self, batch: Batch) -> Tuple[int, int]: decode_batch_size = len(decode_kv_cache_sizes) decode_avg_kv_cache_size = int(np.mean(decode_kv_cache_sizes)) decode_avg_kv_cache_size = ( - (decode_avg_kv_cache_size + self._kv_cache_prediction_granularity - 1) - // self._kv_cache_prediction_granularity - ) * self._kv_cache_prediction_granularity + (decode_avg_kv_cache_size + self._config.kv_cache_prediction_granularity - 1) + // self._config.kv_cache_prediction_granularity + ) * self._config.kv_cache_prediction_granularity batch._decode_params = (decode_batch_size, decode_avg_kv_cache_size) @@ -772,11 +728,11 @@ def _get_batch_prefill_attention_params( kv_cache_size = ( ( request.num_processed_tokens - + self._kv_cache_prediction_granularity + + self._config.kv_cache_prediction_granularity - 1 ) - // self._kv_cache_prediction_granularity - ) * self._kv_cache_prediction_granularity + // self._config.kv_cache_prediction_granularity + ) * self._config.kv_cache_prediction_granularity prefill_params.append((kv_cache_size, prefill_chunk_size)) @@ -803,7 +759,7 @@ def _get_attn_norm_layer_act_execution_time(self, batch: Batch) -> float: return self._predictions["input_layernorm"][(batch._total_num_tokens_rounded,)] def _get_mlp_norm_layer_act_execution_time(self, batch: Batch) -> float: - if not self._post_attn_norm: + if not self._model_config.post_attn_norm: return 0 return self._predictions["post_attention_layernorm"][ @@ -816,9 +772,9 @@ def _get_add_layer_act_execution_time(self, batch: Batch) -> float: def _get_tensor_parallel_communication_time(self, batch: Batch) -> float: return ( self._predictions["all_reduce"][(batch._total_num_tokens_rounded,)] - + self._nccl_cpu_launch_overhead_ms - + self._nccl_cpu_skew_overhead_per_device_ms - * self._num_tensor_parallel_workers**1.25 + + self._config.nccl_cpu_launch_overhead_ms + + self._config.nccl_cpu_skew_overhead_per_device_ms + * self._config.num_tensor_parallel_workers**1.25 ) def _get_pipeline_parallel_communication_time(self, batch: Batch) -> float: @@ -874,52 +830,52 @@ def _get_attention_prefill_execution_time(self, batch: Batch) -> float: ) def _get_schedule_time(self, batch: Batch) -> float: - if self._skip_cpu_overhead_modeling: + if self._config.skip_cpu_overhead_modeling: return 0 return self._predictions["schedule"][(batch.size,)] def _get_sampler_e2e_time(self, batch: Batch) -> float: - if self._skip_cpu_overhead_modeling: + if self._config.skip_cpu_overhead_modeling: return 0 return self._predictions["sampler_e2e"][(batch.size,)] def _get_prepare_inputs_e2e_time(self, batch: Batch) -> float: - if self._skip_cpu_overhead_modeling: + if self._config.skip_cpu_overhead_modeling: return 0 return self._predictions["prepare_inputs_e2e"][(batch.size,)] def _get_process_model_outputs_time(self, batch: Batch) -> float: - if self._skip_cpu_overhead_modeling: + if self._config.skip_cpu_overhead_modeling: return 0 return self._predictions["process_model_outputs"][(batch.size,)] def _get_ray_comm_time(self, batch: Batch) -> float: - if self._skip_cpu_overhead_modeling: + if self._config.skip_cpu_overhead_modeling: return 0 return self._predictions["ray_comm_time"][(batch.size,)] def to_dict(self) -> dict: return { - "model_provider": self._model_provider, - "num_tensor_parallel_workers": self._num_tensor_parallel_workers, - "k_fold_cv_splits": self._k_fold_cv_splits, - "num_q_heads": self._num_q_heads, - "num_kv_heads": self._num_kv_heads, - "embedding_dim": self._embedding_dim, - "mlp_hidden_dim": self._mlp_hidden_dim, - "use_gated_mlp": self._use_gated_mlp, - "vocab_size": self._vocab_size, - "block_size": self._block_size, + "model_provider": str(self._config.get_type()), + "num_tensor_parallel_workers": self._config.num_tensor_parallel_workers, + "k_fold_cv_splits": self._config.k_fold_cv_splits, + "num_q_heads": self._model_config.num_q_heads, + "num_kv_heads": self._model_config.num_kv_heads, + "embedding_dim": self._model_config.embedding_dim, + "mlp_hidden_dim": self._model_config.mlp_hidden_dim, + "use_gated_mlp": self._model_config.use_gated_mlp, + "vocab_size": self._model_config.vocab_size, + "block_size": self._config.block_size, "max_tokens": self._max_tokens, "compute_input_file": self._compute_input_file, "all_reduce_input_file": self._all_reduce_input_file, "send_recv_input_file": self._send_recv_input_file, "cpu_overhead_input_file": self._cpu_overhead_input_file, - "prediction_max_prefill_chunk_size": self._prediction_max_prefill_chunk_size, - "max_batch_size": self._max_batch_size, + "prediction_max_prefill_chunk_size": self._config.prediction_max_prefill_chunk_size, + "max_batch_size": self._config.prediction_max_batch_size, } diff --git a/vidur/metrics/metrics_store.py b/vidur/metrics/metrics_store.py index 0add146..c0b3860 100644 --- a/vidur/metrics/metrics_store.py +++ b/vidur/metrics/metrics_store.py @@ -6,7 +6,7 @@ import plotly_express as px import wandb -from vidur.config import SimulationConfig, MetricsConfig +from vidur.config import MetricsConfig, ReplicaConfig from vidur.entities import Batch, BatchStage, ExecutionTime, Request from vidur.logger import init_logger from vidur.metrics.cdf_sketch import CDFSketch @@ -30,7 +30,7 @@ def if_write_metrics(func): def wrapper(self, *args, **kwargs): - if self._should_write_metrics: + if self._config.write_metrics: return func(self, *args, **kwargs) return wrapper @@ -49,40 +49,9 @@ def wrapper(self, *args, **kwargs): class MetricsStore: - def __init__(self, config: SimulationConfig): - self._config: SimulationConfig = config - metrics_config: MetricsConfig = config.cluster_config.metrics_config - - self._num_replicas = config.cluster_config.num_replicas - self._num_stages = config.cluster_config.replica_config.num_pipeline_stages - self._should_write_metrics = metrics_config.write_metrics - self._subsamples = metrics_config.subsamples - self._save_table_to_wandb = metrics_config.save_table_to_wandb - self._save_plots = metrics_config.store_plots - self._keep_individual_batch_metrics = ( - metrics_config.keep_individual_batch_metrics - ) - - self._wandb_project = metrics_config.wandb_project - self._wandb_group = metrics_config.wandb_group - self._wandb_run_name = metrics_config.wandb_run_name - - self._min_batch_idx = metrics_config.min_batch_index - self._max_batch_idx = metrics_config.max_batch_index - + def __init__(self, config: MetricsConfig, replica_config: ReplicaConfig) -> None: + self._config = config self._last_request_arrived_at = None - self._should_store_token_completion_metrics = ( - metrics_config.store_token_completion_metrics - ) - self._should_store_utilization_metrics = ( - metrics_config.store_utilization_metrics - ) - self._should_store_batch_metrics = metrics_config.store_batch_metrics - self._should_store_operation_metrics = ( - metrics_config.store_operation_metrics - ) - self._should_store_request_metrics = metrics_config.store_request_metrics - # Initialise request metrics self._request_metrics_time_distributions: Dict[ RequestMetricsTimeDistributions, DataSeries @@ -91,9 +60,9 @@ def __init__(self, config: SimulationConfig): self._request_metrics_time_distributions[metric_name] = DataSeries( REQUEST_ID_STR, metric_name.value, - self._subsamples, - self._save_table_to_wandb, - self._save_plots, + self._config.subsamples, + self._config.save_table_to_wandb, + self._config.store_plots, ) self._token_metrics_time_distribution: Dict[ @@ -102,8 +71,8 @@ def __init__(self, config: SimulationConfig): for metric_name in TokenMetricsTimeDistribution: self._token_metrics_time_distribution[metric_name] = CDFSketch( metric_name.value, - self._save_table_to_wandb, - self._save_plots, + self._config.save_table_to_wandb, + self._config.store_plots, ) self._request_metrics_histogram: Dict[RequestMetricsHistogram, DataSeries] = {} @@ -111,9 +80,9 @@ def __init__(self, config: SimulationConfig): self._request_metrics_histogram[metric_name] = DataSeries( REQUEST_ID_STR, metric_name.value, - self._subsamples, - self._save_table_to_wandb, - self._save_plots, + self._config.subsamples, + self._config.save_table_to_wandb, + self._config.store_plots, ) # Initialise batch metrics @@ -126,15 +95,15 @@ def __init__(self, config: SimulationConfig): for metric_name in BatchMetricsCountDistribution: self._batch_metrics_count_distribution[metric_name] = CDFSketch( metric_name.value, - self._save_table_to_wandb, - self._save_plots, + self._config.save_table_to_wandb, + self._config.store_plots, ) self._batch_metrics_count_distribution_per_batch[metric_name] = DataSeries( BATCH_ID_STR, metric_name.value, - self._subsamples, - self._save_table_to_wandb, - self._save_plots, + self._config.subsamples, + self._config.save_table_to_wandb, + self._config.store_plots, ) self._batch_metrics_time_distribution: Dict[ @@ -146,15 +115,15 @@ def __init__(self, config: SimulationConfig): for metric_name in BatchMetricsTimeDistribution: self._batch_metrics_time_distribution[metric_name] = CDFSketch( metric_name.value, - self._save_table_to_wandb, - self._save_plots, + self._config.save_table_to_wandb, + self._config.store_plots, ) self._batch_metrics_time_distribution_per_batch[metric_name] = DataSeries( BATCH_ID_STR, metric_name.value, - self._subsamples, - self._save_table_to_wandb, - self._save_plots, + self._config.subsamples, + self._config.save_table_to_wandb, + self._config.store_plots, ) # Initialise completion metrics @@ -165,9 +134,9 @@ def __init__(self, config: SimulationConfig): self._request_completion_metrics_time_series[metric_name] = DataSeries( TIME_STR, metric_name.value, - self._subsamples, - self._save_table_to_wandb, - self._save_plots, + self._config.subsamples, + self._config.save_table_to_wandb, + self._config.store_plots, ) self._token_completion_metrics_time_series: Dict[ TokenCompletionMetricsTimeSeries, DataSeries @@ -176,9 +145,9 @@ def __init__(self, config: SimulationConfig): self._token_completion_metrics_time_series[metric_name] = DataSeries( TIME_STR, metric_name.value, - self._subsamples, - self._save_table_to_wandb, - self._save_plots, + self._config.subsamples, + self._config.save_table_to_wandb, + self._config.store_plots, ) # Initialise operation metrics @@ -187,15 +156,15 @@ def __init__(self, config: SimulationConfig): for metric_name in OperationMetrics: self._operation_metrics[metric_name] = CDFSketch( metric_name.value, - self._save_table_to_wandb, - self._save_plots, + self._config.save_table_to_wandb, + self._config.store_plots, ) self._operation_metrics_per_batch[metric_name] = DataSeries( BATCH_ID_STR, metric_name.value, - self._subsamples, - self._save_table_to_wandb, - self._save_plots, + self._config.subsamples, + self._config.save_table_to_wandb, + self._config.store_plots, ) self._cpu_operation_metrics: Dict[CpuOperationMetrics, CDFSketch] = {} @@ -205,15 +174,15 @@ def __init__(self, config: SimulationConfig): for metric_name in CpuOperationMetrics: self._cpu_operation_metrics[metric_name] = CDFSketch( metric_name.value, - self._save_table_to_wandb, - self._save_plots, + self._config.save_table_to_wandb, + self._config.store_plots, ) self._cpu_operation_metrics_per_batch[metric_name] = DataSeries( BATCH_ID_STR, metric_name.value, - self._subsamples, - self._save_table_to_wandb, - self._save_plots, + self._config.subsamples, + self._config.save_table_to_wandb, + self._config.store_plots, ) # per replica metrics @@ -221,14 +190,14 @@ def __init__(self, config: SimulationConfig): # per replica stage metrics self._replica_busy_time = [] self._replica_mfu = [] - self._mfu_calculator = MFUCalculator(config) + self._mfu_calculator = MFUCalculator(replica_config) - for replica_idx in range(self._num_replicas): + for replica_idx in range(self._config.num_replicas): self._replica_memory_usage.append( SeriesAverageMeter( TIME_STR, MEMORY_USAGE_STR, - self._save_table_to_wandb, + self._config.save_table_to_wandb, ) ) self._replica_memory_usage[replica_idx].put(0, 0) @@ -236,12 +205,12 @@ def __init__(self, config: SimulationConfig): self._replica_busy_time.append([]) self._replica_mfu.append([]) - for stage_idx in range(self._num_stages): + for stage_idx in range(self._config.num_pipeline_stages): self._replica_busy_time[replica_idx].append( SeriesAverageMeter( TIME_STR, BUSY_TIME_PERCENT, - save_table_to_wandb=self._save_table_to_wandb, + save_table_to_wandb=self._config.save_table_to_wandb, ) ) self._replica_busy_time[replica_idx][stage_idx].put(0, 0) @@ -250,7 +219,7 @@ def __init__(self, config: SimulationConfig): SeriesAverageMeter( TIME_STR, UTILIZATION_STR, - save_table_to_wandb=self._save_table_to_wandb, + save_table_to_wandb=self._config.save_table_to_wandb, ) ) self._replica_mfu[replica_idx][stage_idx].put(0, 0) @@ -259,16 +228,16 @@ def __init__(self, config: SimulationConfig): def _init_wandb(self): if ( - not self._should_write_metrics - or not self._wandb_project - or not self._wandb_group + not self._config.write_metrics + or not self._config.wandb_project + or not self._config.wandb_group ): return wandb.init( - project=self._wandb_project, - group=self._wandb_group, - name=self._wandb_run_name, + project=self._config.wandb_project, + group=self._config.wandb_group, + name=self._config.wandb_run_name, config=self._config.to_dict(), ) @@ -286,7 +255,7 @@ def _save_as_csv( [dataseries._to_df() for dataseries in dataseries_list], ) merged_df.to_csv(f"{base_path}/{file_name}.csv", index=False) - if wandb.run and self._save_table_to_wandb: + if wandb.run and self._config.save_table_to_wandb: wand_table = wandb.Table(dataframe=merged_df) wandb.log({f"{file_name}_table": wand_table}, step=0) @@ -314,7 +283,7 @@ def _store_bar_plot( }, step=0, ) - if self._save_plots: + if self._config.store_plots: fig = px.bar( x=list(data.keys()), y=list(data.values()), @@ -323,7 +292,7 @@ def _store_bar_plot( fig.write_image(f"{base_path}/{plot_name}.png") def _store_operation_metrics(self, base_plot_path: str): - if not self._should_store_operation_metrics: + if not self._config.store_operation_metrics: return total_operation_runtimes: Dict[str, float] = {} @@ -350,7 +319,7 @@ def _store_operation_metrics(self, base_plot_path: str): total_operation_runtimes, ) - if not self._keep_individual_batch_metrics: + if not self._config.keep_individual_batch_metrics: return for dataseries in self._operation_metrics_per_batch.values(): @@ -388,7 +357,7 @@ def _store_operation_metrics(self, base_plot_path: str): ) def _store_request_metrics(self, base_plot_path: str): - if not self._should_store_request_metrics: + if not self._config.store_request_metrics: return all_request_metrics = list( @@ -409,7 +378,7 @@ def _store_request_metrics(self, base_plot_path: str): dataseries.plot_cdf(base_plot_path, dataseries._y_name, TIME_STR) def _store_batch_metrics(self, base_plot_path: str): - if not self._should_store_batch_metrics: + if not self._config.store_batch_metrics: return for dataseries in self._batch_metrics_time_distribution.values(): @@ -423,7 +392,7 @@ def _store_batch_metrics(self, base_plot_path: str): for dataseries in self._batch_metrics_count_distribution.values(): dataseries.plot_cdf(base_plot_path, dataseries._metric_name, COUNT_STR) - if not self._keep_individual_batch_metrics: + if not self._config.keep_individual_batch_metrics: return for dataseries in self._batch_metrics_time_distribution_per_batch.values(): @@ -459,13 +428,13 @@ def _store_batch_metrics(self, base_plot_path: str): ) def _store_completion_metrics(self, base_plot_path: str): - if self._should_store_request_metrics: + if self._config.store_request_metrics: for dataseries in self._request_completion_metrics_time_series.values(): dataseries.plot_step( base_plot_path, f"{dataseries._y_name}_time_series", COUNT_STR ) - if not self._should_store_token_completion_metrics: + if not self._config.store_token_completion_metrics: return for dataseries in self._token_metrics_time_distribution.values(): @@ -477,14 +446,14 @@ def _store_completion_metrics(self, base_plot_path: str): ) def _store_utilization_metrics(self, base_plot_path: str): - if not self._should_store_utilization_metrics: + if not self._config.store_utilization_metrics: return - for replica_idx in range(self._num_replicas): + for replica_idx in range(self._config.num_replicas): self._replica_memory_usage[replica_idx].print_stats( f"replica_{replica_idx + 1}_memory_usage", base_plot_path ) - for stage_idx in range(self._num_stages): + for stage_idx in range(self._config.num_pipeline_stages): self._replica_busy_time[replica_idx][stage_idx].print_stats( f"replica_{replica_idx + 1}_stage_{stage_idx + 1}_busy_time_percent", base_plot_path, @@ -507,7 +476,7 @@ def plot(self) -> None: @if_write_metrics def on_request_arrival(self, time: float, request: Request) -> None: - if not self._should_store_request_metrics: + if not self._config.store_request_metrics: return self._request_completion_metrics_time_series[ @@ -534,7 +503,7 @@ def on_request_arrival(self, time: float, request: Request) -> None: @if_write_metrics def _on_request_end(self, time: float, request: Request) -> None: - if not self._should_store_request_metrics: + if not self._config.store_request_metrics: return self._request_completion_metrics_time_series[ @@ -606,7 +575,7 @@ def _update_per_token_execution_times( # if prefill has just finished in this iteration, update the prefill completion time series if ( time == request.prefill_completed_at - and self._should_store_token_completion_metrics + and self._config.store_token_completion_metrics ): self._token_completion_metrics_time_series[ TokenCompletionMetricsTimeSeries.PREFILL_COMPLETIONS @@ -619,7 +588,7 @@ def _update_per_token_execution_times( if not request.has_started_decode: return - if not self._should_store_token_completion_metrics: + if not self._config.store_token_completion_metrics: return self._token_metrics_time_distribution[ @@ -658,21 +627,21 @@ def _push_metric( def on_batch_end( self, time: float, batch: Batch, replica_id: int, memory_usage_percent: int ) -> None: - if (self._min_batch_idx and batch.id < self._min_batch_idx) or ( - self._max_batch_idx and batch.id > self._max_batch_idx + if (self._config.min_batch_index and batch.id < self._config.min_batch_index) or ( + self._config.max_batch_index and batch.id > self._config.max_batch_index ): return for request in batch.completed_requests: self._on_request_end(time, request) - if self._should_store_utilization_metrics: + if self._config.store_utilization_metrics: self._replica_memory_usage[replica_id - 1].put(time, memory_usage_percent) for request in batch.requests: self._update_per_token_execution_times(time, request, batch) - if not self._should_store_batch_metrics: + if not self._config.store_batch_metrics: return self._push_metric( @@ -703,7 +672,7 @@ def on_batch_end( def on_replica_schedule( self, time: float, replica_id: int, memory_usage_percent: int ) -> None: - if not self._should_store_utilization_metrics: + if not self._config.store_utilization_metrics: return self._replica_memory_usage[replica_id - 1].put(time, memory_usage_percent) @@ -717,14 +686,14 @@ def on_replica_stage_schedule( batch_stage: BatchStage, execution_time: ExecutionTime, ) -> None: - if not self._should_store_utilization_metrics: + if not self._config.store_utilization_metrics: return self._replica_busy_time[replica_id - 1][stage_id - 1].put(time, 100) mfu = self._mfu_calculator.get_mfu(batch_stage) self._replica_mfu[replica_id - 1][stage_id - 1].put(time, mfu) - if not self._should_store_operation_metrics: + if not self._config.store_operation_metrics: return batch_id = batch_stage._batch_id @@ -836,7 +805,7 @@ def on_replica_stage_schedule( def on_batch_stage_end( self, batch_stage: BatchStage, time: float, replica_id: int, stage_id: int ) -> None: - if not self._should_store_utilization_metrics: + if not self._config.store_utilization_metrics: return self._replica_busy_time[replica_id - 1][stage_id - 1].put(time, 0) self._replica_mfu[replica_id - 1][stage_id - 1].put(time, 0) diff --git a/vidur/scheduler/global_scheduler/base_global_scheduler.py b/vidur/scheduler/global_scheduler/base_global_scheduler.py index 6e35e21..0f378ea 100644 --- a/vidur/scheduler/global_scheduler/base_global_scheduler.py +++ b/vidur/scheduler/global_scheduler/base_global_scheduler.py @@ -17,13 +17,16 @@ def __init__(self, config: SimulationConfig, replicas: Dict[int, Replica]): self._num_replicas = len(self._replicas) execution_time_predictor = ExecutionTimePredictorRegistry.get( - self._config.cluster_config.execution_time_predictor_config.get_type(), - self._config, + self._config.execution_time_predictor_config.get_type(), + self._config.execution_time_predictor_config, + self._config.cluster_config.replica_config.model_config, ) self._replica_schedulers = { replica_id: ReplicaSchedulerRegistry.get( config.cluster_config.replica_scheduler_config.get_type(), - config, + config.cluster_config.replica_config, + config.cluster_config.replica_scheduler_config, + config.request_generator_config, replica, replica.num_pipeline_stages, execution_time_predictor, diff --git a/vidur/scheduler/replica_scheduler/base_replica_scheduler.py b/vidur/scheduler/replica_scheduler/base_replica_scheduler.py index c1b40fd..642a856 100644 --- a/vidur/scheduler/replica_scheduler/base_replica_scheduler.py +++ b/vidur/scheduler/replica_scheduler/base_replica_scheduler.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import List -from vidur.config import SimulationConfig +from vidur.config import BaseRequestGeneratorConfig, ReplicaConfig, BaseReplicaSchedulerConfig from vidur.entities import Batch, Replica, Request from vidur.execution_time_predictor import BaseExecutionTimePredictor from vidur.logger import init_logger @@ -14,33 +14,32 @@ class BaseReplicaScheduler(ABC): def __init__( self, - config: SimulationConfig, + replica_config: ReplicaConfig, + replica_scheduler_config: BaseReplicaSchedulerConfig, + request_generator_config: BaseRequestGeneratorConfig, replica: Replica, num_stages: int, execution_time_predictor: BaseExecutionTimePredictor, ) -> None: - self._config = config + self._replica_config = replica_config + self._request_generator_config = request_generator_config + self._replica_scheduler_config = replica_scheduler_config self._replica_id = replica.id self._num_stages = num_stages - # store config variables - self._block_size = self._config.cluster_config.replica_scheduler_config.block_size - self._max_blocks_per_sequence = ( - self._config.request_generator_config.max_tokens // self._block_size + self._request_generator_config.max_tokens // self._replica_scheduler_config.block_size ) - memory_planner = MemoryPlanner(config, replica) - - self._num_total_blocks = config.cluster_config.replica_scheduler_config.num_blocks + memory_planner = MemoryPlanner(self._replica_config, replica) - if not self._num_total_blocks: - self._num_total_blocks = ( + if not self._replica_scheduler_config.num_blocks: + self._replica_scheduler_config.num_blocks = ( self._max_blocks_per_sequence * memory_planner.get_max_request_slots() ) self._max_batch_size = min( memory_planner.get_max_batch_size(), - config.cluster_config.replica_scheduler_config.batch_size_cap, + self._replica_scheduler_config.batch_size_cap, ) logger.debug( @@ -75,7 +74,7 @@ def num_allocated_blocks(self) -> int: @property def memory_usage_percent(self) -> int: - return (self._num_allocated_blocks * 100) / self._num_total_blocks + return (self._num_allocated_blocks * 100) / self._replica_scheduler_config.num_blocks def is_empty(self) -> bool: return ( @@ -102,7 +101,7 @@ def get_replica_stage_scheduler(self, stage_id: int): return self._replica_stage_schedulers[stage_id] def can_allocate(self, num_blocks: int) -> bool: - return self._num_total_blocks - self._num_allocated_blocks >= num_blocks + return self._replica_scheduler_config.num_blocks - self._num_allocated_blocks >= num_blocks def allocate(self, request_id: int, num_blocks: int) -> None: self._num_allocated_blocks += num_blocks @@ -111,7 +110,7 @@ def allocate(self, request_id: int, num_blocks: int) -> None: else: self._allocation_map[request_id] += num_blocks - assert self._num_allocated_blocks <= self._num_total_blocks + assert self._num_allocated_blocks <= self._replica_scheduler_config.num_blocks def free(self, *request_ids: List[int]) -> None: for request_id in request_ids: diff --git a/vidur/scheduler/replica_scheduler/lightllm_replica_scheduler.py b/vidur/scheduler/replica_scheduler/lightllm_replica_scheduler.py index 87db3b8..3a6bfae 100644 --- a/vidur/scheduler/replica_scheduler/lightllm_replica_scheduler.py +++ b/vidur/scheduler/replica_scheduler/lightllm_replica_scheduler.py @@ -14,14 +14,11 @@ def __init__(self, *args, **kwargs): self._preempted_requests: List[Request] = [] self._num_running_batches = 0 - self._max_tokens_in_batch = self._config.cluster_config.replica_scheduler_config.max_tokens_in_batch - self._max_waiting_iters = self._config.cluster_config.replica_scheduler_config.max_waiting_iters - self._max_batch_size = self._config.cluster_config.replica_scheduler_config.batch_size_cap self._max_micro_batch_size = ( - self._max_batch_size // self._num_stages + self._replica_scheduler_config.batch_size_cap // self._num_stages ) assert ( - self._block_size == 1 + self._replica_scheduler_config.block_size == 1 ), "LightLLM scheduler only supports block size of 1." assert ( self._num_stages == 1 @@ -66,7 +63,7 @@ def _can_allocate_request(self, request: Request) -> bool: need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max() - return need_max_token_num < self._num_total_blocks + return need_max_token_num < self._replica_scheduler_config.num_blocks def _allocate_request(self, request: Request) -> None: if request.id not in self._allocation_map: @@ -89,10 +86,10 @@ def _get_prefill_batch(self) -> Batch: next_num_tokens = self._get_request_next_num_tokens(request) - if num_batch_tokens + next_num_tokens > self._max_tokens_in_batch: + if num_batch_tokens + next_num_tokens > self._replica_scheduler_config.max_tokens_in_batch: break - if len(self._allocation_map) == self._max_batch_size: + if len(self._allocation_map) == self._replica_scheduler_config.batch_size_cap: break if len(requests) == self._max_micro_batch_size: @@ -145,7 +142,7 @@ def _get_next_batch(self) -> Batch: self._num_waiting_iters = 0 return batch - if self._num_waiting_iters >= self._max_waiting_iters: + if self._num_waiting_iters >= self._replica_scheduler_config.max_waiting_iters: self._num_waiting_iters = 0 batch = self._get_prefill_batch() if batch: diff --git a/vidur/scheduler/replica_scheduler/sarathi_replica_scheduler.py b/vidur/scheduler/replica_scheduler/sarathi_replica_scheduler.py index 274bdb5..2bedc89 100644 --- a/vidur/scheduler/replica_scheduler/sarathi_replica_scheduler.py +++ b/vidur/scheduler/replica_scheduler/sarathi_replica_scheduler.py @@ -13,43 +13,37 @@ def __init__(self, *args, **kwargs): # sarathi config self._num_running_batches = 0 self._preempted_requests = [] - self._chunk_size = self._config.cluster_config.replica_scheduler_config.chunk_size - # vLLM config - self._watermark_blocks_fraction = ( - self._config.cluster_config.replica_scheduler_config.watermark_blocks_fraction - ) # For vLLM and its derivatives, we only need to set a loose max batch size # Memory requirements are handled explicitly by the scheduler - self._max_batch_size = self._config.cluster_config.replica_scheduler_config.batch_size_cap self._max_micro_batch_size = ( - self._max_batch_size // self._num_stages + self._replica_scheduler_config.batch_size_cap // self._num_stages ) self._watermark_blocks = int( - self._watermark_blocks_fraction * self._num_total_blocks + self._replica_scheduler_config.watermark_blocks_fraction * self._replica_scheduler_config.num_blocks ) def _can_allocate_request(self, request: Request) -> bool: if request.id not in self._allocation_map: # new request - num_required_blocks = ceil(request.num_prefill_tokens / self._block_size) + num_required_blocks = ceil(request.num_prefill_tokens / self._replica_scheduler_config.block_size) return ( - self._num_total_blocks + self._replica_scheduler_config.num_blocks - self._num_allocated_blocks - num_required_blocks >= self._watermark_blocks ) # vllm requires at least one block to be available - return self._num_total_blocks - self._num_allocated_blocks >= 1 + return self._replica_scheduler_config.num_blocks - self._num_allocated_blocks >= 1 def _allocate_request(self, request: Request) -> None: if request.id not in self._allocation_map: # new request - num_required_blocks = ceil(request.num_prefill_tokens / self._block_size) + num_required_blocks = ceil(request.num_prefill_tokens / self._replica_scheduler_config.block_size) self.allocate(request.id, num_required_blocks) return - num_tokens_reserved = self._allocation_map[request.id] * self._block_size + num_tokens_reserved = self._allocation_map[request.id] * self._replica_scheduler_config.block_size num_tokens_required = max(0, request.num_processed_tokens - num_tokens_reserved) assert ( @@ -80,7 +74,7 @@ def _get_request_next_num_tokens( next_num_tokens = min( request.num_prefill_tokens - request.num_processed_tokens, - self._chunk_size - num_batch_tokens, + self._replica_scheduler_config.chunk_size - num_batch_tokens, ) next_num_tokens = max(0, next_num_tokens) @@ -158,7 +152,7 @@ def _get_next_batch(self) -> Batch: skipped_requests = [] while self._request_queue: - if len(self._allocation_map) == self._max_batch_size: + if len(self._allocation_map) == self._replica_scheduler_config.batch_size_cap: break if len(requests) == self._max_micro_batch_size: diff --git a/vidur/scheduler/replica_scheduler/vllm_replica_scheduler.py b/vidur/scheduler/replica_scheduler/vllm_replica_scheduler.py index 5b6f6cc..4063c57 100644 --- a/vidur/scheduler/replica_scheduler/vllm_replica_scheduler.py +++ b/vidur/scheduler/replica_scheduler/vllm_replica_scheduler.py @@ -13,18 +13,13 @@ def __init__(self, *args, **kwargs): self._preempted_requests: List[Request] = [] self._num_running_batches = 0 - self._watermark_blocks_fraction = ( - self._config.cluster_config.replica_scheduler_config.watermark_blocks_fraction - ) - self._max_tokens_in_batch = self._config.cluster_config.replica_scheduler_config.max_tokens_in_batch # For vLLM and its derivatives, we only need to set a loose max batch size # Memory requirements are handled explicitly by the scheduler - self._max_batch_size = self._config.cluster_config.replica_scheduler_config.batch_size_cap self._max_micro_batch_size = ( - self._max_batch_size // self._num_stages + self._replica_scheduler_config.batch_size_cap // self._num_stages ) self._watermark_blocks = int( - self._watermark_blocks_fraction * self._num_total_blocks + self._replica_scheduler_config.watermark_blocks_fraction * self._replica_scheduler_config.num_blocks ) def on_batch_end(self, batch: Batch) -> None: @@ -39,25 +34,25 @@ def on_batch_end(self, batch: Batch) -> None: def _can_allocate_request(self, request: Request) -> bool: if request.id not in self._allocation_map: # new request - num_required_blocks = ceil((request.num_prefill_tokens) / self._block_size) + num_required_blocks = ceil((request.num_prefill_tokens) / self._replica_scheduler_config.block_size) return ( - self._num_total_blocks + self._replica_scheduler_config.num_blocks - self._num_allocated_blocks - num_required_blocks >= self._watermark_blocks ) # vllm requires at least one block to be available - return self._num_total_blocks - self._num_allocated_blocks >= 1 + return self._replica_scheduler_config.num_blocks - self._num_allocated_blocks >= 1 def _allocate_request(self, request: Request) -> None: if request.id not in self._allocation_map: # new request - num_required_blocks = ceil((request.num_prefill_tokens) / self._block_size) + num_required_blocks = ceil((request.num_prefill_tokens) / self._replica_scheduler_config.block_size) self.allocate(request.id, num_required_blocks) return - num_tokens_reserved = self._allocation_map[request.id] * self._block_size + num_tokens_reserved = self._allocation_map[request.id] * self._replica_scheduler_config.block_size num_tokens_required = max(0, request.num_processed_tokens - num_tokens_reserved) assert ( num_tokens_required == 0 or num_tokens_required == 1 @@ -83,10 +78,10 @@ def _get_next_batch(self) -> Batch: new_num_tokens = num_tokens + [next_num_tokens] new_num_batch_tokens = len(new_num_tokens) * max(new_num_tokens) - if new_num_batch_tokens > self._max_tokens_in_batch: + if new_num_batch_tokens > self._replica_scheduler_config.max_tokens_in_batch: break - if len(self._allocation_map) == self._max_batch_size: + if len(self._allocation_map) == self._replica_scheduler_config.batch_size_cap: break if len(requests) == self._max_micro_batch_size: diff --git a/vidur/scheduler/utils/memory_planner.py b/vidur/scheduler/utils/memory_planner.py index ae79b21..e769e7b 100644 --- a/vidur/scheduler/utils/memory_planner.py +++ b/vidur/scheduler/utils/memory_planner.py @@ -1,11 +1,11 @@ -from vidur.config import SimulationConfig +from vidur.config import ReplicaConfig from vidur.entities.replica import Replica from vidur.utils.param_counter import ParamCounter class MemoryPlanner: - def __init__(self, config: SimulationConfig, replica: Replica) -> None: - self._param_counter = ParamCounter(config) + def __init__(self, replica_config: ReplicaConfig, replica: Replica) -> None: + self._param_counter = ParamCounter(replica_config) self._replica = replica def _get_kv_cache_memory_per_layer_per_request(self) -> int: diff --git a/vidur/simulator.py b/vidur/simulator.py index a1f718c..e0e2725 100644 --- a/vidur/simulator.py +++ b/vidur/simulator.py @@ -26,21 +26,19 @@ def __init__(self, config: SimulationConfig) -> None: self._event_queue = [] - self._should_write_json_trace = self._config.cluster_config.metrics_config.write_json_trace - self._should_write_chrome_trace = self._config.cluster_config.metrics_config.enable_chrome_trace - self._event_trace = [] self._event_chrome_trace = [] - self._cluster = Cluster(self._config) - self._metric_store = MetricsStore(self._config) + self._cluster = Cluster(self._config.cluster_config) + self._metric_store = MetricsStore(self._config.metrics_config, self._config.cluster_config.replica_config) self._request_generator = RequestGeneratorRegistry.get( self._config.request_generator_config.get_type(), self._config.request_generator_config, ) self._scheduler = GlobalSchedulerRegistry.get( self._config.cluster_config.global_scheduler_config.get_type(), - self._config, self._cluster.replicas + self._config, + self._cluster.replicas ) self._init_event_queue() @@ -65,10 +63,10 @@ def run(self) -> None: new_events = event.handle_event(self._scheduler, self._metric_store) self._add_events(new_events) - if self._should_write_json_trace: + if self._config.metrics_config.write_json_trace: self._event_trace.append(event.to_dict()) - if self._should_write_chrome_trace: + if self._config.metrics_config.enable_chrome_trace: chrome_trace = event.to_chrome_trace() if chrome_trace: self._event_chrome_trace.append(chrome_trace) @@ -83,12 +81,12 @@ def _write_output(self) -> None: self._metric_store.plot() logger.info("Metrics written") - if self._should_write_json_trace: + if self._config.metrics_config.write_json_trace: self._write_event_trace() self._scheduler.write_batching_history() logger.info("Json event trace written") - if self._should_write_chrome_trace: + if self._config.metrics_config.enable_chrome_trace: self._write_chrome_trace() logger.info("Chrome event trace written") diff --git a/vidur/types/norm_type.py b/vidur/types/norm_type.py index b3191e4..b5a783d 100644 --- a/vidur/types/norm_type.py +++ b/vidur/types/norm_type.py @@ -4,4 +4,3 @@ class NormType(BaseIntEnum): LAYER_NORM = 0 RMS_NORM = 1 - diff --git a/vidur/utils/mfu_calculator.py b/vidur/utils/mfu_calculator.py index ea0c511..fecab53 100644 --- a/vidur/utils/mfu_calculator.py +++ b/vidur/utils/mfu_calculator.py @@ -1,24 +1,23 @@ -from vidur.config import SimulationConfig +from vidur.config import ReplicaConfig from vidur.entities import BatchStage from vidur.utils.param_counter import ParamCounter class MFUCalculator: - def __init__(self, config: SimulationConfig): - param_counter = ParamCounter(config) + + def __init__(self, replica_config: ReplicaConfig): + param_counter = ParamCounter(replica_config) self._num_params_per_device = param_counter.get_num_parameters_per_device() - replica_config = config.cluster_config.replica_config model_config = replica_config.model_config self._num_layers_per_device = ( model_config.num_layers // replica_config.num_pipeline_stages ) - self._embedding_dim = model_config.embedding_dim self._num_heads_per_device = ( model_config.num_q_heads // replica_config.tensor_parallel_size ) - self._head_dimension = self._embedding_dim // model_config.num_q_heads + self._head_dimension = model_config.embedding_dim // model_config.num_q_heads self._device_flops = replica_config.device_config.fp16_tflops * 2**40 def _get_mlp_flops(self, batch_stage: BatchStage) -> float: diff --git a/vidur/utils/param_counter.py b/vidur/utils/param_counter.py index fd7552c..fbc961a 100644 --- a/vidur/utils/param_counter.py +++ b/vidur/utils/param_counter.py @@ -1,44 +1,34 @@ from math import ceil -from vidur.config import SimulationConfig +from vidur.config import ReplicaConfig class ParamCounter: - def __init__(self, config: SimulationConfig) -> None: - replica_config = config.cluster_config.replica_config - model_config = replica_config.model_config + def __init__(self, replica_config: ReplicaConfig) -> None: + self._replica_config = replica_config + self._model_config = self._replica_config.model_config - self._embedding_dim = model_config.embedding_dim - self._num_pipeline_stages = replica_config.num_pipeline_stages - self._num_tensor_parallel_workers = replica_config.tensor_parallel_size - self._num_layers = model_config.num_layers - self._num_q_heads = model_config.num_q_heads - self._num_kv_heads = model_config.num_kv_heads - self._mlp_hidden_dim = model_config.mlp_hidden_dim - self._use_gated_mlp = model_config.use_gated_mlp - self._vocab_size = model_config.vocab_size - - assert self._num_q_heads % self._num_tensor_parallel_workers == 0 - assert self._num_layers % self._num_pipeline_stages == 0 - assert self._embedding_dim % self._num_tensor_parallel_workers == 0 - assert self._embedding_dim % self._num_q_heads == 0 + assert self._model_config.num_q_heads % self._replica_config.tensor_parallel_size == 0 + assert self._model_config.num_layers % self._replica_config.num_pipeline_stages == 0 + assert self._model_config.embedding_dim % self._replica_config.tensor_parallel_size == 0 + assert self._model_config.embedding_dim % self._model_config.num_q_heads == 0 self._num_layers_per_pipeline_stage = ( - self._num_layers // self._num_pipeline_stages + self._model_config.num_layers // self._replica_config.num_pipeline_stages ) - self._attention_head_dim = self._embedding_dim // self._num_q_heads + self._attention_head_dim = self._model_config.embedding_dim // self._model_config.num_q_heads self._q_heads_per_tensor_parallel_worker = ( - self._num_q_heads // self._num_tensor_parallel_workers + self._model_config.num_q_heads // self._replica_config.tensor_parallel_size ) self._kv_heads_per_tensor_parallel_worker = ceil( - self._num_kv_heads / self._num_tensor_parallel_workers + self._model_config.num_kv_heads / self._replica_config.tensor_parallel_size ) def get_num_parameters_per_layer(self) -> int: num_parameters = 0 # weights for attention metrics Wq, Wk, Wv num_parameters += ( - self._embedding_dim + self._model_config.embedding_dim * self._attention_head_dim * ( self._q_heads_per_tensor_parallel_worker @@ -47,24 +37,24 @@ def get_num_parameters_per_layer(self) -> int: ) # weights for attention metrics Wo num_parameters += ( - self._embedding_dim + self._model_config.embedding_dim * self._attention_head_dim * self._q_heads_per_tensor_parallel_worker ) # fc layer weights - if self._use_gated_mlp: + if self._model_config.use_gated_mlp: num_parameters += ( 3 - * self._embedding_dim - * self._mlp_hidden_dim - // self._num_tensor_parallel_workers + * self._model_config.embedding_dim + * self._model_config.mlp_hidden_dim + // self._replica_config.tensor_parallel_size ) else: num_parameters += ( 2 - * self._embedding_dim - * self._mlp_hidden_dim - // self._num_tensor_parallel_workers + * self._model_config.embedding_dim + * self._model_config.mlp_hidden_dim + // self._replica_config.tensor_parallel_size ) return num_parameters From 232296764845bc491cd19c2f4de3ef01a2f1531f Mon Sep 17 00:00:00 2001 From: Anmol Agarwal Date: Sat, 27 Jul 2024 17:10:10 -0400 Subject: [PATCH 13/24] Removing post_init updates to configs --- vidur/config/config.py | 74 +++------------- vidur/entities/cluster.py | 13 +-- vidur/entities/replica.py | 7 +- .../base_execution_time_predictor.py | 31 +++++-- ...ear_regression_execution_time_predictor.py | 21 ++++- ...random_forrest_execution_time_predictor.py | 21 ++++- .../sklearn_execution_time_predictor.py | 84 +++++++++++-------- vidur/metrics/metrics_store.py | 19 +++-- .../global_scheduler/base_global_scheduler.py | 18 ++-- vidur/simulator.py | 8 +- 10 files changed, 157 insertions(+), 139 deletions(-) diff --git a/vidur/config/config.py b/vidur/config/config.py index 6936bc0..5d16399 100644 --- a/vidur/config/config.py +++ b/vidur/config/config.py @@ -424,11 +424,20 @@ class MetricsConfig: default=None, metadata={"help": "Maximum batch index."}, ) + output_dir: str = field( + default="simulator_output", + metadata={"help": "Output directory."}, + ) + cache_dir: str = field( + default="cache", + metadata={"help": "Cache directory."}, + ) def __post_init__(self): - self.output_dir = None - self.num_replicas = None - self.num_pipeline_stages = None + self.output_dir = ( + f"{self.output_dir}/{datetime.now().strftime('%Y-%m-%d_%H-%M-%S-%f')}" + ) + os.makedirs(self.output_dir, exist_ok=True) @dataclass @@ -468,8 +477,6 @@ def __post_init__(self): self.device_config: BaseDeviceSKUConfig = BaseDeviceSKUConfig.create_from_type_string(self.device) self.node_config: BaseNodeSKUConfig = BaseNodeSKUConfig.create_from_type_string(self.network_device) - self.max_tokens = None - @dataclass class BaseGlobalSchedulerConfig(BasePolyConfig): @@ -568,18 +575,6 @@ class BaseExecutionTimePredictorConfig(BasePolyConfig): metadata={"help": "Whether to skip CPU overhead modeling."}, ) - def __post_init__(self): - self.num_tensor_parallel_workers = None - self.num_pipeline_stages = None - self.num_layers_per_pipeline_stage = None - self.replica_scheduler_provider = None - self.cache_dir = None - self.block_size = None - self.total_memory_gb = None - self.devices_per_node = None - self.device = None - self.network_device = None - @dataclass class LinearRegressionExecutionTimePredictorConfig(BaseExecutionTimePredictorConfig): @@ -641,10 +636,6 @@ class ClusterConfig: metadata={"help": "Replica scheduler config."}, ) - def __post_init__(self): - self.output_dir = None - self.write_json_trace = None - @dataclass class SimulationConfig(ABC): @@ -656,14 +647,6 @@ class SimulationConfig(ABC): default="info", metadata={"help": "Logging level."}, ) - output_dir: str = field( - default="simulator_output", - metadata={"help": "Output directory."}, - ) - cache_dir: str = field( - default="cache", - metadata={"help": "Cache directory."}, - ) time_limit: int = field( default=0, # in seconds, 0 is no limit metadata={"help": "Time limit for simulation in seconds. 0 means no limit."}, @@ -686,17 +669,8 @@ class SimulationConfig(ABC): ) def __post_init__(self): - self.output_dir = ( - f"{self.output_dir}/{datetime.now().strftime('%Y-%m-%d_%H-%M-%S-%f')}" - ) - os.makedirs(self.output_dir, exist_ok=True) self.write_config_to_file() - # Update the config - self.update_cluster_config() - self.update_metrics_config() - self.update_predictor_config() - @classmethod def create_from_cli_args(cls): flat_config = create_flat_dataclass(cls).create_from_cli_args() @@ -713,27 +687,5 @@ def to_dict(self): def write_config_to_file(self): config_dict = dataclass_to_dict(self) - with open(f"{self.output_dir}/config.json", "w") as f: + with open(f"{self.metrics_config.output_dir}/config.json", "w") as f: json.dump(config_dict, f, indent=4) - - def update_cluster_config(self): - self.cluster_config.output_dir = self.output_dir - self.cluster_config.replica_config.max_tokens = self.request_generator_config.max_tokens - self.cluster_config.write_json_trace = self.metrics_config.write_json_trace - - def update_metrics_config(self): - self.metrics_config.output_dir = self.output_dir - self.metrics_config.num_replicas = self.cluster_config.num_replicas - self.metrics_config.num_pipeline_stages = self.cluster_config.replica_config.num_pipeline_stages - - def update_predictor_config(self): - self.execution_time_predictor_config.num_tensor_parallel_workers = self.cluster_config.replica_config.tensor_parallel_size - self.execution_time_predictor_config.num_pipeline_stages = self.cluster_config.replica_config.num_pipeline_stages - self.execution_time_predictor_config.num_layers_per_pipeline_stage = self.cluster_config.replica_config.model_config.num_layers // self.cluster_config.replica_config.num_pipeline_stages - self.execution_time_predictor_config.replica_scheduler_provider = str(self.cluster_config.replica_scheduler_config.get_type()) - self.execution_time_predictor_config.cache_dir = f"{self.cache_dir}/execution_time_predictor" - self.execution_time_predictor_config.block_size = self.cluster_config.replica_scheduler_config.block_size - self.execution_time_predictor_config.total_memory_gb = self.cluster_config.replica_config.device_config.total_memory_gb - self.execution_time_predictor_config.devices_per_node = self.cluster_config.replica_config.node_config.num_devices_per_node - self.execution_time_predictor_config.device = self.cluster_config.replica_config.device - self.execution_time_predictor_config.network_device = self.cluster_config.replica_config.network_device diff --git a/vidur/entities/cluster.py b/vidur/entities/cluster.py index 9d4f836..fa23171 100644 --- a/vidur/entities/cluster.py +++ b/vidur/entities/cluster.py @@ -1,6 +1,6 @@ import json -from vidur.config import ClusterConfig +from vidur.config import ClusterConfig, MetricsConfig, BaseRequestGeneratorConfig from vidur.entities.base_entity import BaseEntity from vidur.entities.replica import Replica from vidur.logger import init_logger @@ -9,18 +9,21 @@ class Cluster(BaseEntity): - def __init__(self, cluster_config: ClusterConfig) -> None: + def __init__(self, cluster_config: ClusterConfig, metrics_config: MetricsConfig, generator_config: BaseRequestGeneratorConfig) -> None: self._id = Cluster.generate_id() self._config = cluster_config + # get metrics config + self._output_dir = metrics_config.output_dir + # Init replica object handles self._replicas = {} for _ in range(self._config.num_replicas): - replica = Replica(self._config.replica_config) + replica = Replica(self._config.replica_config, generator_config) self._replicas[replica.id] = replica - if self._config.write_json_trace: + if metrics_config.write_json_trace: self._write_cluster_info_to_file() @property @@ -37,6 +40,6 @@ def _write_cluster_info_to_file(self) -> None: replica_dicts = [replica.to_dict() for replica in self._replicas.values()] cluster_info = {"replicas": replica_dicts} - cluster_file = f"{self._config.output_dir}/cluster.json" + cluster_file = f"{self._output_dir}/cluster.json" with open(cluster_file, "w") as f: json.dump(cluster_info, f) diff --git a/vidur/entities/replica.py b/vidur/entities/replica.py index 66b224a..7b8425f 100644 --- a/vidur/entities/replica.py +++ b/vidur/entities/replica.py @@ -1,6 +1,6 @@ from math import ceil -from vidur.config import ReplicaConfig +from vidur.config import ReplicaConfig, BaseRequestGeneratorConfig from vidur.entities.base_entity import BaseEntity from vidur.logger import init_logger @@ -8,12 +8,13 @@ class Replica(BaseEntity): - def __init__(self, replica_config: ReplicaConfig) -> None: + def __init__(self, replica_config: ReplicaConfig, generator_config: BaseRequestGeneratorConfig) -> None: self._id = Replica.generate_id() self._replica_config = replica_config self._model_config = replica_config.model_config self._device_config = replica_config.device_config + self._generator_config = generator_config assert self._model_config.num_layers % self._replica_config.num_pipeline_stages == 0 assert ( @@ -86,7 +87,7 @@ def memory_margin_fraction(self) -> float: @property def max_request_tokens(self) -> int: - return self._replica_config.max_tokens + return self._generator_config.max_tokens @property def per_device_flops(self) -> float: diff --git a/vidur/execution_time_predictor/base_execution_time_predictor.py b/vidur/execution_time_predictor/base_execution_time_predictor.py index 7f1744e..0aafba3 100644 --- a/vidur/execution_time_predictor/base_execution_time_predictor.py +++ b/vidur/execution_time_predictor/base_execution_time_predictor.py @@ -1,24 +1,39 @@ from abc import ABC, abstractmethod -from vidur.config import BaseExecutionTimePredictorConfig -from vidur.config.model_config import BaseModelConfig +from vidur.config import ( + BaseExecutionTimePredictorConfig, + ReplicaConfig, + BaseReplicaSchedulerConfig, + MetricsConfig +) from vidur.entities import Batch, ExecutionTime class BaseExecutionTimePredictor(ABC): - def __init__(self, config: BaseExecutionTimePredictorConfig, model_config: BaseModelConfig) -> None: - self._config = config - self._model_config = model_config + def __init__(self, + predictor_config: BaseExecutionTimePredictorConfig, + replica_config: ReplicaConfig, + replica_scheduler_config: BaseReplicaSchedulerConfig, + metrics_config: MetricsConfig) -> None: + self._config = predictor_config + self._replica_config = replica_config + self._model_config = replica_config.model_config + + # get configs + self._replica_scheduler_provider = str(replica_scheduler_config.get_type()) + self._block_size = replica_scheduler_config.block_size + self._cache_dir = metrics_config.cache_dir + self._num_layers_per_pipeline_stage = self._model_config.num_layers // self._replica_config.num_pipeline_stages def get_execution_time(self, batch: Batch, pipeline_stage: int) -> ExecutionTime: - if pipeline_stage == self._config.num_pipeline_stages - 1: + if pipeline_stage == self._replica_config.num_pipeline_stages - 1: pipeline_parallel_communication_time = 0 else: pipeline_parallel_communication_time = ( self._get_pipeline_parallel_communication_time(batch) ) - if self._config.num_tensor_parallel_workers == 1: + if self._replica_config.tensor_parallel_size == 1: tensor_parallel_communication_time = 0 else: tensor_parallel_communication_time = ( @@ -26,7 +41,7 @@ def get_execution_time(self, batch: Batch, pipeline_stage: int) -> ExecutionTime ) return ExecutionTime( - self._config.num_layers_per_pipeline_stage, + self._num_layers_per_pipeline_stage, self._get_attention_rope_execution_time(batch), self._get_attention_kv_cache_save_execution_time(batch), self._get_attention_decode_execution_time(batch), diff --git a/vidur/execution_time_predictor/linear_regression_execution_time_predictor.py b/vidur/execution_time_predictor/linear_regression_execution_time_predictor.py index 18cc719..8f9f261 100644 --- a/vidur/execution_time_predictor/linear_regression_execution_time_predictor.py +++ b/vidur/execution_time_predictor/linear_regression_execution_time_predictor.py @@ -2,17 +2,30 @@ from sklearn.pipeline import make_pipeline from sklearn.preprocessing import PolynomialFeatures -from vidur.config import LinearRegressionExecutionTimePredictorConfig -from vidur.config.model_config import BaseModelConfig +from vidur.config import ( + LinearRegressionExecutionTimePredictorConfig, + ReplicaConfig, + BaseReplicaSchedulerConfig, + MetricsConfig +) from vidur.execution_time_predictor.sklearn_execution_time_predictor import ( SklearnExecutionTimePredictor, ) class LinearRegressionExecutionTimePredictor(SklearnExecutionTimePredictor): - def __init__(self, config: LinearRegressionExecutionTimePredictorConfig, model_config: BaseModelConfig): + def __init__(self, + predictor_config: LinearRegressionExecutionTimePredictorConfig, + replica_config: ReplicaConfig, + replica_scheduler_config: BaseReplicaSchedulerConfig, + metrics_config: MetricsConfig) -> None: # will trigger model training - super().__init__(config, model_config) + super().__init__( + predictor_config=predictor_config, + replica_config=replica_config, + replica_scheduler_config=replica_scheduler_config, + metrics_config=metrics_config + ) def _get_grid_search_params(self): return { diff --git a/vidur/execution_time_predictor/random_forrest_execution_time_predictor.py b/vidur/execution_time_predictor/random_forrest_execution_time_predictor.py index 1e7d664..7e8f232 100644 --- a/vidur/execution_time_predictor/random_forrest_execution_time_predictor.py +++ b/vidur/execution_time_predictor/random_forrest_execution_time_predictor.py @@ -1,16 +1,29 @@ from sklearn.ensemble import RandomForestRegressor -from vidur.config import RandomForrestExecutionTimePredictorConfig -from vidur.config.model_config import BaseModelConfig +from vidur.config import ( + RandomForrestExecutionTimePredictorConfig, + ReplicaConfig, + BaseReplicaSchedulerConfig, + MetricsConfig +) from vidur.execution_time_predictor.sklearn_execution_time_predictor import ( SklearnExecutionTimePredictor, ) class RandomForrestExecutionTimePredictor(SklearnExecutionTimePredictor): - def __init__(self, config: RandomForrestExecutionTimePredictorConfig, model_config: BaseModelConfig): + def __init__(self, + predictor_config: RandomForrestExecutionTimePredictorConfig, + replica_config: ReplicaConfig, + replica_scheduler_config: BaseReplicaSchedulerConfig, + metrics_config: MetricsConfig) -> None: # will trigger model training - super().__init__(config, model_config) + super().__init__( + predictor_config=predictor_config, + replica_config=replica_config, + replica_scheduler_config=replica_scheduler_config, + metrics_config=metrics_config + ) def _get_grid_search_params(self): return { diff --git a/vidur/execution_time_predictor/sklearn_execution_time_predictor.py b/vidur/execution_time_predictor/sklearn_execution_time_predictor.py index b358140..f936288 100644 --- a/vidur/execution_time_predictor/sklearn_execution_time_predictor.py +++ b/vidur/execution_time_predictor/sklearn_execution_time_predictor.py @@ -12,8 +12,12 @@ from sklearn.metrics import make_scorer from sklearn.model_selection import GridSearchCV -from vidur.config import BaseExecutionTimePredictorConfig -from vidur.config.model_config import BaseModelConfig +from vidur.config import ( + BaseExecutionTimePredictorConfig, + ReplicaConfig, + BaseReplicaSchedulerConfig, + MetricsConfig +) from vidur.entities import Batch from vidur.execution_time_predictor.base_execution_time_predictor import ( BaseExecutionTimePredictor, @@ -24,9 +28,18 @@ class SklearnExecutionTimePredictor(BaseExecutionTimePredictor): - def __init__(self, config: BaseExecutionTimePredictorConfig, model_config: BaseModelConfig) -> None: - super().__init__(config, model_config) - os.makedirs(self._config.cache_dir, exist_ok=True) + def __init__(self, + predictor_config: BaseExecutionTimePredictorConfig, + replica_config: ReplicaConfig, + replica_scheduler_config: BaseReplicaSchedulerConfig, + metrics_config: MetricsConfig) -> None: + super().__init__( + predictor_config=predictor_config, + replica_config=replica_config, + replica_scheduler_config=replica_scheduler_config, + metrics_config=metrics_config + ) + os.makedirs(self._cache_dir, exist_ok=True) # These overheads are only for GQA models self._attention_prefill_batching_overhead_fraction = ( @@ -43,17 +56,18 @@ def __init__(self, config: BaseExecutionTimePredictorConfig, model_config: BaseM if self._model_config.num_q_heads > self._model_config.num_kv_heads else 0 ) - if self._config.replica_scheduler_provider == "orca": + if self._replica_scheduler_provider == "orca": self._max_tokens = self._config.prediction_max_tokens_per_request * self._config.prediction_max_batch_size else: self._max_tokens = self._config.prediction_max_tokens_per_request - num_workers = self._config.num_pipeline_stages * self._config.num_tensor_parallel_workers + num_workers = self._replica_config.num_pipeline_stages * self._replica_config.tensor_parallel_size + devices_per_node = self._replica_config.node_config.num_devices_per_node assert ( - num_workers < self._config.devices_per_node or num_workers % self._config.devices_per_node == 0 + num_workers < devices_per_node or num_workers % devices_per_node == 0 ), "Number of workers should be less than devices per node or a multiple of devices per node" - self._is_multi_node = num_workers > self._config.devices_per_node + self._is_multi_node = num_workers > devices_per_node ( self._compute_input_file, @@ -76,9 +90,9 @@ def _get_input_files(self) -> Tuple[str, str, str, str, str]: ] for i in range(len(input_files)): input_files[i] = input_files[i].replace( - "{DEVICE}", self._config.device).replace( + "{DEVICE}", self._replica_config.device).replace( "{MODEL}", self._model_config.get_name()).replace( - "{NETWORK_DEVICE}", self._config.network_device) + "{NETWORK_DEVICE}", self._replica_config.network_device) return tuple(input_files) @@ -93,7 +107,7 @@ def _load_compute_df(self, file_path: str) -> pd.DataFrame: logger.debug(f"self._use_gated_mlp: {self._model_config.use_gated_mlp}") logger.debug(f"self._vocab_size: {self._model_config.vocab_size}") logger.debug( - f"self._num_tensor_parallel_workers: {self._config.num_tensor_parallel_workers}" + f"self._num_tensor_parallel_workers: {self._replica_config.tensor_parallel_size}" ) df = df[ @@ -103,7 +117,7 @@ def _load_compute_df(self, file_path: str) -> pd.DataFrame: & (df["n_expanded_embd"] == self._model_config.mlp_hidden_dim) & (df["use_gated_mlp"] == self._model_config.use_gated_mlp) & (df["vocab_size"] == self._model_config.vocab_size) - & (df["num_tensor_parallel_workers"] == self._config.num_tensor_parallel_workers) + & (df["num_tensor_parallel_workers"] == self._replica_config.tensor_parallel_size) ] for column in [ @@ -133,15 +147,15 @@ def _load_attention_df(self, file_path: str) -> pd.DataFrame: (df["n_embd"] == self._model_config.embedding_dim) & (df["n_q_head"] == self._model_config.num_q_heads) & (df["n_kv_head"] == self._model_config.num_kv_heads) - & (df["block_size"] == self._config.block_size) - & (df["num_tensor_parallel_workers"] == self._config.num_tensor_parallel_workers) + & (df["block_size"] == self._block_size) + & (df["num_tensor_parallel_workers"] == self._replica_config.tensor_parallel_size) ] def _load_all_reduce_df(self, file_path: str) -> pd.DataFrame: df = self._read_input_file(file_path) return df[ - (df["num_workers"] == self._config.num_tensor_parallel_workers) - & (df["devices_per_node"] == self._config.num_tensor_parallel_workers) + (df["num_workers"] == self._replica_config.tensor_parallel_size) + & (df["devices_per_node"] == self._replica_config.tensor_parallel_size) & (df["collective"] == "all_reduce") ] @@ -162,7 +176,7 @@ def _load_cpu_overhead_df(self, file_path: str) -> pd.DataFrame: df = self._read_input_file(file_path) filtered_df = df[ (df["model_name"] == self._model_config.get_name()) - & (df["tensor_parallel_degree"] == self._config.num_tensor_parallel_workers) + & (df["tensor_parallel_degree"] == self._replica_config.tensor_parallel_size) ] return filtered_df @@ -262,12 +276,12 @@ def _get_model_hash(self, model_name: str, df: pd.DataFrame = None) -> str: def _load_model_from_cache(self, model_name: str, model_hash: str) -> BaseEstimator: with InterProcessReaderWriterLock( - f"{self._config.cache_dir}/{model_hash}_model_lock.file" + f"{self._cache_dir}/{model_hash}_model_lock.file" ).read_lock(): if self._config.no_cache: return # check if model is in cache - cache_file = f"{self._config.cache_dir}/{model_name}_{model_hash}.pkl" + cache_file = f"{self._cache_dir}/{model_name}_{model_hash}.pkl" if not os.path.exists(cache_file): return @@ -279,10 +293,10 @@ def _store_model_in_cache( self, model_name: str, model_hash: str, model: BaseEstimator ) -> None: with InterProcessReaderWriterLock( - f"{self._config.cache_dir}/{model_hash}_model_lock.file" + f"{self._cache_dir}/{model_hash}_model_lock.file" ).write_lock(): # store model in cache - cache_file = f"{self._config.cache_dir}/{model_name}_{model_hash}.pkl" + cache_file = f"{self._cache_dir}/{model_name}_{model_hash}.pkl" pickle.dump(model, open(cache_file, "wb"), protocol=pickle.HIGHEST_PROTOCOL) def _store_training_prediction_data( @@ -301,7 +315,7 @@ def _store_training_prediction_data( # store the prediction data df[feature_cols + [target_col, "prediction"]].to_csv( - f"{self._config.cache_dir}/{model_name}_{model_hash}_training_predictions.csv", + f"{self._cache_dir}/{model_name}_{model_hash}_training_predictions.csv", index=False, ) @@ -365,9 +379,9 @@ def _store_model_predication_cache( self, model_name: str, model_hash: str, predictions: Dict[Tuple, float] ) -> None: with InterProcessReaderWriterLock( - f"{self._config.cache_dir}/{model_hash}_prediction_lock.file" + f"{self._cache_dir}/{model_hash}_prediction_lock.file" ).write_lock(): - cache_file = f"{self._config.cache_dir}/{model_name}_{model_hash}_predictions.pkl" + cache_file = f"{self._cache_dir}/{model_name}_{model_hash}_predictions.pkl" pickle.dump( predictions, open(cache_file, "wb"), protocol=pickle.HIGHEST_PROTOCOL ) @@ -376,11 +390,11 @@ def _load_model_predication_cache( self, model_name: str, model_hash: str ) -> Dict[Tuple, float]: with InterProcessReaderWriterLock( - f"{self._config.cache_dir}/{model_hash}_prediction_lock.file" + f"{self._cache_dir}/{model_hash}_prediction_lock.file" ).read_lock(): if self._config.no_cache: return - cache_file = f"{self._config.cache_dir}/{model_name}_{model_hash}_predictions.pkl" + cache_file = f"{self._cache_dir}/{model_name}_{model_hash}_predictions.pkl" if not os.path.exists(cache_file): return @@ -413,7 +427,7 @@ def _get_model_prediction( X["prediction"] = predictions_array X.to_csv( - f"{self._config.cache_dir}/{model_name}_{model_hash}_predictions.csv", + f"{self._cache_dir}/{model_name}_{model_hash}_predictions.csv", index=False, ) @@ -462,7 +476,7 @@ def _train_compute_models(self) -> Dict[str, BaseEstimator]: target_col=f"time_stats.{model_name}.median", ) - if self._config.num_pipeline_stages > 1: + if self._replica_config.num_pipeline_stages > 1: send_recv_df = self._load_send_recv_df(self._send_recv_input_file) send_recv_df = self._get_send_recv_df_with_derived_features(send_recv_df) @@ -473,7 +487,7 @@ def _train_compute_models(self) -> Dict[str, BaseEstimator]: target_col="time_stats.send_recv.median", ) - if self._config.num_tensor_parallel_workers > 1: + if self._replica_config.tensor_parallel_size > 1: all_reduce_df = self._load_all_reduce_df(self._all_reduce_input_file) all_reduce_df = self._get_all_reduce_df_with_derived_features(all_reduce_df) @@ -572,10 +586,10 @@ def _predict_for_compute_models(self) -> Dict[str, Any]: "add", ] - if self._config.num_pipeline_stages > 1: + if self._replica_config.num_pipeline_stages > 1: model_names.append("send_recv") - if self._config.num_tensor_parallel_workers > 1: + if self._replica_config.tensor_parallel_size > 1: model_names.append("all_reduce") num_token_range = np.arange(1, self._max_tokens + 1) @@ -774,7 +788,7 @@ def _get_tensor_parallel_communication_time(self, batch: Batch) -> float: self._predictions["all_reduce"][(batch._total_num_tokens_rounded,)] + self._config.nccl_cpu_launch_overhead_ms + self._config.nccl_cpu_skew_overhead_per_device_ms - * self._config.num_tensor_parallel_workers**1.25 + * self._replica_config.tensor_parallel_size**1.25 ) def _get_pipeline_parallel_communication_time(self, batch: Batch) -> float: @@ -862,7 +876,7 @@ def _get_ray_comm_time(self, batch: Batch) -> float: def to_dict(self) -> dict: return { "model_provider": str(self._config.get_type()), - "num_tensor_parallel_workers": self._config.num_tensor_parallel_workers, + "num_tensor_parallel_workers": self._replica_config.tensor_parallel_size, "k_fold_cv_splits": self._config.k_fold_cv_splits, "num_q_heads": self._model_config.num_q_heads, "num_kv_heads": self._model_config.num_kv_heads, @@ -870,7 +884,7 @@ def to_dict(self) -> dict: "mlp_hidden_dim": self._model_config.mlp_hidden_dim, "use_gated_mlp": self._model_config.use_gated_mlp, "vocab_size": self._model_config.vocab_size, - "block_size": self._config.block_size, + "block_size": self._block_size, "max_tokens": self._max_tokens, "compute_input_file": self._compute_input_file, "all_reduce_input_file": self._all_reduce_input_file, diff --git a/vidur/metrics/metrics_store.py b/vidur/metrics/metrics_store.py index c0b3860..c41a954 100644 --- a/vidur/metrics/metrics_store.py +++ b/vidur/metrics/metrics_store.py @@ -6,7 +6,7 @@ import plotly_express as px import wandb -from vidur.config import MetricsConfig, ReplicaConfig +from vidur.config import MetricsConfig, ClusterConfig from vidur.entities import Batch, BatchStage, ExecutionTime, Request from vidur.logger import init_logger from vidur.metrics.cdf_sketch import CDFSketch @@ -49,9 +49,14 @@ def wrapper(self, *args, **kwargs): class MetricsStore: - def __init__(self, config: MetricsConfig, replica_config: ReplicaConfig) -> None: + def __init__(self, config: MetricsConfig, cluster_config: ClusterConfig) -> None: self._config = config self._last_request_arrived_at = None + + # copy config + self._num_replicas = cluster_config.num_replicas + self._num_pipeline_stages = cluster_config.replica_config.num_pipeline_stages + # Initialise request metrics self._request_metrics_time_distributions: Dict[ RequestMetricsTimeDistributions, DataSeries @@ -190,9 +195,9 @@ def __init__(self, config: MetricsConfig, replica_config: ReplicaConfig) -> None # per replica stage metrics self._replica_busy_time = [] self._replica_mfu = [] - self._mfu_calculator = MFUCalculator(replica_config) + self._mfu_calculator = MFUCalculator(cluster_config.replica_config) - for replica_idx in range(self._config.num_replicas): + for replica_idx in range(self._num_replicas): self._replica_memory_usage.append( SeriesAverageMeter( TIME_STR, @@ -205,7 +210,7 @@ def __init__(self, config: MetricsConfig, replica_config: ReplicaConfig) -> None self._replica_busy_time.append([]) self._replica_mfu.append([]) - for stage_idx in range(self._config.num_pipeline_stages): + for stage_idx in range(self._num_pipeline_stages): self._replica_busy_time[replica_idx].append( SeriesAverageMeter( TIME_STR, @@ -449,11 +454,11 @@ def _store_utilization_metrics(self, base_plot_path: str): if not self._config.store_utilization_metrics: return - for replica_idx in range(self._config.num_replicas): + for replica_idx in range(self._num_replicas): self._replica_memory_usage[replica_idx].print_stats( f"replica_{replica_idx + 1}_memory_usage", base_plot_path ) - for stage_idx in range(self._config.num_pipeline_stages): + for stage_idx in range(self._num_pipeline_stages): self._replica_busy_time[replica_idx][stage_idx].print_stats( f"replica_{replica_idx + 1}_stage_{stage_idx + 1}_busy_time_percent", base_plot_path, diff --git a/vidur/scheduler/global_scheduler/base_global_scheduler.py b/vidur/scheduler/global_scheduler/base_global_scheduler.py index 0f378ea..cae2d10 100644 --- a/vidur/scheduler/global_scheduler/base_global_scheduler.py +++ b/vidur/scheduler/global_scheduler/base_global_scheduler.py @@ -18,18 +18,20 @@ def __init__(self, config: SimulationConfig, replicas: Dict[int, Replica]): execution_time_predictor = ExecutionTimePredictorRegistry.get( self._config.execution_time_predictor_config.get_type(), - self._config.execution_time_predictor_config, - self._config.cluster_config.replica_config.model_config, + predictor_config=self._config.execution_time_predictor_config, + replica_config=self._config.cluster_config.replica_config, + replica_scheduler_config=self._config.cluster_config.replica_scheduler_config, + metrics_config=self._config.metrics_config, ) self._replica_schedulers = { replica_id: ReplicaSchedulerRegistry.get( config.cluster_config.replica_scheduler_config.get_type(), - config.cluster_config.replica_config, - config.cluster_config.replica_scheduler_config, - config.request_generator_config, - replica, - replica.num_pipeline_stages, - execution_time_predictor, + replica_config=config.cluster_config.replica_config, + replica_scheduler_config=config.cluster_config.replica_scheduler_config, + request_generator_config=config.request_generator_config, + replica=replica, + num_stages=replica.num_pipeline_stages, + execution_time_predictor=execution_time_predictor, ) for replica_id, replica in replicas.items() } diff --git a/vidur/simulator.py b/vidur/simulator.py index e0e2725..4628662 100644 --- a/vidur/simulator.py +++ b/vidur/simulator.py @@ -29,8 +29,8 @@ def __init__(self, config: SimulationConfig) -> None: self._event_trace = [] self._event_chrome_trace = [] - self._cluster = Cluster(self._config.cluster_config) - self._metric_store = MetricsStore(self._config.metrics_config, self._config.cluster_config.replica_config) + self._cluster = Cluster(self._config.cluster_config, self._config.metrics_config, self._config.request_generator_config) + self._metric_store = MetricsStore(self._config.metrics_config, self._config.cluster_config) self._request_generator = RequestGeneratorRegistry.get( self._config.request_generator_config.get_type(), self._config.request_generator_config, @@ -112,12 +112,12 @@ def _set_time(self, time: float) -> None: self._terminate = True def _write_event_trace(self) -> None: - trace_file = f"{self._config.output_dir}/event_trace.json" + trace_file = f"{self._config.metrics_config.output_dir}/event_trace.json" with open(trace_file, "w") as f: json.dump(self._event_trace, f) def _write_chrome_trace(self) -> None: - trace_file = f"{self._config.output_dir}/chrome_trace.json" + trace_file = f"{self._config.metrics_config.output_dir}/chrome_trace.json" chrome_trace = {"traceEvents": self._event_chrome_trace} From d94c0361e54f9d2ba830b972060389ad8d7e3224 Mon Sep 17 00:00:00 2001 From: Anmol Agarwal Date: Sat, 27 Jul 2024 17:52:04 -0400 Subject: [PATCH 14/24] Remove helper strings --- vidur/config/device_sku_config.py | 38 +-- vidur/config/model_config.py | 466 +++++++----------------------- vidur/config/node_sku_config.py | 54 +--- 3 files changed, 118 insertions(+), 440 deletions(-) diff --git a/vidur/config/device_sku_config.py b/vidur/config/device_sku_config.py index 415e49d..a220f5e 100644 --- a/vidur/config/device_sku_config.py +++ b/vidur/config/device_sku_config.py @@ -9,24 +9,14 @@ @dataclass class BaseDeviceSKUConfig(BaseFixedConfig): - fp16_tflops: int = field( - metadata={"help": "The number of TFLOPS the device can achieve in FP16"}, - ) - total_memory_gb: int = field( - metadata={"help": "The total memory of the device in GB"}, - ) + fp16_tflops: int + total_memory_gb: int @dataclass class A100DeviceSKUConfig(BaseDeviceSKUConfig): - fp16_tflops: int = field( - default=312, - metadata={"help": "The number of TFLOPS the device can achieve in FP16"}, - ) - total_memory_gb: int = field( - default=80, - metadata={"help": "The total memory of the device in GB"}, - ) + fp16_tflops: int = 312 + total_memory_gb: int = 80 @staticmethod def get_type(): @@ -35,14 +25,8 @@ def get_type(): @dataclass class A40DeviceSKUConfig(BaseDeviceSKUConfig): - fp16_tflops: int = field( - default=150, - metadata={"help": "The number of TFLOPS the device can achieve in FP16"}, - ) - total_memory_gb: int = field( - default=45, - metadata={"help": "The total memory of the device in GB"}, - ) + fp16_tflops: int = 150 + total_memory_gb: int = 45 @staticmethod def get_type(): @@ -51,14 +35,8 @@ def get_type(): @dataclass class H100DeviceSKUConfig(BaseDeviceSKUConfig): - fp16_tflops: int = field( - default=1000, - metadata={"help": "The number of TFLOPS the device can achieve in FP16"}, - ) - total_memory_gb: int = field( - default=80, - metadata={"help": "The total memory of the device in GB"}, - ) + fp16_tflops: int = 1000 + total_memory_gb: int = 80 @staticmethod def get_type(): diff --git a/vidur/config/model_config.py b/vidur/config/model_config.py index c866a47..462534e 100644 --- a/vidur/config/model_config.py +++ b/vidur/config/model_config.py @@ -10,121 +10,41 @@ @dataclass class BaseModelConfig(BaseFixedConfig): - num_layers: int = field( - metadata={"help": "The number of layers in the model"}, - ) - num_q_heads: int = field( - metadata={"help": "The number of query heads in the model"}, - ) - num_kv_heads: int = field( - metadata={"help": "The number of key-value heads in the model"}, - ) - embedding_dim: int = field( - metadata={"help": "The embedding dimension of the model"}, - ) - mlp_hidden_dim: int = field( - metadata={"help": "The hidden dimension of the MLP in the model"}, - ) - max_position_embeddings: int = field( - metadata={"help": "The maximum position embeddings in the model"}, - ) - use_gated_mlp: bool = field( - metadata={"help": "Whether to use gated MLP in the model"}, - ) - use_bias: bool = field( - metadata={"help": "Whether to use bias in the model"}, - ) - use_qkv_bias: bool = field( - metadata={"help": "Whether to use bias in the QKV in the model"}, - ) - activation: ActivationType = field( - metadata={"help": "The activation function in the model"}, - ) - norm: NormType = field( - metadata={"help": "The normalization function in the model"}, - ) - post_attn_norm: bool = field( - metadata={"help": "Whether to use post-attention normalization in the model"}, - ) - vocab_size: int = field( - metadata={"help": "The vocabulary size of the model"}, - ) - is_neox_style: Optional[bool] = field( - default=True, - metadata={"help": "Whether to use the Neox style in the model"}, - ) - rope_theta: Optional[int] = field( - default=None, - metadata={"help": "The rope theta in the model"}, - ) - rope_scaling: Optional[Dict[str, Any]] = field( - default=None, - metadata={"help": "The rope scaling config for the model"}, - ) - partial_rotary_factor: float = field( - default=1.0, - metadata={"help": "The partial rotary factor in the model"}, - ) - no_tensor_parallel: bool = field( - default=False, - metadata={"help": "Whether to use tensor parallelism in the model"}, - ) + num_layers: int + num_q_heads: int + num_kv_heads: int + embedding_dim: int + mlp_hidden_dim: int + max_position_embeddings: int + use_gated_mlp: bool + use_bias: bool + use_qkv_bias: bool + activation: ActivationType + norm: NormType + post_attn_norm: bool + vocab_size: int + is_neox_style: Optional[bool] = True + rope_theta: Optional[int] = None + rope_scaling: Optional[Dict[str, Any]] = None + partial_rotary_factor: float = 1.0 + no_tensor_parallel: bool = False @dataclass class Llama2ModelConfig(BaseModelConfig): - max_position_embeddings: int = field( - default=16384, - metadata={"help": "The maximum position embeddings in the model"}, - ) - use_gated_mlp: bool = field( - default=True, - metadata={"help": "Whether to use gated MLP in the model"}, - ) - use_bias: bool = field( - default=False, - metadata={"help": "Whether to use bias in the model"}, - ) - use_qkv_bias: bool = field( - default=False, - metadata={"help": "Whether to use bias in the QKV in the model"}, - ) - activation: ActivationType = field( - default=ActivationType.SILU, - metadata={"help": "The activation function in the model"}, - ) - norm: NormType = field( - default=NormType.RMS_NORM, - metadata={"help": "The normalization function in the model"}, - ) - post_attn_norm: bool = field( - default=True, - metadata={"help": "Whether to use post-attention normalization in the model"}, - ) - vocab_size: int = field( - default=32768, - metadata={"help": "The vocabulary size of the model"}, - ) - is_neox_style: Optional[bool] = field( - default=True, - metadata={"help": "Whether to use the Neox style in the model"}, - ) - rope_theta: Optional[int] = field( - default=10000.0, - metadata={"help": "The rope theta in the model"}, - ) - rope_scaling: Optional[Dict[str, Any]] = field( - default=None, - metadata={"help": "The rope scaling config for the model"}, - ) - partial_rotary_factor: float = field( - default=1.0, - metadata={"help": "The partial rotary factor in the model"}, - ) - no_tensor_parallel: bool = field( - default=False, - metadata={"help": "Whether to use tensor parallelism in the model"}, - ) + max_position_embeddings: int = 16384 + use_gated_mlp: bool = True + use_bias: bool = False + use_qkv_bias: bool = False + activation: ActivationType = ActivationType.SILU + norm: NormType = NormType.RMS_NORM + post_attn_norm: bool = True + vocab_size: int = 32768 + is_neox_style: Optional[bool] = True + rope_theta: Optional[int] = 10000.0 + rope_scaling: Optional[Dict[str, Any]] = None + partial_rotary_factor: float = 1.0 + no_tensor_parallel: bool = False @staticmethod def get_name(): @@ -133,26 +53,11 @@ def get_name(): @dataclass class CodeLlama34BModelConfig(Llama2ModelConfig): - num_layers: int = field( - default=48, - metadata={"help": "The number of layers in the model"}, - ) - num_q_heads: int = field( - default=64, - metadata={"help": "The number of query heads in the model"}, - ) - num_kv_heads: int = field( - default=8, - metadata={"help": "The number of key-value heads in the model"}, - ) - embedding_dim: int = field( - default=8192, - metadata={"help": "The embedding dimension of the model"}, - ) - mlp_hidden_dim: int = field( - default=22016, - metadata={"help": "The hidden dimension of the MLP in the model"}, - ) + num_layers: int = 48 + num_q_heads: int = 64 + num_kv_heads: int = 8 + embedding_dim: int = 8192 + mlp_hidden_dim: int = 22016 @staticmethod def get_name(): @@ -161,26 +66,11 @@ def get_name(): @dataclass class Llama2_7BModelConfig(Llama2ModelConfig): - num_layers: int = field( - default=32, - metadata={"help": "The number of layers in the model"}, - ) - num_q_heads: int = field( - default=32, - metadata={"help": "The number of query heads in the model"}, - ) - num_kv_heads: int = field( - default=32, - metadata={"help": "The number of key-value heads in the model"}, - ) - embedding_dim: int = field( - default=4096, - metadata={"help": "The embedding dimension of the model"}, - ) - mlp_hidden_dim: int = field( - default=11008, - metadata={"help": "The hidden dimension of the MLP in the model"}, - ) + num_layers: int = 32 + num_q_heads: int = 32 + num_kv_heads: int = 32 + embedding_dim: int = 4096 + mlp_hidden_dim: int = 11008 @staticmethod def get_name(): @@ -189,26 +79,11 @@ def get_name(): @dataclass class Llama2_70BModelConfig(Llama2ModelConfig): - num_layers: int = field( - default=80, - metadata={"help": "The number of layers in the model"}, - ) - num_q_heads: int = field( - default=64, - metadata={"help": "The number of query heads in the model"}, - ) - num_kv_heads: int = field( - default=8, - metadata={"help": "The number of key-value heads in the model"}, - ) - embedding_dim: int = field( - default=8192, - metadata={"help": "The embedding dimension of the model"}, - ) - mlp_hidden_dim: int = field( - default=28672, - metadata={"help": "The hidden dimension of the MLP in the model"}, - ) + num_layers: int = 80 + num_q_heads: int = 64 + num_kv_heads: int = 8 + embedding_dim: int = 8192 + mlp_hidden_dim: int = 28672 @staticmethod def get_name(): @@ -217,38 +92,14 @@ def get_name(): @dataclass class Llama3_8BModelConfig(Llama2ModelConfig): - num_layers: int = field( - default=32, - metadata={"help": "The number of layers in the model"}, - ) - num_q_heads: int = field( - default=32, - metadata={"help": "The number of query heads in the model"}, - ) - num_kv_heads: int = field( - default=8, - metadata={"help": "The number of key-value heads in the model"}, - ) - embedding_dim: int = field( - default=4096, - metadata={"help": "The embedding dimension of the model"}, - ) - mlp_hidden_dim: int = field( - default=14336, - metadata={"help": "The hidden dimension of the MLP in the model"}, - ) - max_position_embeddings: int = field( - default=4096, - metadata={"help": "The maximum position embeddings in the model"}, - ) - rope_theta: Optional[int] = field( - default=500000.0, - metadata={"help": "The rope theta in the model"}, - ) - vocab_size: int = field( - default=128256, - metadata={"help": "The vocabulary size of the model"}, - ) + num_layers: int = 32 + num_q_heads: int = 32 + num_kv_heads: int = 8 + embedding_dim: int = 4096 + mlp_hidden_dim: int = 14336 + max_position_embeddings: int = 4096 + rope_theta: Optional[int] = 500000.0 + vocab_size: int = 128256 @staticmethod def get_name(): @@ -257,38 +108,14 @@ def get_name(): @dataclass class Llama3_70BModelConfig(Llama2ModelConfig): - num_layers: int = field( - default=80, - metadata={"help": "The number of layers in the model"}, - ) - num_q_heads: int = field( - default=64, - metadata={"help": "The number of query heads in the model"}, - ) - num_kv_heads: int = field( - default=8, - metadata={"help": "The number of key-value heads in the model"}, - ) - embedding_dim: int = field( - default=8192, - metadata={"help": "The embedding dimension of the model"}, - ) - mlp_hidden_dim: int = field( - default=28672, - metadata={"help": "The hidden dimension of the MLP in the model"}, - ) - max_position_embeddings: int = field( - default=8192, - metadata={"help": "The maximum position embeddings in the model"}, - ) - rope_theta: Optional[int] = field( - default=500000.0, - metadata={"help": "The rope theta in the model"}, - ) - vocab_size: int = field( - default=128256, - metadata={"help": "The vocabulary size of the model"}, - ) + num_layers: int = 80 + num_q_heads: int = 64 + num_kv_heads: int = 8 + embedding_dim: int = 8192 + mlp_hidden_dim: int = 28672 + max_position_embeddings: int = 8192 + rope_theta: Optional[int] = 500000.0 + vocab_size: int = 128256 @staticmethod def get_name(): @@ -296,38 +123,17 @@ def get_name(): @dataclass class InternLM2ModelConfig(Llama2ModelConfig): - max_position_embeddings: int = field( - default=32768, - metadata={"help": "The maximum position embeddings in the model"}, - ) - vocab_size: int = field( - default=92544, - metadata={"help": "The vocabulary size of the model"}, - ) + max_position_embeddings: int = 32768 + vocab_size: int = 92544 @dataclass class InternLM2_20BModelConfig(InternLM2ModelConfig): - num_layers: int = field( - default=48, - metadata={"help": "The number of layers in the model"}, - ) - num_q_heads: int = field( - default=48, - metadata={"help": "The number of query heads in the model"}, - ) - num_kv_heads: int = field( - default=8, - metadata={"help": "The number of key-value heads in the model"}, - ) - embedding_dim: int = field( - default=6144, - metadata={"help": "The embedding dimension of the model"}, - ) - mlp_hidden_dim: int = field( - default=16384, - metadata={"help": "The hidden dimension of the MLP in the model"}, - ) + num_layers: int = 48 + num_q_heads: int = 48 + num_kv_heads: int = 8 + embedding_dim: int = 6144 + mlp_hidden_dim: int = 16384 @staticmethod def get_name(): @@ -336,78 +142,24 @@ def get_name(): @dataclass class Phi2ModelConfig(Llama2ModelConfig): - num_layers: int = field( - default=32, - metadata={"help": "The number of layers in the model"}, - ) - num_q_heads: int = field( - default=32, - metadata={"help": "The number of query heads in the model"}, - ) - num_kv_heads: int = field( - default=32, - metadata={"help": "The number of key-value heads in the model"}, - ) - embedding_dim: int = field( - default=2560, - metadata={"help": "The embedding dimension of the model"}, - ) - mlp_hidden_dim: int = field( - default=10240, - metadata={"help": "The hidden dimension of the MLP in the model"}, - ) - max_position_embeddings: int = field( - default=2048, - metadata={"help": "The maximum position embeddings in the model"}, - ) - use_gated_mlp: bool = field( - default=False, - metadata={"help": "Whether to use gated MLP in the model"}, - ) - use_bias: bool = field( - default=True, - metadata={"help": "Whether to use bias in the model"}, - ) - use_qkv_bias: bool = field( - default=True, - metadata={"help": "Whether to use bias in the QKV in the model"}, - ) - activation: ActivationType = field( - default=ActivationType.GELU, - metadata={"help": "The activation function in the model"}, - ) - norm: NormType = field( - default=NormType.LAYER_NORM, - metadata={"help": "The normalization function in the model"}, - ) - post_attn_norm: bool = field( - default=False, - metadata={"help": "Whether to use post-attention normalization in the model"}, - ) - vocab_size: int = field( - default=51200, - metadata={"help": "The vocabulary size of the model"}, - ) - rope_scaling: Optional[Dict[str, Any]] = field( - default=None, - metadata={"help": "The rope scaling config for the model"}, - ) - rope_theta: Optional[int] = field( - default=10000.0, - metadata={"help": "The rope theta in the model"}, - ) - partial_rotary_factor: float = field( - default=0.4, - metadata={"help": "The partial rotary factor in the model"}, - ) - no_tensor_parallel: bool = field( - default=True, - metadata={"help": "Whether to use tensor parallelism in the model"}, - ) - is_neox_style: bool = field( - default=True, - metadata={"help": "Whether to use the Neox style in the model"}, - ) + num_layers: int = 32 + num_q_heads: int = 32 + num_kv_heads: int = 32 + embedding_dim: int = 2560 + mlp_hidden_dim: int = 10240 + max_position_embeddings: int = 2048 + use_gated_mlp: bool = False + use_bias: bool = True + use_qkv_bias: bool = True + activation: ActivationType = ActivationType.GELU + norm: NormType = NormType.LAYER_NORM + post_attn_norm: bool = False + vocab_size: int = 51200 + rope_scaling: Optional[Dict[str, Any]] = None + rope_theta: Optional[int] = 10000.0 + partial_rotary_factor: float = 0.4 + no_tensor_parallel: bool = True + is_neox_style: bool = True @staticmethod def get_name(): @@ -416,42 +168,22 @@ def get_name(): @dataclass class QwenModelConfig(Llama2ModelConfig): - use_qkv_bias: bool = field( - default=True, - metadata={"help": "Whether to use bias in the QKV in the model"}, - ) - max_position_embeddings: int = field( - default=32768, - metadata={"help": "The maximum position embeddings in the model"}, - ) - vocab_size: int = field( - default=152064, - metadata={"help": "The vocabulary size of the model"}, - ) + use_qkv_bias: bool = True + max_position_embeddings: int = 32768 + vocab_size: int = 152064 + + @staticmethod + def get_name(): + return "Qwen/Qwen-Config" @dataclass class Qwen72BModelConfig(QwenModelConfig): - num_layers: int = field( - default=80, - metadata={"help": "The number of layers in the model"}, - ) - num_q_heads: int = field( - default=64, - metadata={"help": "The number of query heads in the model"}, - ) - num_kv_heads: int = field( - default=64, - metadata={"help": "The number of key-value heads in the model"}, - ) - embedding_dim: int = field( - default=8192, - metadata={"help": "The embedding dimension of the model"}, - ) - mlp_hidden_dim: int = field( - default=24576, - metadata={"help": "The hidden dimension of the MLP in the model"}, - ) + num_layers: int = 80 + num_q_heads: int = 64 + num_kv_heads: int = 64 + embedding_dim: int = 8192 + mlp_hidden_dim: int = 24576 @staticmethod def get_name(): diff --git a/vidur/config/node_sku_config.py b/vidur/config/node_sku_config.py index a3324a4..43dcdd9 100644 --- a/vidur/config/node_sku_config.py +++ b/vidur/config/node_sku_config.py @@ -9,21 +9,13 @@ @dataclass class BaseNodeSKUConfig(BaseFixedConfig): - num_devices_per_node: int = field( - metadata={"help": "The number of devices per node"}, - ) + num_devices_per_node: int @dataclass class A40PairwiseNvlinkNodeSKUConfig(BaseNodeSKUConfig): - device_sku_type: DeviceSKUType = field( - default=DeviceSKUType.A40, - metadata={"help": "The device SKU type"}, - ) - num_devices_per_node: int = field( - default=8, - metadata={"help": "The number of devices per node"}, - ) + device_sku_type: DeviceSKUType = DeviceSKUType.A40 + num_devices_per_node: int = 8 @staticmethod def get_type(): @@ -32,14 +24,8 @@ def get_type(): @dataclass class A100PairwiseNvlinkNodeSKUConfig(BaseNodeSKUConfig): - device_sku_type: DeviceSKUType = field( - default=DeviceSKUType.A100, - metadata={"help": "The device SKU type"}, - ) - num_devices_per_node: int = field( - default=8, - metadata={"help": "The number of devices per node"}, - ) + device_sku_type: DeviceSKUType = DeviceSKUType.A100 + num_devices_per_node: int = 8 @staticmethod def get_type(): @@ -48,14 +34,8 @@ def get_type(): @dataclass class H100PairwiseNvlinkNodeSKUConfig(BaseNodeSKUConfig): - device_sku_type: DeviceSKUType = field( - default=DeviceSKUType.H100, - metadata={"help": "The device SKU type"}, - ) - num_devices_per_node: int = field( - default=8, - metadata={"help": "The number of devices per node"}, - ) + device_sku_type: DeviceSKUType = DeviceSKUType.H100 + num_devices_per_node: int = 8 @staticmethod def get_type(): @@ -64,14 +44,8 @@ def get_type(): @dataclass class A100DgxNodeSKUConfig(BaseNodeSKUConfig): - device_sku_type: DeviceSKUType = field( - default=DeviceSKUType.A100, - metadata={"help": "The device SKU type"}, - ) - num_devices_per_node: int = field( - default=8, - metadata={"help": "The number of devices per node"}, - ) + device_sku_type: DeviceSKUType = DeviceSKUType.A100 + num_devices_per_node: int = 8 @staticmethod def get_type(): @@ -80,14 +54,8 @@ def get_type(): @dataclass class H100DgxNodeSKUConfig(BaseNodeSKUConfig): - device_sku_type: DeviceSKUType = field( - default=DeviceSKUType.H100, - metadata={"help": "The device SKU type"}, - ) - num_devices_per_node: int = field( - default=8, - metadata={"help": "The number of devices per node"}, - ) + device_sku_type: DeviceSKUType = DeviceSKUType.H100 + num_devices_per_node: int = 8 @staticmethod def get_type(): From 46176c3ab19b75b877ba90fb8e176ca025ce8855 Mon Sep 17 00:00:00 2001 From: Anmol Agarwal Date: Sat, 27 Jul 2024 19:33:44 -0400 Subject: [PATCH 15/24] profiler bug fix --- vidur/profiling/common/model_config.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/vidur/profiling/common/model_config.py b/vidur/profiling/common/model_config.py index a896727..057717e 100644 --- a/vidur/profiling/common/model_config.py +++ b/vidur/profiling/common/model_config.py @@ -5,7 +5,7 @@ from sarathi.config import ParallelConfig from vidur.config.model_config import BaseModelConfig -from vidur.config.config import ReplicaConfig +from vidur.types import ActivationType, NormType class ModelConfig: @@ -21,8 +21,8 @@ def __init__( use_gated_mlp: bool, use_bias: bool, use_qkv_bias: bool, - activation: str, - norm: str, + activation: ActivationType, + norm: NormType, post_attn_norm: bool, vocab_size: int, is_neox_style: Optional[bool] = True, @@ -42,8 +42,8 @@ def __init__( self.vocab_size = vocab_size self.use_bias = use_bias self.use_qkv_bias = use_qkv_bias - self.activation = activation - self.norm = norm + self.activation = str(activation) + self.norm = str(norm) self.post_attn_norm = post_attn_norm self.no_tensor_parallel = no_tensor_parallel self.partial_rotary_factor = partial_rotary_factor @@ -61,7 +61,7 @@ def __init__( @staticmethod def from_model_name(model_name: str): - model_config: BaseModelConfig = ReplicaConfig.get_model_config(model_name) + model_config: BaseModelConfig = BaseModelConfig.create_from_name(model_name) model_config_dict = asdict(model_config) return ModelConfig(model_name, **model_config_dict) From 1e4c737a61dd76f166116c6753626514b8dcec8a Mon Sep 17 00:00:00 2001 From: Anmol Agarwal Date: Sat, 27 Jul 2024 19:44:17 -0400 Subject: [PATCH 16/24] capacity config update --- vidur/config_optimizer/config_explorer/config/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vidur/config_optimizer/config_explorer/config/config.py b/vidur/config_optimizer/config_explorer/config/config.py index a9dc263..4a073fe 100644 --- a/vidur/config_optimizer/config_explorer/config/config.py +++ b/vidur/config_optimizer/config_explorer/config/config.py @@ -239,8 +239,8 @@ class SimulationConfig: def to_config_dict(self): return { **self.job_config.to_config_dict(), - "output_dir": self.get_run_dir(), - "cache_dir": self.cache_dir, + "metrics_config_output_dir": self.get_run_dir(), + "metrics_config_cache_dir": self.cache_dir, "poisson_request_interval_generator_config_qps": self.qps, "gamma_request_interval_generator_config_qps": self.qps, "time_limit": self.time_limit * 60, # to seconds From 8be5b143b4e81a801569aca777e77fa351af8d40 Mon Sep 17 00:00:00 2001 From: Anmol Agarwal Date: Sat, 27 Jul 2024 22:09:36 -0400 Subject: [PATCH 17/24] remove flush --- vidur/config_optimizer/config_explorer/capacity_search.py | 6 +----- vidur/config_optimizer/config_explorer/main.py | 8 ++++---- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/vidur/config_optimizer/config_explorer/capacity_search.py b/vidur/config_optimizer/config_explorer/capacity_search.py index bd970f7..5befbde 100644 --- a/vidur/config_optimizer/config_explorer/capacity_search.py +++ b/vidur/config_optimizer/config_explorer/capacity_search.py @@ -53,7 +53,7 @@ def _generate_run_command( cpu_affinity_command = f"taskset --cpu-list {self.cpu_core_id}" command = f"nice -n 1 {cpu_affinity_command} python -m vidur.main {scheduler_config.to_args()}" - logger.debug(f"Running command: {command}", flush=True) + logger.debug(f"Running command: {command}") return command @@ -81,7 +81,6 @@ def _is_under_sla( logger.info( f"{simulator_config.to_human_readable_name()} - Scheduling delay (P{self.args.scheduling_delay_slo_quantile}): {scheduling_delay}", - flush=True, ) return is_under_scheduling_delay_sla, scheduling_delay @@ -120,7 +119,6 @@ def is_under_sla(self, qps: float) -> tuple[bool, float]: except Exception as e: logger.error( f"Error running: {self.job_config.get_human_readable_name()}, failed with error: {e}", - flush=True, ) return False, None @@ -130,7 +128,6 @@ def search(self): """ logger.info( f"Starting search for {self.job_config.get_human_readable_name()}", - flush=True, ) left = 0 @@ -175,7 +172,6 @@ def search(self): logger.info( f"Max QPS under SLO for {self.job_config.get_human_readable_name()}: {max_qps_under_sla}", - flush=True, ) self.release_cpu_core_id() diff --git a/vidur/config_optimizer/config_explorer/main.py b/vidur/config_optimizer/config_explorer/main.py index 26ce5af..7821c9a 100644 --- a/vidur/config_optimizer/config_explorer/main.py +++ b/vidur/config_optimizer/config_explorer/main.py @@ -64,9 +64,9 @@ def get_args(): os.makedirs(args.output_dir, exist_ok=True) - logger.info("Starting config optimizer", flush=True) - logger.info(f"Args: {args}", flush=True) - logger.info(f"Config: {config}", flush=True) + logger.info("Starting config optimizer") + logger.info(f"Args: {args}") + logger.info(f"Config: {config}") # store the config and args json.dump(vars(args), open(f"{args.output_dir}/args.json", "w")) @@ -80,4 +80,4 @@ def get_args(): end_time = time.time() - logger.info(f"Simulation took time: {end_time - start_time}", flush=True) + logger.info(f"Simulation took time: {end_time - start_time}") From 47dc291ae599e50bd1f44c75e6d380c947d77a0f Mon Sep 17 00:00:00 2001 From: Anmol Agarwal Date: Sat, 27 Jul 2024 22:34:05 -0400 Subject: [PATCH 18/24] BooleanOptionalAction --- vidur/config/flat_dataclass.py | 33 +++++++++++++++++++-------------- vidur/config/utils.py | 5 +++++ 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/vidur/config/flat_dataclass.py b/vidur/config/flat_dataclass.py index c86a6d7..525d200 100644 --- a/vidur/config/flat_dataclass.py +++ b/vidur/config/flat_dataclass.py @@ -1,5 +1,5 @@ import json -from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser +from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, BooleanOptionalAction from collections import defaultdict, deque from dataclasses import MISSING, fields, make_dataclass from typing import Any, get_args @@ -8,6 +8,7 @@ from vidur.config.utils import ( get_all_subclasses, get_inner_type, + is_bool, is_composed_of_primitives, is_dict, is_list, @@ -83,6 +84,7 @@ def create_from_cli_args(cls) -> Any: for field in fields(cls): nargs = None + action = None field_type = field.type help_text = field.metadata.get("help", None) @@ -96,24 +98,27 @@ def create_from_cli_args(cls) -> Any: elif is_dict(field.type): assert is_composed_of_primitives(field.type) field_type = json.loads + elif is_bool(field.type): + action = BooleanOptionalAction + + arg_params = { + "type": field_type, + "action": action, + "help": help_text, + } # handle cases with default and default factory args if field.default is not MISSING: - parser.add_argument( - f"--{field.name}", type=field_type, default=field.default, nargs=nargs, help=help_text - ) + arg_params["default"] = field.default elif field.default_factory is not MISSING: - parser.add_argument( - f"--{field.name}", - type=field_type, - default=field.default_factory(), - nargs=nargs, - help=help_text, - ) + arg_params["default"] = field.default_factory() else: - parser.add_argument( - f"--{field.name}", type=field_type, required=True, nargs=nargs, help=help_text - ) + arg_params["required"] = True + + if nargs: + arg_params["nargs"] = nargs + + parser.add_argument(f"--{field.name}", **arg_params) args = parser.parse_args() diff --git a/vidur/config/utils.py b/vidur/config/utils.py index e206893..975caf5 100644 --- a/vidur/config/utils.py +++ b/vidur/config/utils.py @@ -53,6 +53,10 @@ def is_dict(field_type: type) -> bool: return get_origin(field_type) is dict +def is_bool(field_type: type) -> bool: + return field_type is bool + + def get_inner_type(field_type: type) -> type: return next(t for t in get_args(field_type) if t is not type(None)) @@ -60,6 +64,7 @@ def get_inner_type(field_type: type) -> type: def is_subclass(cls, parent: type) -> bool: return hasattr(cls, "__bases__") and parent in cls.__bases__ + def dataclass_to_dict(obj): if isinstance(obj, list): return [dataclass_to_dict(item) for item in obj] From 48475abf7685c39ce1cfc32c35d1918fdb4f073c Mon Sep 17 00:00:00 2001 From: Anmol Agarwal Date: Sun, 28 Jul 2024 00:45:54 -0400 Subject: [PATCH 19/24] ip address --- vidur/config_optimizer/config_explorer/ray_utils.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/vidur/config_optimizer/config_explorer/ray_utils.py b/vidur/config_optimizer/config_explorer/ray_utils.py index 1b4bbd4..884df8c 100644 --- a/vidur/config_optimizer/config_explorer/ray_utils.py +++ b/vidur/config_optimizer/config_explorer/ray_utils.py @@ -7,7 +7,16 @@ def get_ip() -> str: - return socket.gethostbyname(socket.gethostname()) + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.settimeout(0) + try: + s.connect(("10.254.254.254", 1)) + ip = s.getsockname()[0] + except Exception: + ip = "127.0.0.1" + finally: + s.close() + return ip def get_nodes() -> list[str]: From 0c1a57acf47237454c7dc7e04c6a0dbbb14a0e3a Mon Sep 17 00:00:00 2001 From: Anmol Agarwal Date: Sun, 28 Jul 2024 02:05:34 -0400 Subject: [PATCH 20/24] special handling --- .../config_explorer/capacity_search.py | 4 ++-- vidur/config_optimizer/config_explorer/ray_utils.py | 10 ++++++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/vidur/config_optimizer/config_explorer/capacity_search.py b/vidur/config_optimizer/config_explorer/capacity_search.py index 5befbde..74c7fd5 100644 --- a/vidur/config_optimizer/config_explorer/capacity_search.py +++ b/vidur/config_optimizer/config_explorer/capacity_search.py @@ -1,6 +1,7 @@ import argparse import glob import os +import platform import shlex from subprocess import Popen @@ -48,8 +49,7 @@ def _generate_run_command( scheduler_config: SimulationConfig, ): cpu_affinity_command = "" - if self.cpu_core_id is not None: - self.cpu_core_id = self.cpu_core_id + if self.cpu_core_id is not None and platform.system() != "Darwin": cpu_affinity_command = f"taskset --cpu-list {self.cpu_core_id}" command = f"nice -n 1 {cpu_affinity_command} python -m vidur.main {scheduler_config.to_args()}" diff --git a/vidur/config_optimizer/config_explorer/ray_utils.py b/vidur/config_optimizer/config_explorer/ray_utils.py index 884df8c..35a5eb7 100644 --- a/vidur/config_optimizer/config_explorer/ray_utils.py +++ b/vidur/config_optimizer/config_explorer/ray_utils.py @@ -1,4 +1,5 @@ import os +import platform import socket import time from typing import Optional @@ -7,6 +8,10 @@ def get_ip() -> str: + # special handling for macos + if platform.system() == "Darwin": + return "127.0.0.1" + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s.settimeout(0) try: @@ -26,6 +31,11 @@ def get_nodes() -> list[str]: for x in cluster_resources_keys if x.startswith("node:") and x != "node:__internal_head__" ] + + # special handling for macos, ensure that we only have one node + if platform.system() == "Darwin": + assert len(ip_addresses) == 1 + return ip_addresses From 6a766701f5570c5a4967f62f7f2320e8671cb630 Mon Sep 17 00:00:00 2001 From: Anmol Agarwal Date: Sun, 28 Jul 2024 02:10:09 -0400 Subject: [PATCH 21/24] replica_scheduler_config --- .../base_replica_scheduler.py | 16 ++++++++-------- .../lightllm_replica_scheduler.py | 12 ++++++------ .../sarathi_replica_scheduler.py | 18 +++++++++--------- .../vllm_replica_scheduler.py | 18 +++++++++--------- 4 files changed, 32 insertions(+), 32 deletions(-) diff --git a/vidur/scheduler/replica_scheduler/base_replica_scheduler.py b/vidur/scheduler/replica_scheduler/base_replica_scheduler.py index 642a856..7f9b321 100644 --- a/vidur/scheduler/replica_scheduler/base_replica_scheduler.py +++ b/vidur/scheduler/replica_scheduler/base_replica_scheduler.py @@ -21,25 +21,25 @@ def __init__( num_stages: int, execution_time_predictor: BaseExecutionTimePredictor, ) -> None: + self._config = replica_scheduler_config self._replica_config = replica_config self._request_generator_config = request_generator_config - self._replica_scheduler_config = replica_scheduler_config self._replica_id = replica.id self._num_stages = num_stages self._max_blocks_per_sequence = ( - self._request_generator_config.max_tokens // self._replica_scheduler_config.block_size + self._request_generator_config.max_tokens // self._config.block_size ) memory_planner = MemoryPlanner(self._replica_config, replica) - if not self._replica_scheduler_config.num_blocks: - self._replica_scheduler_config.num_blocks = ( + if not self._config.num_blocks: + self._config.num_blocks = ( self._max_blocks_per_sequence * memory_planner.get_max_request_slots() ) self._max_batch_size = min( memory_planner.get_max_batch_size(), - self._replica_scheduler_config.batch_size_cap, + self._config.batch_size_cap, ) logger.debug( @@ -74,7 +74,7 @@ def num_allocated_blocks(self) -> int: @property def memory_usage_percent(self) -> int: - return (self._num_allocated_blocks * 100) / self._replica_scheduler_config.num_blocks + return (self._num_allocated_blocks * 100) / self._config.num_blocks def is_empty(self) -> bool: return ( @@ -101,7 +101,7 @@ def get_replica_stage_scheduler(self, stage_id: int): return self._replica_stage_schedulers[stage_id] def can_allocate(self, num_blocks: int) -> bool: - return self._replica_scheduler_config.num_blocks - self._num_allocated_blocks >= num_blocks + return self._config.num_blocks - self._num_allocated_blocks >= num_blocks def allocate(self, request_id: int, num_blocks: int) -> None: self._num_allocated_blocks += num_blocks @@ -110,7 +110,7 @@ def allocate(self, request_id: int, num_blocks: int) -> None: else: self._allocation_map[request_id] += num_blocks - assert self._num_allocated_blocks <= self._replica_scheduler_config.num_blocks + assert self._num_allocated_blocks <= self._config.num_blocks def free(self, *request_ids: List[int]) -> None: for request_id in request_ids: diff --git a/vidur/scheduler/replica_scheduler/lightllm_replica_scheduler.py b/vidur/scheduler/replica_scheduler/lightllm_replica_scheduler.py index 3a6bfae..567524a 100644 --- a/vidur/scheduler/replica_scheduler/lightllm_replica_scheduler.py +++ b/vidur/scheduler/replica_scheduler/lightllm_replica_scheduler.py @@ -15,10 +15,10 @@ def __init__(self, *args, **kwargs): self._preempted_requests: List[Request] = [] self._num_running_batches = 0 self._max_micro_batch_size = ( - self._replica_scheduler_config.batch_size_cap // self._num_stages + self._config.batch_size_cap // self._num_stages ) assert ( - self._replica_scheduler_config.block_size == 1 + self._config.block_size == 1 ), "LightLLM scheduler only supports block size of 1." assert ( self._num_stages == 1 @@ -63,7 +63,7 @@ def _can_allocate_request(self, request: Request) -> bool: need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max() - return need_max_token_num < self._replica_scheduler_config.num_blocks + return need_max_token_num < self._config.num_blocks def _allocate_request(self, request: Request) -> None: if request.id not in self._allocation_map: @@ -86,10 +86,10 @@ def _get_prefill_batch(self) -> Batch: next_num_tokens = self._get_request_next_num_tokens(request) - if num_batch_tokens + next_num_tokens > self._replica_scheduler_config.max_tokens_in_batch: + if num_batch_tokens + next_num_tokens > self._config.max_tokens_in_batch: break - if len(self._allocation_map) == self._replica_scheduler_config.batch_size_cap: + if len(self._allocation_map) == self._config.batch_size_cap: break if len(requests) == self._max_micro_batch_size: @@ -142,7 +142,7 @@ def _get_next_batch(self) -> Batch: self._num_waiting_iters = 0 return batch - if self._num_waiting_iters >= self._replica_scheduler_config.max_waiting_iters: + if self._num_waiting_iters >= self._config.max_waiting_iters: self._num_waiting_iters = 0 batch = self._get_prefill_batch() if batch: diff --git a/vidur/scheduler/replica_scheduler/sarathi_replica_scheduler.py b/vidur/scheduler/replica_scheduler/sarathi_replica_scheduler.py index 2bedc89..68aaeb0 100644 --- a/vidur/scheduler/replica_scheduler/sarathi_replica_scheduler.py +++ b/vidur/scheduler/replica_scheduler/sarathi_replica_scheduler.py @@ -16,34 +16,34 @@ def __init__(self, *args, **kwargs): # For vLLM and its derivatives, we only need to set a loose max batch size # Memory requirements are handled explicitly by the scheduler self._max_micro_batch_size = ( - self._replica_scheduler_config.batch_size_cap // self._num_stages + self._config.batch_size_cap // self._num_stages ) self._watermark_blocks = int( - self._replica_scheduler_config.watermark_blocks_fraction * self._replica_scheduler_config.num_blocks + self._config.watermark_blocks_fraction * self._config.num_blocks ) def _can_allocate_request(self, request: Request) -> bool: if request.id not in self._allocation_map: # new request - num_required_blocks = ceil(request.num_prefill_tokens / self._replica_scheduler_config.block_size) + num_required_blocks = ceil(request.num_prefill_tokens / self._config.block_size) return ( - self._replica_scheduler_config.num_blocks + self._config.num_blocks - self._num_allocated_blocks - num_required_blocks >= self._watermark_blocks ) # vllm requires at least one block to be available - return self._replica_scheduler_config.num_blocks - self._num_allocated_blocks >= 1 + return self._config.num_blocks - self._num_allocated_blocks >= 1 def _allocate_request(self, request: Request) -> None: if request.id not in self._allocation_map: # new request - num_required_blocks = ceil(request.num_prefill_tokens / self._replica_scheduler_config.block_size) + num_required_blocks = ceil(request.num_prefill_tokens / self._config.block_size) self.allocate(request.id, num_required_blocks) return - num_tokens_reserved = self._allocation_map[request.id] * self._replica_scheduler_config.block_size + num_tokens_reserved = self._allocation_map[request.id] * self._config.block_size num_tokens_required = max(0, request.num_processed_tokens - num_tokens_reserved) assert ( @@ -74,7 +74,7 @@ def _get_request_next_num_tokens( next_num_tokens = min( request.num_prefill_tokens - request.num_processed_tokens, - self._replica_scheduler_config.chunk_size - num_batch_tokens, + self._config.chunk_size - num_batch_tokens, ) next_num_tokens = max(0, next_num_tokens) @@ -152,7 +152,7 @@ def _get_next_batch(self) -> Batch: skipped_requests = [] while self._request_queue: - if len(self._allocation_map) == self._replica_scheduler_config.batch_size_cap: + if len(self._allocation_map) == self._config.batch_size_cap: break if len(requests) == self._max_micro_batch_size: diff --git a/vidur/scheduler/replica_scheduler/vllm_replica_scheduler.py b/vidur/scheduler/replica_scheduler/vllm_replica_scheduler.py index 4063c57..a05043c 100644 --- a/vidur/scheduler/replica_scheduler/vllm_replica_scheduler.py +++ b/vidur/scheduler/replica_scheduler/vllm_replica_scheduler.py @@ -16,10 +16,10 @@ def __init__(self, *args, **kwargs): # For vLLM and its derivatives, we only need to set a loose max batch size # Memory requirements are handled explicitly by the scheduler self._max_micro_batch_size = ( - self._replica_scheduler_config.batch_size_cap // self._num_stages + self._config.batch_size_cap // self._num_stages ) self._watermark_blocks = int( - self._replica_scheduler_config.watermark_blocks_fraction * self._replica_scheduler_config.num_blocks + self._config.watermark_blocks_fraction * self._config.num_blocks ) def on_batch_end(self, batch: Batch) -> None: @@ -34,25 +34,25 @@ def on_batch_end(self, batch: Batch) -> None: def _can_allocate_request(self, request: Request) -> bool: if request.id not in self._allocation_map: # new request - num_required_blocks = ceil((request.num_prefill_tokens) / self._replica_scheduler_config.block_size) + num_required_blocks = ceil((request.num_prefill_tokens) / self._config.block_size) return ( - self._replica_scheduler_config.num_blocks + self._config.num_blocks - self._num_allocated_blocks - num_required_blocks >= self._watermark_blocks ) # vllm requires at least one block to be available - return self._replica_scheduler_config.num_blocks - self._num_allocated_blocks >= 1 + return self._config.num_blocks - self._num_allocated_blocks >= 1 def _allocate_request(self, request: Request) -> None: if request.id not in self._allocation_map: # new request - num_required_blocks = ceil((request.num_prefill_tokens) / self._replica_scheduler_config.block_size) + num_required_blocks = ceil((request.num_prefill_tokens) / self._config.block_size) self.allocate(request.id, num_required_blocks) return - num_tokens_reserved = self._allocation_map[request.id] * self._replica_scheduler_config.block_size + num_tokens_reserved = self._allocation_map[request.id] * self._config.block_size num_tokens_required = max(0, request.num_processed_tokens - num_tokens_reserved) assert ( num_tokens_required == 0 or num_tokens_required == 1 @@ -78,10 +78,10 @@ def _get_next_batch(self) -> Batch: new_num_tokens = num_tokens + [next_num_tokens] new_num_batch_tokens = len(new_num_tokens) * max(new_num_tokens) - if new_num_batch_tokens > self._replica_scheduler_config.max_tokens_in_batch: + if new_num_batch_tokens > self._config.max_tokens_in_batch: break - if len(self._allocation_map) == self._replica_scheduler_config.batch_size_cap: + if len(self._allocation_map) == self._config.batch_size_cap: break if len(requests) == self._max_micro_batch_size: From 4d1be0c6d85be96a5b5ceb234fb4135e0bcb289e Mon Sep 17 00:00:00 2001 From: Anmol Agarwal Date: Sun, 28 Jul 2024 02:11:48 -0400 Subject: [PATCH 22/24] shorten config --- .../global_scheduler/base_global_scheduler.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vidur/scheduler/global_scheduler/base_global_scheduler.py b/vidur/scheduler/global_scheduler/base_global_scheduler.py index cae2d10..493a3e5 100644 --- a/vidur/scheduler/global_scheduler/base_global_scheduler.py +++ b/vidur/scheduler/global_scheduler/base_global_scheduler.py @@ -17,11 +17,11 @@ def __init__(self, config: SimulationConfig, replicas: Dict[int, Replica]): self._num_replicas = len(self._replicas) execution_time_predictor = ExecutionTimePredictorRegistry.get( - self._config.execution_time_predictor_config.get_type(), - predictor_config=self._config.execution_time_predictor_config, - replica_config=self._config.cluster_config.replica_config, - replica_scheduler_config=self._config.cluster_config.replica_scheduler_config, - metrics_config=self._config.metrics_config, + config.execution_time_predictor_config.get_type(), + predictor_config=config.execution_time_predictor_config, + replica_config=config.cluster_config.replica_config, + replica_scheduler_config=config.cluster_config.replica_scheduler_config, + metrics_config=config.metrics_config, ) self._replica_schedulers = { replica_id: ReplicaSchedulerRegistry.get( From e21a9353fe3f9998dea38470789add8b222da1d5 Mon Sep 17 00:00:00 2001 From: Anmol Agarwal Date: Sun, 28 Jul 2024 02:15:59 -0400 Subject: [PATCH 23/24] model name --- vidur/config/model_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vidur/config/model_config.py b/vidur/config/model_config.py index 462534e..5144c13 100644 --- a/vidur/config/model_config.py +++ b/vidur/config/model_config.py @@ -103,7 +103,7 @@ class Llama3_8BModelConfig(Llama2ModelConfig): @staticmethod def get_name(): - return "meta-llama/Meta-Llama-3-8b" + return "meta-llama/Meta-Llama-3-8B" @dataclass From 14db47bc0b8fc333ce729a33959096943737febb Mon Sep 17 00:00:00 2001 From: Anmol Agarwal Date: Sun, 28 Jul 2024 02:26:44 -0400 Subject: [PATCH 24/24] make format --- vidur/config/base_fixed_config.py | 2 +- vidur/config/config.py | 52 ++++++++---- vidur/config/device_sku_config.py | 1 - vidur/config/flat_dataclass.py | 8 +- vidur/config/model_config.py | 9 ++- vidur/config/node_sku_config.py | 2 +- vidur/config/utils.py | 10 +-- vidur/entities/cluster.py | 9 ++- vidur/entities/replica.py | 24 ++++-- .../base_execution_time_predictor.py | 20 +++-- .../execution_time_predictor_registry.py | 1 + ...ear_regression_execution_time_predictor.py | 18 +++-- ...random_forrest_execution_time_predictor.py | 18 +++-- .../sklearn_execution_time_predictor.py | 81 ++++++++++++------- vidur/metrics/cdf_sketch.py | 2 +- vidur/metrics/data_series.py | 2 +- vidur/metrics/metrics_store.py | 8 +- vidur/metrics/series_average_meter.py | 1 + vidur/profiling/common/model_config.py | 2 +- .../synthetic_request_generator.py | 4 +- .../trace_replay_request_generator.py | 4 +- .../base_replica_scheduler.py | 6 +- .../lightllm_replica_scheduler.py | 4 +- .../sarathi_replica_scheduler.py | 12 +-- .../vllm_replica_scheduler.py | 12 +-- vidur/simulator.py | 12 ++- vidur/types/__init__.py | 9 +-- vidur/utils/param_counter.py | 19 ++++- 28 files changed, 224 insertions(+), 128 deletions(-) diff --git a/vidur/config/base_fixed_config.py b/vidur/config/base_fixed_config.py index 1ab355f..2d469b3 100644 --- a/vidur/config/base_fixed_config.py +++ b/vidur/config/base_fixed_config.py @@ -14,7 +14,7 @@ def create_from_type(cls, type_: Any) -> Any: if subclass.get_type() == type_: return subclass() raise ValueError(f"[{cls.__name__}] Invalid type: {type_}") - + @classmethod def create_from_name(cls, name: str) -> Any: for subclass in get_all_subclasses(cls): diff --git a/vidur/config/config.py b/vidur/config/config.py index 5d16399..70ac5da 100644 --- a/vidur/config/config.py +++ b/vidur/config/config.py @@ -1,9 +1,9 @@ +import json +import os from abc import ABC from dataclasses import dataclass, field from datetime import datetime -import json -import os -from typing import Optional, List +from typing import List, Optional from vidur.config.base_poly_config import BasePolyConfig from vidur.config.device_sku_config import BaseDeviceSKUConfig @@ -12,7 +12,14 @@ from vidur.config.node_sku_config import BaseNodeSKUConfig from vidur.config.utils import dataclass_to_dict from vidur.logger import init_logger -from vidur.types import ReplicaSchedulerType, GlobalSchedulerType, ExecutionTimePredictorType, RequestGeneratorType, RequestIntervalGeneratorType, RequestLengthGeneratorType +from vidur.types import ( + ExecutionTimePredictorType, + GlobalSchedulerType, + ReplicaSchedulerType, + RequestGeneratorType, + RequestIntervalGeneratorType, + RequestLengthGeneratorType, +) logger = init_logger(__name__) @@ -53,7 +60,9 @@ class TraceRequestIntervalGeneratorConfig(BaseRequestIntervalGeneratorConfig): ) time_scale_factor: float = field( default=0.3, - metadata={"help": "Time scale factor for the trace request interval generator."}, + metadata={ + "help": "Time scale factor for the trace request interval generator." + }, ) @staticmethod @@ -81,7 +90,9 @@ class GammaRequestIntervalGeneratorConfig(BaseRequestIntervalGeneratorConfig): ) cv: float = field( default=0.5, - metadata={"help": "Coefficient of variation for Gamma Request Interval Generator."}, + metadata={ + "help": "Coefficient of variation for Gamma Request Interval Generator." + }, ) @staticmethod @@ -104,11 +115,15 @@ class TraceRequestLengthGeneratorConfig(BaseRequestLengthGeneratorConfig): ) prefill_scale_factor: float = field( default=1, - metadata={"help": "Prefill scale factor for the trace request length generator."}, + metadata={ + "help": "Prefill scale factor for the trace request length generator." + }, ) decode_scale_factor: float = field( default=1, - metadata={"help": "Decode scale factor for the trace request length generator."}, + metadata={ + "help": "Decode scale factor for the trace request length generator." + }, ) max_tokens: int = field( default=4096, @@ -160,7 +175,9 @@ class UniformRequestLengthGeneratorConfig(BaseRequestLengthGeneratorConfig): ) prefill_to_decode_ratio: float = field( default=20.0, - metadata={"help": "Prefill to decode ratio for Uniform Request Length Generator."}, + metadata={ + "help": "Prefill to decode ratio for Uniform Request Length Generator." + }, ) @staticmethod @@ -196,7 +213,6 @@ class BaseRequestGeneratorConfig(BasePolyConfig): ) - @dataclass class SyntheticRequestGeneratorConfig(BaseRequestGeneratorConfig): length_generator_config: BaseRequestLengthGeneratorConfig = field( @@ -473,9 +489,15 @@ class ReplicaConfig: def __post_init__(self): self.world_size = self.num_pipeline_stages * self.tensor_parallel_size - self.model_config: BaseModelConfig = BaseModelConfig.create_from_name(self.model_name) - self.device_config: BaseDeviceSKUConfig = BaseDeviceSKUConfig.create_from_type_string(self.device) - self.node_config: BaseNodeSKUConfig = BaseNodeSKUConfig.create_from_type_string(self.network_device) + self.model_config: BaseModelConfig = BaseModelConfig.create_from_name( + self.model_name + ) + self.device_config: BaseDeviceSKUConfig = ( + BaseDeviceSKUConfig.create_from_type_string(self.device) + ) + self.node_config: BaseNodeSKUConfig = BaseNodeSKUConfig.create_from_type_string( + self.network_device + ) @dataclass @@ -558,7 +580,7 @@ class BaseExecutionTimePredictorConfig(BasePolyConfig): default=0.1, metadata={"help": "Attention prefill batching overhead fraction."}, ) - nccl_cpu_launch_overhead_ms: float = field( + nccl_cpu_launch_overhead_ms: float = field( default=0.02, metadata={"help": "NCCL CPU launch overhead in ms."}, ) @@ -648,7 +670,7 @@ class SimulationConfig(ABC): metadata={"help": "Logging level."}, ) time_limit: int = field( - default=0, # in seconds, 0 is no limit + default=0, # in seconds, 0 is no limit metadata={"help": "Time limit for simulation in seconds. 0 means no limit."}, ) cluster_config: ClusterConfig = field( diff --git a/vidur/config/device_sku_config.py b/vidur/config/device_sku_config.py index a220f5e..a92646f 100644 --- a/vidur/config/device_sku_config.py +++ b/vidur/config/device_sku_config.py @@ -41,4 +41,3 @@ class H100DeviceSKUConfig(BaseDeviceSKUConfig): @staticmethod def get_type(): return DeviceSKUType.H100 - diff --git a/vidur/config/flat_dataclass.py b/vidur/config/flat_dataclass.py index 525d200..fe23dac 100644 --- a/vidur/config/flat_dataclass.py +++ b/vidur/config/flat_dataclass.py @@ -1,5 +1,9 @@ import json -from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, BooleanOptionalAction +from argparse import ( + ArgumentDefaultsHelpFormatter, + ArgumentParser, + BooleanOptionalAction, +) from collections import defaultdict, deque from dataclasses import MISSING, fields, make_dataclass from typing import Any, get_args @@ -114,7 +118,7 @@ def create_from_cli_args(cls) -> Any: arg_params["default"] = field.default_factory() else: arg_params["required"] = True - + if nargs: arg_params["nargs"] = nargs diff --git a/vidur/config/model_config.py b/vidur/config/model_config.py index 5144c13..0057d78 100644 --- a/vidur/config/model_config.py +++ b/vidur/config/model_config.py @@ -3,7 +3,7 @@ from vidur.config.base_fixed_config import BaseFixedConfig from vidur.logger import init_logger -from vidur.types import NormType, ActivationType +from vidur.types import ActivationType, NormType logger = init_logger(__name__) @@ -58,7 +58,7 @@ class CodeLlama34BModelConfig(Llama2ModelConfig): num_kv_heads: int = 8 embedding_dim: int = 8192 mlp_hidden_dim: int = 22016 - + @staticmethod def get_name(): return "codellama/CodeLlama-34b-Instruct-hf" @@ -71,7 +71,7 @@ class Llama2_7BModelConfig(Llama2ModelConfig): num_kv_heads: int = 32 embedding_dim: int = 4096 mlp_hidden_dim: int = 11008 - + @staticmethod def get_name(): return "meta-llama/Llama-2-7b-hf" @@ -84,7 +84,7 @@ class Llama2_70BModelConfig(Llama2ModelConfig): num_kv_heads: int = 8 embedding_dim: int = 8192 mlp_hidden_dim: int = 28672 - + @staticmethod def get_name(): return "meta-llama/Llama-2-70b-hf" @@ -121,6 +121,7 @@ class Llama3_70BModelConfig(Llama2ModelConfig): def get_name(): return "meta-llama/Meta-Llama-3-70B" + @dataclass class InternLM2ModelConfig(Llama2ModelConfig): max_position_embeddings: int = 32768 diff --git a/vidur/config/node_sku_config.py b/vidur/config/node_sku_config.py index 43dcdd9..34eb805 100644 --- a/vidur/config/node_sku_config.py +++ b/vidur/config/node_sku_config.py @@ -2,7 +2,7 @@ from vidur.config.base_fixed_config import BaseFixedConfig from vidur.logger import init_logger -from vidur.types import NodeSKUType, DeviceSKUType +from vidur.types import DeviceSKUType, NodeSKUType logger = init_logger(__name__) diff --git a/vidur/config/utils.py b/vidur/config/utils.py index 975caf5..8627dd3 100644 --- a/vidur/config/utils.py +++ b/vidur/config/utils.py @@ -1,4 +1,4 @@ -from dataclasses import is_dataclass, fields +from dataclasses import fields, is_dataclass from typing import Union, get_args, get_origin primitive_types = {int, str, float, bool, type(None)} @@ -78,10 +78,10 @@ def dataclass_to_dict(obj): if key not in data: data[key] = dataclass_to_dict(value) # Include the name of the class - if hasattr(obj, 'get_type') and callable(getattr(obj, 'get_type')): - data['name'] = str(obj.get_type()) - elif hasattr(obj, 'get_name') and callable(getattr(obj, 'get_name')): - data['name'] = obj.get_name() + if hasattr(obj, "get_type") and callable(getattr(obj, "get_type")): + data["name"] = str(obj.get_type()) + elif hasattr(obj, "get_name") and callable(getattr(obj, "get_name")): + data["name"] = obj.get_name() return data else: return obj diff --git a/vidur/entities/cluster.py b/vidur/entities/cluster.py index fa23171..013b86d 100644 --- a/vidur/entities/cluster.py +++ b/vidur/entities/cluster.py @@ -1,6 +1,6 @@ import json -from vidur.config import ClusterConfig, MetricsConfig, BaseRequestGeneratorConfig +from vidur.config import BaseRequestGeneratorConfig, ClusterConfig, MetricsConfig from vidur.entities.base_entity import BaseEntity from vidur.entities.replica import Replica from vidur.logger import init_logger @@ -9,7 +9,12 @@ class Cluster(BaseEntity): - def __init__(self, cluster_config: ClusterConfig, metrics_config: MetricsConfig, generator_config: BaseRequestGeneratorConfig) -> None: + def __init__( + self, + cluster_config: ClusterConfig, + metrics_config: MetricsConfig, + generator_config: BaseRequestGeneratorConfig, + ) -> None: self._id = Cluster.generate_id() self._config = cluster_config diff --git a/vidur/entities/replica.py b/vidur/entities/replica.py index 7b8425f..bda2293 100644 --- a/vidur/entities/replica.py +++ b/vidur/entities/replica.py @@ -1,6 +1,6 @@ from math import ceil -from vidur.config import ReplicaConfig, BaseRequestGeneratorConfig +from vidur.config import BaseRequestGeneratorConfig, ReplicaConfig from vidur.entities.base_entity import BaseEntity from vidur.logger import init_logger @@ -8,7 +8,11 @@ class Replica(BaseEntity): - def __init__(self, replica_config: ReplicaConfig, generator_config: BaseRequestGeneratorConfig) -> None: + def __init__( + self, + replica_config: ReplicaConfig, + generator_config: BaseRequestGeneratorConfig, + ) -> None: self._id = Replica.generate_id() self._replica_config = replica_config @@ -16,9 +20,13 @@ def __init__(self, replica_config: ReplicaConfig, generator_config: BaseRequestG self._device_config = replica_config.device_config self._generator_config = generator_config - assert self._model_config.num_layers % self._replica_config.num_pipeline_stages == 0 assert ( - self._model_config.embedding_dim % self._replica_config.tensor_parallel_size == 0 + self._model_config.num_layers % self._replica_config.num_pipeline_stages + == 0 + ) + assert ( + self._model_config.embedding_dim % self._replica_config.tensor_parallel_size + == 0 ) @property @@ -67,11 +75,15 @@ def attention_head_dim(self) -> int: @property def q_heads_per_tensor_parallel_worker(self) -> int: - return self._model_config.num_q_heads // self._replica_config.tensor_parallel_size + return ( + self._model_config.num_q_heads // self._replica_config.tensor_parallel_size + ) @property def kv_heads_per_tensor_parallel_worker(self) -> int: - return ceil(self._model_config.num_kv_heads / self._replica_config.tensor_parallel_size) + return ceil( + self._model_config.num_kv_heads / self._replica_config.tensor_parallel_size + ) @property def num_tensor_parallel_workers(self) -> int: diff --git a/vidur/execution_time_predictor/base_execution_time_predictor.py b/vidur/execution_time_predictor/base_execution_time_predictor.py index 0aafba3..f399c8e 100644 --- a/vidur/execution_time_predictor/base_execution_time_predictor.py +++ b/vidur/execution_time_predictor/base_execution_time_predictor.py @@ -2,19 +2,21 @@ from vidur.config import ( BaseExecutionTimePredictorConfig, - ReplicaConfig, BaseReplicaSchedulerConfig, - MetricsConfig + MetricsConfig, + ReplicaConfig, ) from vidur.entities import Batch, ExecutionTime class BaseExecutionTimePredictor(ABC): - def __init__(self, - predictor_config: BaseExecutionTimePredictorConfig, - replica_config: ReplicaConfig, - replica_scheduler_config: BaseReplicaSchedulerConfig, - metrics_config: MetricsConfig) -> None: + def __init__( + self, + predictor_config: BaseExecutionTimePredictorConfig, + replica_config: ReplicaConfig, + replica_scheduler_config: BaseReplicaSchedulerConfig, + metrics_config: MetricsConfig, + ) -> None: self._config = predictor_config self._replica_config = replica_config self._model_config = replica_config.model_config @@ -23,7 +25,9 @@ def __init__(self, self._replica_scheduler_provider = str(replica_scheduler_config.get_type()) self._block_size = replica_scheduler_config.block_size self._cache_dir = metrics_config.cache_dir - self._num_layers_per_pipeline_stage = self._model_config.num_layers // self._replica_config.num_pipeline_stages + self._num_layers_per_pipeline_stage = ( + self._model_config.num_layers // self._replica_config.num_pipeline_stages + ) def get_execution_time(self, batch: Batch, pipeline_stage: int) -> ExecutionTime: if pipeline_stage == self._replica_config.num_pipeline_stages - 1: diff --git a/vidur/execution_time_predictor/execution_time_predictor_registry.py b/vidur/execution_time_predictor/execution_time_predictor_registry.py index 71b3221..a48c110 100644 --- a/vidur/execution_time_predictor/execution_time_predictor_registry.py +++ b/vidur/execution_time_predictor/execution_time_predictor_registry.py @@ -13,6 +13,7 @@ class ExecutionTimePredictorRegistry(BaseRegistry): def get_key_from_str(cls, key_str: str) -> ExecutionTimePredictorType: return ExecutionTimePredictorType.from_str(key_str) + ExecutionTimePredictorRegistry.register( ExecutionTimePredictorType.RANDOM_FORREST, RandomForrestExecutionTimePredictor ) diff --git a/vidur/execution_time_predictor/linear_regression_execution_time_predictor.py b/vidur/execution_time_predictor/linear_regression_execution_time_predictor.py index 8f9f261..8dd32b7 100644 --- a/vidur/execution_time_predictor/linear_regression_execution_time_predictor.py +++ b/vidur/execution_time_predictor/linear_regression_execution_time_predictor.py @@ -3,10 +3,10 @@ from sklearn.preprocessing import PolynomialFeatures from vidur.config import ( + BaseReplicaSchedulerConfig, LinearRegressionExecutionTimePredictorConfig, + MetricsConfig, ReplicaConfig, - BaseReplicaSchedulerConfig, - MetricsConfig ) from vidur.execution_time_predictor.sklearn_execution_time_predictor import ( SklearnExecutionTimePredictor, @@ -14,17 +14,19 @@ class LinearRegressionExecutionTimePredictor(SklearnExecutionTimePredictor): - def __init__(self, - predictor_config: LinearRegressionExecutionTimePredictorConfig, - replica_config: ReplicaConfig, - replica_scheduler_config: BaseReplicaSchedulerConfig, - metrics_config: MetricsConfig) -> None: + def __init__( + self, + predictor_config: LinearRegressionExecutionTimePredictorConfig, + replica_config: ReplicaConfig, + replica_scheduler_config: BaseReplicaSchedulerConfig, + metrics_config: MetricsConfig, + ) -> None: # will trigger model training super().__init__( predictor_config=predictor_config, replica_config=replica_config, replica_scheduler_config=replica_scheduler_config, - metrics_config=metrics_config + metrics_config=metrics_config, ) def _get_grid_search_params(self): diff --git a/vidur/execution_time_predictor/random_forrest_execution_time_predictor.py b/vidur/execution_time_predictor/random_forrest_execution_time_predictor.py index 7e8f232..27fd748 100644 --- a/vidur/execution_time_predictor/random_forrest_execution_time_predictor.py +++ b/vidur/execution_time_predictor/random_forrest_execution_time_predictor.py @@ -1,10 +1,10 @@ from sklearn.ensemble import RandomForestRegressor from vidur.config import ( + BaseReplicaSchedulerConfig, + MetricsConfig, RandomForrestExecutionTimePredictorConfig, ReplicaConfig, - BaseReplicaSchedulerConfig, - MetricsConfig ) from vidur.execution_time_predictor.sklearn_execution_time_predictor import ( SklearnExecutionTimePredictor, @@ -12,17 +12,19 @@ class RandomForrestExecutionTimePredictor(SklearnExecutionTimePredictor): - def __init__(self, - predictor_config: RandomForrestExecutionTimePredictorConfig, - replica_config: ReplicaConfig, - replica_scheduler_config: BaseReplicaSchedulerConfig, - metrics_config: MetricsConfig) -> None: + def __init__( + self, + predictor_config: RandomForrestExecutionTimePredictorConfig, + replica_config: ReplicaConfig, + replica_scheduler_config: BaseReplicaSchedulerConfig, + metrics_config: MetricsConfig, + ) -> None: # will trigger model training super().__init__( predictor_config=predictor_config, replica_config=replica_config, replica_scheduler_config=replica_scheduler_config, - metrics_config=metrics_config + metrics_config=metrics_config, ) def _get_grid_search_params(self): diff --git a/vidur/execution_time_predictor/sklearn_execution_time_predictor.py b/vidur/execution_time_predictor/sklearn_execution_time_predictor.py index f936288..a5a9646 100644 --- a/vidur/execution_time_predictor/sklearn_execution_time_predictor.py +++ b/vidur/execution_time_predictor/sklearn_execution_time_predictor.py @@ -14,9 +14,9 @@ from vidur.config import ( BaseExecutionTimePredictorConfig, - ReplicaConfig, BaseReplicaSchedulerConfig, - MetricsConfig + MetricsConfig, + ReplicaConfig, ) from vidur.entities import Batch from vidur.execution_time_predictor.base_execution_time_predictor import ( @@ -28,40 +28,44 @@ class SklearnExecutionTimePredictor(BaseExecutionTimePredictor): - def __init__(self, - predictor_config: BaseExecutionTimePredictorConfig, - replica_config: ReplicaConfig, - replica_scheduler_config: BaseReplicaSchedulerConfig, - metrics_config: MetricsConfig) -> None: + def __init__( + self, + predictor_config: BaseExecutionTimePredictorConfig, + replica_config: ReplicaConfig, + replica_scheduler_config: BaseReplicaSchedulerConfig, + metrics_config: MetricsConfig, + ) -> None: super().__init__( predictor_config=predictor_config, replica_config=replica_config, replica_scheduler_config=replica_scheduler_config, - metrics_config=metrics_config + metrics_config=metrics_config, ) os.makedirs(self._cache_dir, exist_ok=True) # These overheads are only for GQA models self._attention_prefill_batching_overhead_fraction = ( - ( - self._config.attention_prefill_batching_overhead_fraction - ) + (self._config.attention_prefill_batching_overhead_fraction) if self._model_config.num_q_heads > self._model_config.num_kv_heads else 0 ) self._attention_decode_batching_overhead_fraction = ( - ( - self._config.attention_decode_batching_overhead_fraction - ) + (self._config.attention_decode_batching_overhead_fraction) if self._model_config.num_q_heads > self._model_config.num_kv_heads else 0 ) if self._replica_scheduler_provider == "orca": - self._max_tokens = self._config.prediction_max_tokens_per_request * self._config.prediction_max_batch_size + self._max_tokens = ( + self._config.prediction_max_tokens_per_request + * self._config.prediction_max_batch_size + ) else: self._max_tokens = self._config.prediction_max_tokens_per_request - num_workers = self._replica_config.num_pipeline_stages * self._replica_config.tensor_parallel_size + num_workers = ( + self._replica_config.num_pipeline_stages + * self._replica_config.tensor_parallel_size + ) devices_per_node = self._replica_config.node_config.num_devices_per_node assert ( num_workers < devices_per_node or num_workers % devices_per_node == 0 @@ -74,7 +78,7 @@ def __init__(self, self._attention_input_file, self._all_reduce_input_file, self._send_recv_input_file, - self._cpu_overhead_input_file + self._cpu_overhead_input_file, ) = self._get_input_files() self._models = self._train_models() @@ -89,10 +93,12 @@ def _get_input_files(self) -> Tuple[str, str, str, str, str]: self._config.cpu_overhead_input_file, ] for i in range(len(input_files)): - input_files[i] = input_files[i].replace( - "{DEVICE}", self._replica_config.device).replace( - "{MODEL}", self._model_config.get_name()).replace( - "{NETWORK_DEVICE}", self._replica_config.network_device) + input_files[i] = ( + input_files[i] + .replace("{DEVICE}", self._replica_config.device) + .replace("{MODEL}", self._model_config.get_name()) + .replace("{NETWORK_DEVICE}", self._replica_config.network_device) + ) return tuple(input_files) @@ -117,7 +123,10 @@ def _load_compute_df(self, file_path: str) -> pd.DataFrame: & (df["n_expanded_embd"] == self._model_config.mlp_hidden_dim) & (df["use_gated_mlp"] == self._model_config.use_gated_mlp) & (df["vocab_size"] == self._model_config.vocab_size) - & (df["num_tensor_parallel_workers"] == self._replica_config.tensor_parallel_size) + & ( + df["num_tensor_parallel_workers"] + == self._replica_config.tensor_parallel_size + ) ] for column in [ @@ -148,7 +157,10 @@ def _load_attention_df(self, file_path: str) -> pd.DataFrame: & (df["n_q_head"] == self._model_config.num_q_heads) & (df["n_kv_head"] == self._model_config.num_kv_heads) & (df["block_size"] == self._block_size) - & (df["num_tensor_parallel_workers"] == self._replica_config.tensor_parallel_size) + & ( + df["num_tensor_parallel_workers"] + == self._replica_config.tensor_parallel_size + ) ] def _load_all_reduce_df(self, file_path: str) -> pd.DataFrame: @@ -176,7 +188,10 @@ def _load_cpu_overhead_df(self, file_path: str) -> pd.DataFrame: df = self._read_input_file(file_path) filtered_df = df[ (df["model_name"] == self._model_config.get_name()) - & (df["tensor_parallel_degree"] == self._replica_config.tensor_parallel_size) + & ( + df["tensor_parallel_degree"] + == self._replica_config.tensor_parallel_size + ) ] return filtered_df @@ -627,9 +642,13 @@ def _predict_for_cpu_overhead_models(self) -> Dict[str, Any]: def _predict_for_attention_layer_models(self) -> Dict[str, Any]: predictions = {} - decode_batch_size_range = np.arange(1, self._config.prediction_max_batch_size + 1) + decode_batch_size_range = np.arange( + 1, self._config.prediction_max_batch_size + 1 + ) decode_kv_cache_size_range = np.arange( - 0, self._config.prediction_max_tokens_per_request + 1, self._config.kv_cache_prediction_granularity + 0, + self._config.prediction_max_tokens_per_request + 1, + self._config.kv_cache_prediction_granularity, ) decode_prefill_chunk_size_range = [0] decode_batch_size, decode_kv_cache_size, decode_prefill_chunk_size = zip( @@ -642,7 +661,9 @@ def _predict_for_attention_layer_models(self) -> Dict[str, Any]: prefill_batch_size_range = [1] prefill_kv_cache_size_range = np.arange( - 0, self._config.prediction_max_tokens_per_request + 1, self._config.kv_cache_prediction_granularity + 0, + self._config.prediction_max_tokens_per_request + 1, + self._config.kv_cache_prediction_granularity, ) prefill_prefill_chunk_size_range = np.arange( 1, self._config.prediction_max_prefill_chunk_size + 1 @@ -718,7 +739,11 @@ def _get_batch_decode_attention_params(self, batch: Batch) -> Tuple[int, int]: decode_batch_size = len(decode_kv_cache_sizes) decode_avg_kv_cache_size = int(np.mean(decode_kv_cache_sizes)) decode_avg_kv_cache_size = ( - (decode_avg_kv_cache_size + self._config.kv_cache_prediction_granularity - 1) + ( + decode_avg_kv_cache_size + + self._config.kv_cache_prediction_granularity + - 1 + ) // self._config.kv_cache_prediction_granularity ) * self._config.kv_cache_prediction_granularity diff --git a/vidur/metrics/cdf_sketch.py b/vidur/metrics/cdf_sketch.py index 3083ce1..50aeebf 100644 --- a/vidur/metrics/cdf_sketch.py +++ b/vidur/metrics/cdf_sketch.py @@ -1,9 +1,9 @@ import numpy as np import pandas as pd import plotly_express as px +import wandb from ddsketch.ddsketch import DDSketch -import wandb from vidur.logger import init_logger logger = init_logger(__name__) diff --git a/vidur/metrics/data_series.py b/vidur/metrics/data_series.py index 7770496..51be848 100644 --- a/vidur/metrics/data_series.py +++ b/vidur/metrics/data_series.py @@ -4,8 +4,8 @@ import numpy as np import pandas as pd import plotly_express as px - import wandb + from vidur.logger import init_logger logger = init_logger(__name__) diff --git a/vidur/metrics/metrics_store.py b/vidur/metrics/metrics_store.py index c41a954..c1e76b0 100644 --- a/vidur/metrics/metrics_store.py +++ b/vidur/metrics/metrics_store.py @@ -6,7 +6,7 @@ import plotly_express as px import wandb -from vidur.config import MetricsConfig, ClusterConfig +from vidur.config import ClusterConfig, MetricsConfig from vidur.entities import Batch, BatchStage, ExecutionTime, Request from vidur.logger import init_logger from vidur.metrics.cdf_sketch import CDFSketch @@ -632,9 +632,9 @@ def _push_metric( def on_batch_end( self, time: float, batch: Batch, replica_id: int, memory_usage_percent: int ) -> None: - if (self._config.min_batch_index and batch.id < self._config.min_batch_index) or ( - self._config.max_batch_index and batch.id > self._config.max_batch_index - ): + if ( + self._config.min_batch_index and batch.id < self._config.min_batch_index + ) or (self._config.max_batch_index and batch.id > self._config.max_batch_index): return for request in batch.completed_requests: diff --git a/vidur/metrics/series_average_meter.py b/vidur/metrics/series_average_meter.py index 92b0498..8f6679e 100644 --- a/vidur/metrics/series_average_meter.py +++ b/vidur/metrics/series_average_meter.py @@ -1,6 +1,7 @@ import json import wandb + from vidur.logger import init_logger logger = init_logger(__name__) diff --git a/vidur/profiling/common/model_config.py b/vidur/profiling/common/model_config.py index 057717e..7ee132b 100644 --- a/vidur/profiling/common/model_config.py +++ b/vidur/profiling/common/model_config.py @@ -1,7 +1,7 @@ +from dataclasses import asdict from typing import Any, Dict, Optional import torch -from dataclasses import asdict from sarathi.config import ParallelConfig from vidur.config.model_config import BaseModelConfig diff --git a/vidur/request_generator/synthetic_request_generator.py b/vidur/request_generator/synthetic_request_generator.py index 596f768..46ebcb4 100644 --- a/vidur/request_generator/synthetic_request_generator.py +++ b/vidur/request_generator/synthetic_request_generator.py @@ -2,9 +2,7 @@ from vidur.config import SyntheticRequestGeneratorConfig from vidur.entities import Request -from vidur.request_generator.base_request_generator import ( - BaseRequestGenerator, -) +from vidur.request_generator.base_request_generator import BaseRequestGenerator from vidur.request_generator.request_interval_generator_registry import ( RequestIntervalGeneratorRegistry, ) diff --git a/vidur/request_generator/trace_replay_request_generator.py b/vidur/request_generator/trace_replay_request_generator.py index 971f2da..5ef9397 100644 --- a/vidur/request_generator/trace_replay_request_generator.py +++ b/vidur/request_generator/trace_replay_request_generator.py @@ -5,9 +5,7 @@ from vidur.config import TraceRequestGeneratorConfig from vidur.entities import Request -from vidur.request_generator.base_request_generator import ( - BaseRequestGenerator, -) +from vidur.request_generator.base_request_generator import BaseRequestGenerator logger = logging.getLogger(__name__) diff --git a/vidur/scheduler/replica_scheduler/base_replica_scheduler.py b/vidur/scheduler/replica_scheduler/base_replica_scheduler.py index 7f9b321..2db5c33 100644 --- a/vidur/scheduler/replica_scheduler/base_replica_scheduler.py +++ b/vidur/scheduler/replica_scheduler/base_replica_scheduler.py @@ -1,7 +1,11 @@ from abc import ABC, abstractmethod from typing import List -from vidur.config import BaseRequestGeneratorConfig, ReplicaConfig, BaseReplicaSchedulerConfig +from vidur.config import ( + BaseReplicaSchedulerConfig, + BaseRequestGeneratorConfig, + ReplicaConfig, +) from vidur.entities import Batch, Replica, Request from vidur.execution_time_predictor import BaseExecutionTimePredictor from vidur.logger import init_logger diff --git a/vidur/scheduler/replica_scheduler/lightllm_replica_scheduler.py b/vidur/scheduler/replica_scheduler/lightllm_replica_scheduler.py index 567524a..34263bf 100644 --- a/vidur/scheduler/replica_scheduler/lightllm_replica_scheduler.py +++ b/vidur/scheduler/replica_scheduler/lightllm_replica_scheduler.py @@ -14,9 +14,7 @@ def __init__(self, *args, **kwargs): self._preempted_requests: List[Request] = [] self._num_running_batches = 0 - self._max_micro_batch_size = ( - self._config.batch_size_cap // self._num_stages - ) + self._max_micro_batch_size = self._config.batch_size_cap // self._num_stages assert ( self._config.block_size == 1 ), "LightLLM scheduler only supports block size of 1." diff --git a/vidur/scheduler/replica_scheduler/sarathi_replica_scheduler.py b/vidur/scheduler/replica_scheduler/sarathi_replica_scheduler.py index 68aaeb0..c033a18 100644 --- a/vidur/scheduler/replica_scheduler/sarathi_replica_scheduler.py +++ b/vidur/scheduler/replica_scheduler/sarathi_replica_scheduler.py @@ -15,9 +15,7 @@ def __init__(self, *args, **kwargs): self._preempted_requests = [] # For vLLM and its derivatives, we only need to set a loose max batch size # Memory requirements are handled explicitly by the scheduler - self._max_micro_batch_size = ( - self._config.batch_size_cap // self._num_stages - ) + self._max_micro_batch_size = self._config.batch_size_cap // self._num_stages self._watermark_blocks = int( self._config.watermark_blocks_fraction * self._config.num_blocks ) @@ -25,7 +23,9 @@ def __init__(self, *args, **kwargs): def _can_allocate_request(self, request: Request) -> bool: if request.id not in self._allocation_map: # new request - num_required_blocks = ceil(request.num_prefill_tokens / self._config.block_size) + num_required_blocks = ceil( + request.num_prefill_tokens / self._config.block_size + ) return ( self._config.num_blocks - self._num_allocated_blocks @@ -39,7 +39,9 @@ def _can_allocate_request(self, request: Request) -> bool: def _allocate_request(self, request: Request) -> None: if request.id not in self._allocation_map: # new request - num_required_blocks = ceil(request.num_prefill_tokens / self._config.block_size) + num_required_blocks = ceil( + request.num_prefill_tokens / self._config.block_size + ) self.allocate(request.id, num_required_blocks) return diff --git a/vidur/scheduler/replica_scheduler/vllm_replica_scheduler.py b/vidur/scheduler/replica_scheduler/vllm_replica_scheduler.py index a05043c..eaa871c 100644 --- a/vidur/scheduler/replica_scheduler/vllm_replica_scheduler.py +++ b/vidur/scheduler/replica_scheduler/vllm_replica_scheduler.py @@ -15,9 +15,7 @@ def __init__(self, *args, **kwargs): self._num_running_batches = 0 # For vLLM and its derivatives, we only need to set a loose max batch size # Memory requirements are handled explicitly by the scheduler - self._max_micro_batch_size = ( - self._config.batch_size_cap // self._num_stages - ) + self._max_micro_batch_size = self._config.batch_size_cap // self._num_stages self._watermark_blocks = int( self._config.watermark_blocks_fraction * self._config.num_blocks ) @@ -34,7 +32,9 @@ def on_batch_end(self, batch: Batch) -> None: def _can_allocate_request(self, request: Request) -> bool: if request.id not in self._allocation_map: # new request - num_required_blocks = ceil((request.num_prefill_tokens) / self._config.block_size) + num_required_blocks = ceil( + (request.num_prefill_tokens) / self._config.block_size + ) return ( self._config.num_blocks - self._num_allocated_blocks @@ -48,7 +48,9 @@ def _can_allocate_request(self, request: Request) -> bool: def _allocate_request(self, request: Request) -> None: if request.id not in self._allocation_map: # new request - num_required_blocks = ceil((request.num_prefill_tokens) / self._config.block_size) + num_required_blocks = ceil( + (request.num_prefill_tokens) / self._config.block_size + ) self.allocate(request.id, num_required_blocks) return diff --git a/vidur/simulator.py b/vidur/simulator.py index 4628662..f20c7f8 100644 --- a/vidur/simulator.py +++ b/vidur/simulator.py @@ -29,8 +29,14 @@ def __init__(self, config: SimulationConfig) -> None: self._event_trace = [] self._event_chrome_trace = [] - self._cluster = Cluster(self._config.cluster_config, self._config.metrics_config, self._config.request_generator_config) - self._metric_store = MetricsStore(self._config.metrics_config, self._config.cluster_config) + self._cluster = Cluster( + self._config.cluster_config, + self._config.metrics_config, + self._config.request_generator_config, + ) + self._metric_store = MetricsStore( + self._config.metrics_config, self._config.cluster_config + ) self._request_generator = RequestGeneratorRegistry.get( self._config.request_generator_config.get_type(), self._config.request_generator_config, @@ -38,7 +44,7 @@ def __init__(self, config: SimulationConfig) -> None: self._scheduler = GlobalSchedulerRegistry.get( self._config.cluster_config.global_scheduler_config.get_type(), self._config, - self._cluster.replicas + self._cluster.replicas, ) self._init_event_queue() diff --git a/vidur/types/__init__.py b/vidur/types/__init__.py index f21104d..da67d95 100644 --- a/vidur/types/__init__.py +++ b/vidur/types/__init__.py @@ -1,16 +1,15 @@ +from vidur.types.activation_type import ActivationType from vidur.types.base_int_enum import BaseIntEnum +from vidur.types.device_sku_type import DeviceSKUType from vidur.types.event_type import EventType from vidur.types.execution_time_predictor_type import ExecutionTimePredictorType from vidur.types.global_scheduler_type import GlobalSchedulerType +from vidur.types.node_sku_type import NodeSKUType +from vidur.types.norm_type import NormType from vidur.types.replica_scheduler_type import ReplicaSchedulerType from vidur.types.request_generator_type import RequestGeneratorType from vidur.types.request_interval_generator_type import RequestIntervalGeneratorType from vidur.types.request_length_generator_type import RequestLengthGeneratorType -from vidur.types.device_sku_type import DeviceSKUType -from vidur.types.node_sku_type import NodeSKUType -from vidur.types.norm_type import NormType -from vidur.types.activation_type import ActivationType - __all__ = [ EventType, diff --git a/vidur/utils/param_counter.py b/vidur/utils/param_counter.py index fbc961a..5ef348f 100644 --- a/vidur/utils/param_counter.py +++ b/vidur/utils/param_counter.py @@ -8,15 +8,26 @@ def __init__(self, replica_config: ReplicaConfig) -> None: self._replica_config = replica_config self._model_config = self._replica_config.model_config - assert self._model_config.num_q_heads % self._replica_config.tensor_parallel_size == 0 - assert self._model_config.num_layers % self._replica_config.num_pipeline_stages == 0 - assert self._model_config.embedding_dim % self._replica_config.tensor_parallel_size == 0 + assert ( + self._model_config.num_q_heads % self._replica_config.tensor_parallel_size + == 0 + ) + assert ( + self._model_config.num_layers % self._replica_config.num_pipeline_stages + == 0 + ) + assert ( + self._model_config.embedding_dim % self._replica_config.tensor_parallel_size + == 0 + ) assert self._model_config.embedding_dim % self._model_config.num_q_heads == 0 self._num_layers_per_pipeline_stage = ( self._model_config.num_layers // self._replica_config.num_pipeline_stages ) - self._attention_head_dim = self._model_config.embedding_dim // self._model_config.num_q_heads + self._attention_head_dim = ( + self._model_config.embedding_dim // self._model_config.num_q_heads + ) self._q_heads_per_tensor_parallel_worker = ( self._model_config.num_q_heads // self._replica_config.tensor_parallel_size )