diff --git a/simdistserve/README.md b/simdistserve/README.md index c31cbcd..84c1c6c 100644 --- a/simdistserve/README.md +++ b/simdistserve/README.md @@ -48,7 +48,30 @@ Ideally you should get the following result: Best per GPU rate: 1.56 Best config: pp_cross=1, tp_prefill=2, pp_prefill=1, tp_decode=1, pp_decode=1 ``` - +### Ratio search +Given the parallel strategy of prefill and decoding instances, search for the best config ratio M:N. +```bash +python -m simdistserve.simulate_ratio \ + --prefill-tp 8 \ + --prefill-pp 1 \ + --decode-tp 8 \ + --decode-pp 1 \ + --max-prefill-instances 8 \ + --max-decode-instances 8 \ + --kv-cache-mem-per-gpu 64 \ + --kv-transfer-bw 600 \ + --model-type "facebook/opt-66b" \ + --workload sharegpt --backend distserve \ + --prefill-target 200 --decode-target 100 \ + --prefill-percentage 90 --decode-percentage 90 \ + --max-per-gpu-rate 5 \ + --esp 0.25 \ + --N 300 +``` +Output: +```text +Best config: prefill_instance=8, decode_instance=3, per_gpu_rate=2.1875 +``` ## Architecture The simulator is written on top of `simpy`, a discrete event simulator built natively in Python. diff --git a/simdistserve/base/request.py b/simdistserve/base/request.py index efdf010..d72c22c 100644 --- a/simdistserve/base/request.py +++ b/simdistserve/base/request.py @@ -7,6 +7,8 @@ E_DO_PREFILL = "do_prefill" E_WAIT_DECODE = "wait_decode" E_DO_DECODE = "do_decode" +E_WAIT_KVCACHE_MIGRATION = "wait_kvcache_migration" +E_DO_KVCACHE_MIGRATION = "do_kvcache_migration" E_FINISH_PREFILL = "finish_prefill" E_FINISH_DECODE = "finish_decode" E_EXIT_SYSTEM = "exit_system" @@ -61,10 +63,27 @@ def __init__( # set this value if a request belongs to a particular chunk # The last worker in the pipeline unset this value at a chunk's end. self.chunk_id = None + # after the request is finished prefill, `kvcache_generated` should be set to `True`. + self.kvcache_generated = False + self.prefill_is_finished = False + # workers do frefill/deocde in PP. + self.prefill_workers = [] + self.decode_workers = [] + # workers already sent/received kvcache. + self.migrated_prefill_workers = [] + self.migrated_decode_workers = [] + # migrate finished event + self.migrate_event = None + # migrate time for one prefill worker + self.migrate_time = 0 @property def current_context_len(self): return self.prefill_lens + max(0, self.counter) + + @property + def kvcache_migrate_is_done(self): + return len(self.migrated_prefill_workers) == len(self.prefill_workers) def _log_event(self, event, wid=-1): if not self.env: @@ -88,6 +107,12 @@ def wait_decode(self, wid=None): def do_decode(self, wid=None): self._log_event(E_DO_DECODE, wid=wid) + + def wait_kvcache_migration(self, wid=None): + self._log_event(E_WAIT_KVCACHE_MIGRATION, wid=wid) + + def do_kvcache_migration(self, wid=None): + self._log_event(E_DO_KVCACHE_MIGRATION, wid=wid) def _reset_chunked_prefill_metadata(self): """Reset the metadata of chunked prefill.""" @@ -111,6 +136,7 @@ def finish_prefill(self, is_finished_one_round=False, wid=None, next_wid=None): # Reset counter to 0 # TODO: Should we do self.counter += 1? self.counter = 0 + self.prefill_is_finished = True # Hack to ensure "wait_decode" appears at least once. self.wait_decode(wid=next_wid) if not self.should_finish(): diff --git a/simdistserve/base/worker.py b/simdistserve/base/worker.py index 72b7f38..7371892 100644 --- a/simdistserve/base/worker.py +++ b/simdistserve/base/worker.py @@ -8,6 +8,7 @@ from uuid import UUID from simdistserve.estimators.time_estimator import get_prefill_time, get_decode_time +from simdistserve.utils import cal_kvcache_slots, cal_kvcache_token_size if TYPE_CHECKING: from simdistserve.base.scheduler import Scheduler @@ -26,6 +27,8 @@ class WorkerConfig(TypedDict): decode_max_tokens: int # Max tokens in a iteration forward (default = 10**7) enable_chunked_prefill: Optional[bool] # Enable memory pressure simulation (default = False) engine_type: Literal["distserve", "vllm"] # Engine type for prefill/decode time calculation (default = "distserve") + kv_cache_mem_per_gpu: int # KV cache memory per GPU in GB (default = 10) + kv_transfer_bw: int # KV transfer bandwidth in Gbps (default = 80) # TODO: Deprecated TP: Optional[int] # Tensor parallelism (default = 1) @@ -46,9 +49,12 @@ def __init__( TP: int = 1, TP_Prefill: int = None, TP_Decode: int = None, + PP: int = 1, enable_chunked_prefill=False, prefill_max_tokens=10 ** 7, decode_max_tokens=10 ** 7, + kv_cache_mem_per_gpu=54, + kv_transfer_bw=80, decode_back_pressure: float = 0.9, engine_type: Literal["distserve", "vllm"] = "distserve", ): @@ -62,6 +68,7 @@ def __init__( # TODO: (Deprecate) TP should be deprecate in favor of TP_prefill and TP_decode. self.TP = TP + self.PP = PP self.TP_Prefill = TP_Prefill self.TP_Decode = TP_Decode if (self.TP_Prefill is None) and (self.TP_Decode is None): @@ -94,10 +101,20 @@ def __init__( self.prefill_queue: 'deque[Request]' = deque() self.decode_queue: 'deque[Request]' = deque() + # Transfer kv-cache to other workers, every request will be push to this queue after prefill, + # and waitng for some worker to receive it and decode. + self.migrate_queue: 'deque[Request]' = deque() self._prefill_ips: int = 0 # Elements in progress for prefill self._decode_ips: int = 0 # Elements in progress for decode self._wakeup_event = env.event() self.log: 'list[tuple[float, str, int, int, int, list[int], list[int]]]' = [] + + kvcache_slot_num_per_gpu = cal_kvcache_slots(model_type, kv_cache_mem_per_gpu) + + self.free_mem_slots_num = kvcache_slot_num_per_gpu * max(self.TP_Prefill, self.TP_Decode) + self.max_mem_slot_num = self.free_mem_slots_num + self.mem_slot_lower_bound = 0.1 * self.free_mem_slots_num # 10% of free slots + self.per_token_kvcache_transfertime = cal_kvcache_token_size(model_type) / (kv_transfer_bw * 1024) # in ms TODO: check the unit # Simulate scheduler delay in terms of number of decode rounds. self._prefill_sched_delay: int = 0 @@ -108,10 +125,6 @@ def __init__( def is_first_in_pipeline(self): return self.pipe_rank == 0 - @property - def has_back_pressure(self) -> bool: - threshold = int(self.decode_max_batch_size * self.decode_back_pressure) - return sum(r.current_context_len for r in self.decode_queue) > threshold def __repr__(self): return f"Worker {self.wid}" @@ -127,20 +140,49 @@ def _log_event(self, event, num_tokens: int = 0, prefill_bs=0, decode_bs=0, # print(item) return - def run(self): + def run_compute(self): # compute loop while True: if not (self.prefill_queue or self.decode_queue): yield self._wakeup_event - if self.prefill_queue and not self.has_back_pressure: - yield from self.do_prefill() + if self.prefill_queue : + if self.mem_is_enough(): + yield from self.do_prefill() + else: + yield self.env.timeout(0.1) # avoid dead lock else: yield from self.do_decode() + self._log_event("compute_wait") - self._log_event("wait") - pass + def run_migrate(self): # migrate loop + while True: + if not self.migrate_queue: + yield self._wakeup_event + + migrate_queue_len = len(self.migrate_queue) + migrate_time = 0 + for i in range(migrate_queue_len): + request = self.migrate_queue.popleft() + if request.kvcache_migrate_is_done: + # all prefill workers have sent the kv-cache + raise + continue + migrate_time += request.migrate_time + assert self.wid not in request.migrated_prefill_workers + request.migrated_prefill_workers.append(self.wid) + if request.kvcache_migrate_is_done: # all prefill workers have sent the kv-cache + request.migrate_event.succeed() + for worker in request.prefill_workers: # free the kv-cache + yield worker.env.process(worker.prefill_free_kvcache([request,])) + + yield self.env.timeout(migrate_time) + self._log_event("migrate_wait") + + def run(self): + compute_task = self.env.process(self.run_compute()) + migrate_task = self.env.process(self.run_migrate()) + yield self.env.all_of([compute_task, migrate_task]) - pass def add_ray_overhead(self, sum_of_tokens) -> int: base_overhead = 2 @@ -182,21 +224,65 @@ def forward_decode(self, items: Union['Request', Iterable['Request']], to_schedu return def _enter_decodes(self, remaining_tok_in_batch: int) -> 'List[Request]': - # decode_max_tokens - - # Acceptable decode requests is capped by the remaining allowed tokens in this batch. - # TODO: Hack: Must revert this to use the max token given - # watermark = 0.9 - # decode_max_tokens = self.decode_max_tokens * watermark - decode_max_tokens = 50000 _decode_len = min(remaining_tok_in_batch, len(self.decode_queue)) - decode_reqs = [] - for i in range(_decode_len): - req = self.decode_queue[0] - if (req.current_context_len + 1) > decode_max_tokens: - break - decode_max_tokens -= (req.current_context_len + 1) - decode_reqs.append(self.decode_queue.popleft()) + decode_reqs: 'List[Request]' = [] + + # if memory is not enough, only schedule the requests that have been decoded before + # because their kv-cache is already alloced, new requests' kv-cache cost too much free memory + decode_all_kinds_requests = self.mem_is_enough() + + # request is given up if + # 1. request needs kv-cache migrate but memory is less than the mem_slot_lower_bound + # 2. available memory is not enough + requests_give_up = deque() + + if self.mem_is_enough() or decode_all_kinds_requests: + available_slots = self.free_mem_slots_num // 2 # avoid free_mem_slots_num is used in one batch + else: + available_slots = self.free_mem_slots_num // 128 # batch size control + + for _ in range(_decode_len): # choose batch + left_req = self.decode_queue[0] + self.decode_queue.popleft() + if self.wid not in left_req.migrated_decode_workers: # requests need to migrate kv-cache + kv_cache_size = left_req.current_context_len / self.PP + if not decode_all_kinds_requests: + requests_give_up.append(left_req) # needs kv-cache migrate, give up + continue + elif available_slots > kv_cache_size: # memory is enough, migrate kv-cache + available_slots -= kv_cache_size + decode_reqs.append(left_req) + else: # memory is not enough, give up the request + requests_give_up.append(left_req) + else: # common requests + if available_slots < 1: + requests_give_up.append(left_req) + continue + decode_reqs.append(left_req) + available_slots -= 1 + + + # put the requests kicked back to the queue from front + while len(requests_give_up) > 0: + self.decode_queue.appendleft(requests_give_up.pop()) + assert len(self.decode_queue) + len(decode_reqs) == _decode_len + + # kv-cache transfer + migrate_requests = [] + + for r in list(decode_reqs): + if self.wid not in r.migrated_decode_workers: # need to migrate kv-cache + if self.wid not in [w.wid for w in r.prefill_workers]: # if the kv-cache is not in the worker, then migrate + # r.migrate_time means the time cost to migrate the kv-cache for each prefill worker + r.migrate_time = self.per_token_kvcache_transfertime * r.current_context_len / len(r.prefill_workers) + migrate_requests.append(r) + for worker in r.prefill_workers: # make sure all prefill workers are waken up and migrate the kv-cache + worker.wakeup() + yield self.env.process(self.migrate_alloc_kvcache([r,])) + + if decode_reqs: # allocate kv-cache for the requests + self.decode_alloc_kvcache(decode_reqs) + for r in decode_reqs: r.do_decode(wid=self.wid) return decode_reqs @@ -204,6 +290,12 @@ def _enter_decodes(self, remaining_tok_in_batch: int) -> 'List[Request]': def _enter_prefill(self) -> 'List[Request]': result: 'List[Request]' = [] + # check if free_slot_num touches the lower bound + if not self.mem_is_enough(): + return result + + available_slots = self.free_mem_slots_num + # Limit the maximum prefill requests to handle. max_request_size = min(self.prefill_max_batch_size, len(self.prefill_queue)) @@ -216,7 +308,10 @@ def _enter_prefill(self) -> 'List[Request]': candidate: 'Request' = self.prefill_queue[0] if candidate.chunk_id != chunk_id: break + if available_slots < candidate.current_prefill_lens: # TODO: not sure if this is correct + break result.append(self.prefill_queue.popleft()) + available_slots -= candidate.current_prefill_lens pass else: @@ -248,6 +343,9 @@ def _enter_prefill(self) -> 'List[Request]': break pass + if available_slots < sched_size: + break + available_slots -= sched_size # Candidate is picked. Now fill in the chunked-prefill information. candidate.current_prefill_lens = sched_size candidate.remain_prefill_lens -= sched_size @@ -259,37 +357,68 @@ def _enter_prefill(self) -> 'List[Request]': pass for i in result: i.do_prefill(wid=self.wid) + + if result: + self.prefill_alloc_kvcache(result) + return result def _exit_prefill(self, prefill_items: List['Request']): + # if a request finished prefill, it should be migrated to other workers + requests_need_migrate = [] + forward_decode_requests = [] + for item in prefill_items: next_wid = self.next_worker.wid if self.next_worker else None item.finish_prefill(is_finished_one_round=self.is_last_in_pipeline, wid=self.wid, next_wid=next_wid) + # add the worker self to the prefill_workers list + item.prefill_workers.append(self) + item.prefill_workers = list(set(item.prefill_workers)) + if not self.is_last_in_pipeline or (item.remain_prefill_lens > 0): # Finish one chunk of prefill. Now forward to the next worker # (or head of worker) to do the rest of the parts. self.forward_prefill(item) continue - + if item.prefill_is_finished: # need to migrate kv-cache + requests_need_migrate.append(item) # Arrive at worker who is at the last of pipeline. if item.should_finish(): # ... just a sanity check to avoid any infinite loop. continue + forward_decode_requests.append(item) + + for r in requests_need_migrate: # prefill is finished, migrate the kv-cache + r.migrate_event = self.env.event() + for worker in r.prefill_workers: + worker.migrate_kvcache([r,]) + for item in forward_decode_requests: self.forward_decode(item, to_scheduler=(not self.should_request_stay)) return - def _exit_decode(self, decode_reqs): + def _exit_decode(self, decode_reqs: 'List[Request]'): + finished_requests = [] # if the request is finished, its kv-cache should be freed + if not decode_reqs: return next_wid = self.next_worker.wid if self.next_worker else None for r in decode_reqs: r.finish_decode(is_finished_one_round=self.is_last_in_pipeline, next_wid=next_wid) + r.decode_workers.append(self) # add the worker self to the decode_workers list + r.decode_workers = list(set(r.decode_workers)) + if r._terminated: + finished_requests.append(r) next_decode_batch = tuple(r for r in decode_reqs if not r.should_finish()) + for r in finished_requests: # free all the kv-cache: prefilled and decoded + for worker in r.decode_workers: + worker.decode_free_kvcache([r,]) self.forward_decode(next_decode_batch) return def do_prefill(self): prefill_items: 'List[Request]' = self._enter_prefill() + if not prefill_items: + return if self.enable_chunked_prefill: remaining_tok_in_batch = self.prefill_max_tokens - sum(x.current_prefill_lens for x in prefill_items) decode_reqs = self._enter_decodes(remaining_tok_in_batch) @@ -332,8 +461,11 @@ def do_prefill(self): return def do_decode(self): - decode_reqs = self._enter_decodes(self.decode_max_tokens) - batch_size = len(decode_reqs) + decode_reqs = yield self.env.process(self._enter_decodes(self.decode_max_tokens)) + batch_size = len(list(decode_reqs)) + if batch_size == 0: + return + self._log_event( "do_decode", num_tokens=batch_size, decode_bs=batch_size, decode_len_list=[x.current_context_len for x in decode_reqs], @@ -350,4 +482,69 @@ def do_decode(self): self._exit_decode(decode_reqs) return - pass + + + def mem_is_enough(self) -> bool: + assert self.free_mem_slots_num <= self.max_mem_slot_num, f"Worker {self.wid} free_mem_slots_num: {self.free_mem_slots_num}, max_mem_slot_num: {self.max_mem_slot_num}" + assert self.free_mem_slots_num >= 0 + + return self.free_mem_slots_num >= self.mem_slot_lower_bound + + + def migrate_kvcache(self, requests: 'list[Request]') -> None: + # called by prefill worker, push the requests to the migrate queue, + # and waiting for the decode worker to receive it + # TODO: if the request's output_len == 1, decode won't happen, the kv-cache should be freed + for i in requests: + self.migrate_queue.append(i) + i.wait_kvcache_migration(wid=self.wid) + + def prefill_alloc_kvcache(self, requests: 'list[Request]') -> bool: + # allocate slots for kv-cache + if not self.mem_is_enough(): + return False + for i in requests: + i.kvcache_generated = True + kvcache_size = sum([request.current_prefill_lens for request in requests]) / self.PP + self.free_mem_slots_num -= kvcache_size + self._log_event('prefill_alloc_kvcache') + return True + + + def prefill_free_kvcache(self, requests: 'list[Request]') -> None: + # called by prefill worker, free the slots immediately after kv-cache migration finished + for r in requests: # waiting for all prefill workers to migrate the kv-cache + yield r.migrate_event + kvcache_size = 0 + for r in requests: + assert self.wid in [w.wid for w in r.prefill_workers] + if self.wid in [w.wid for w in r.prefill_workers] : + kvcache_size += r.prefill_lens / self.PP + self.free_mem_slots_num += kvcache_size + self._log_event('prefill_free_kvcache') + return + + def migrate_alloc_kvcache(self, requests: 'list[Request]') -> bool: + # called by the decode worker, prepare for the kv-cache migration + kvcache_size = sum([request.current_context_len for request in requests]) / self.PP + self.free_mem_slots_num -= kvcache_size + for r in requests: + yield r.migrate_event + r.migrated_decode_workers.append(self.wid) # mark the decode worker has received the kv-cache + + def decode_alloc_kvcache(self, requests: 'list[Request]') -> bool: + # for each request, decode once cost one slot + kvcache_size = len(requests) / self.PP + self.free_mem_slots_num -= kvcache_size + self._log_event('decode_alloc_kvcache') + return True + + def decode_free_kvcache(self, requests: 'list[Request]') -> None: + # free the slots immediately after decoding finished + kvcache_size = sum([(request.current_context_len) for request in requests]) / self.PP + self.free_mem_slots_num += kvcache_size + self._log_event('decode_free_kvcache') + return + + def __del__(self): + assert self.free_mem_slots_num == self.max_mem_slot_num, f"worker:{self.wid} free_mem_slots_num: {self.free_mem_slots_num}, max_mem_slot_num: {self.max_mem_slot_num}" diff --git a/simdistserve/benchmarks/parallel_bisect.py b/simdistserve/benchmarks/parallel_bisect.py index 67f3f72..700f92f 100644 --- a/simdistserve/benchmarks/parallel_bisect.py +++ b/simdistserve/benchmarks/parallel_bisect.py @@ -20,7 +20,9 @@ def main( num_node: int, num_gpu_per_node: int, model_type: ModelTypes, is_dist_high: bool = True, backend: str = "distserve", attainment=(200, 100, 90, 90), - max_per_gpu_rate=5, esp=0.25, N=1000, + max_per_gpu_rate=5, + kv_cache_mem_per_gpu=54, kv_transfer_bw=80, + esp=0.25, N=1000, max_cpu_count=MAX_CPU_COUNT, ): """ @@ -48,6 +50,8 @@ def main( backend, attainment, ), kwargs=dict( + kv_cache_mem_per_gpu=kv_cache_mem_per_gpu, + kv_transfer_bw=kv_transfer_bw, max_per_gpu_rate=max_per_gpu_rate, pid=pid, esp=esp, N=N, result=result, diff --git a/simdistserve/benchmarks/parallel_ratio_bisect.py b/simdistserve/benchmarks/parallel_ratio_bisect.py new file mode 100644 index 0000000..fb2c67a --- /dev/null +++ b/simdistserve/benchmarks/parallel_ratio_bisect.py @@ -0,0 +1,126 @@ +import os +import time +from multiprocessing import Process, Manager +from time import sleep + +import pandas as pd +from tqdm import tqdm +from fractions import Fraction + +from simdistserve.benchmarks.search_binary import run_binary_search +from simdistserve.benchmarks.search_configs import get_distserve_configs, get_vllm_config +from simdistserve.constants import ModelTypes + +# Restrict runtime to <= 32 CPU core. +# RunPod encounters problem when using `os.cpu_count()` +# to query the number of CPUs +MAX_CPU_COUNT = min(os.cpu_count() - 2, 32) + + +def main( + prefill_pp: int, prefill_tp: int, + decode_pp: int, decode_tp: int, + model_type: ModelTypes, + is_dist_high: bool = True, + backend: str = "distserve", attainment=(200, 100, 90, 90), + max_prefill_instance=8, max_decode_instance=8, + max_per_gpu_rate=5, + kv_cache_mem_per_gpu=54, kv_transfer_bw=80, + esp=0.25, N=1000, + max_cpu_count=MAX_CPU_COUNT, +): + """ + :return result: dict that maps config to the best_per_gpu_rate (int) + """ + configs = [(1, prefill_tp, prefill_pp, decode_tp, decode_pp)] + + if backend == "distserve": + ratios = [] + for prefill_instance in range(1, max_prefill_instance + 1): + for decode_instance in range(1, max_decode_instance + 1): + frac = Fraction(prefill_instance, decode_instance) + ratios.append((frac.numerator, frac.denominator)) + ratios = list(set(ratios)) + else: + raise ValueError(f"Unsupported backend for ratio search: {backend}") + + processes = [] + # Add a multiproc shared dict + with Manager() as manager: + result = manager.dict() + pbar = tqdm(enumerate(ratios), total=len(ratios)) + for pid, ratio in pbar: + proc = Process( + target=run_binary_search, + args=( + model_type, configs[0], + backend, attainment, + ), + kwargs=dict( + kv_cache_mem_per_gpu=kv_cache_mem_per_gpu, + kv_transfer_bw=kv_transfer_bw, + max_per_gpu_rate=max_per_gpu_rate, + prefill_instance=ratio[0], + decode_instance=ratio[1], + pid=pid, esp=esp, + N=N, result=result, + ratio_search=True, + debug=True, + ) + ) + if len(processes) >= max_cpu_count: + # Pop a process that has finished running + found = False + while not found: + for i in range(len(processes)): + if not processes[i].is_alive(): + processes[i].join() + processes.pop(i) + found = True + pbar.update(1) + break + sleep(0.2) + + proc.start() + processes.append(proc) + pass + for proc in processes: + pbar.update(1) + proc.join() + result = dict(result) + return result + + +simulate_bisect_ratio_search = main + +if __name__ == '__main__': + data = [] + for ngpu in [2, 4, 8, 16, 32]: + start = time.perf_counter() + main(ngpu, 1, is_dist_high=True) + end = time.perf_counter() + duration = end - start + data.append({ + "name": "DistHigh", + "ngpu": ngpu, + "duration": duration + }) + print(f"DistHigh({ngpu=}):{duration}s") + + for ngpu_per_node, num_node in [(2, 1), (4, 1), (8, 1), (8, 2), (8, 4)]: + ngpu = ngpu_per_node * num_node + start = time.perf_counter() + main(num_node, ngpu_per_node, is_dist_high=False) + end = time.perf_counter() + duration = end - start + data.append({ + "name": "DistLow", + "ngpu": ngpu, + "duration": duration + }) + print(f"DistLow({ngpu_per_node=},{num_node=}):{duration}s") + + df = pd.DataFrame(data) + df.to_csv("parallel_bisect.csv", index=False) + + pass diff --git a/simdistserve/benchmarks/search_binary.py b/simdistserve/benchmarks/search_binary.py index 49e43e5..9c8cbb0 100644 --- a/simdistserve/benchmarks/search_binary.py +++ b/simdistserve/benchmarks/search_binary.py @@ -13,6 +13,11 @@ def run_binary_search( backend: str, containment_targets: '(prefill_target, decode_target, prefill_containment, decode_containment)', max_per_gpu_rate: int = 16, + kv_cache_mem_per_gpu=54, + kv_transfer_bw=80, + prefill_instance=1, + decode_instance=1, + ratio_search=False, pid=0, esp=0.5, N=1000, @@ -32,6 +37,8 @@ def run_binary_search( '--pp-prefill', f'{pp_cross * pp_prefill}', '--tp-decode', f'{tp_decode}', '--pp-decode', f'{pp_cross * pp_decode}', + '--prefill-instance', prefill_instance, + '--decode-instance', decode_instance, ] else: (tp, pp) = config @@ -61,6 +68,8 @@ def run_binary_search( '--decode-containment', decode_containment, # P90 '--decode-target', decode_target, # ms '--model', ModelTypes.formalize_model_name(model_type), + '--kv-cache-mem-per-gpu', kv_cache_mem_per_gpu, + '--kv-transfer-bw', kv_transfer_bw, '--workload', 'sharegpt', '--slas', '[]', '--slo-scales', '[1]', @@ -102,7 +111,10 @@ def run_binary_search( pass # print(best_per_gpu_rate) if result is not None: - result[config] = best_per_gpu_rate + if ratio_search: + result[(prefill_instance, decode_instance)] = best_per_gpu_rate + else: + result[config] = best_per_gpu_rate return best_per_gpu_rate diff --git a/simdistserve/benchmarks/simulate_dist.py b/simdistserve/benchmarks/simulate_dist.py index efca6e9..3a73b57 100644 --- a/simdistserve/benchmarks/simulate_dist.py +++ b/simdistserve/benchmarks/simulate_dist.py @@ -65,6 +65,12 @@ def parse_args(args_=None): parser.add_argument('--output-request-latency', type=str, default=None, help='Output per-request latency (csv)') parser.add_argument('--output-worker', type=str, default=None, help='Output per-worker per-iteration time (csv)') + parser.add_argument('--kv-cache-mem-per-gpu', type=int, default=54, + help='KV cache memory per GPU in GB (default 10GB)') + parser.add_argument('--kv-transfer-bw', type=int, default=80, + help='KV transfer bandwidth in Gbps (default 80Gbps)') + parser.add_argument('--prefill-instance', type=int, default=1, help='Number of prefill instances (used in DistServe)') + parser.add_argument('--decode-instance', type=int, default=1, help='Number of decode instances (used in DistServe)') parser.add_argument('--prefill-containment', type=int, default=None, help='Containment target for prefill') parser.add_argument('--prefill-target', type=int, default=200, @@ -145,7 +151,13 @@ def main(args, outputs=None): PP_prefill = args.pp_prefill TP_Decode = args.tp_decode PP_decode = args.pp_decode - + + prefill_instance = args.prefill_instance + decode_instance = args.decode_instance + + kv_cache_mem_per_gpu = args.kv_cache_mem_per_gpu + kv_transfer_bw = args.kv_transfer_bw + # # Handle vllm in data processing # @@ -176,6 +188,8 @@ def main(args, outputs=None): decode_max_batch_size=10 ** 7, # inf prefill_max_tokens=prefill_max_tokens, decode_max_tokens=decode_max_tokens, + kv_cache_mem_per_gpu=kv_cache_mem_per_gpu, + kv_transfer_bw=kv_transfer_bw, enable_chunked_prefill=False, engine_type=args.backend, ) @@ -191,12 +205,15 @@ def main(args, outputs=None): decode_max_batch_size=10 ** 7, # inf prefill_max_tokens=prefill_max_tokens, decode_max_tokens=decode_max_tokens, + kv_cache_mem_per_gpu=kv_cache_mem_per_gpu, + kv_transfer_bw=kv_transfer_bw, enable_chunked_prefill=False, engine_type=args.backend, ) cluster = DisaggCluster( - env=env, PP_prefill=PP_prefill, PP_decode=PP_decode, + env=env, N_prefill_instance=prefill_instance, N_decode_instance=decode_instance, + PP_prefill=PP_prefill, PP_decode=PP_decode, worker_configs=worker_config, ) else: diff --git a/simdistserve/clusters/disagg.py b/simdistserve/clusters/disagg.py index 1d0e1c0..4a04e63 100644 --- a/simdistserve/clusters/disagg.py +++ b/simdistserve/clusters/disagg.py @@ -33,7 +33,7 @@ def __init__( for inst_id in range(N_prefill_instance): instance = [] for i, p in enumerate(range(PP_prefill)): - worker = Worker(env, worker_id, cluster=self, pipe_rank=i, **worker_kwargs) + worker = Worker(env, worker_id, cluster=self, PP=PP_prefill, pipe_rank=i, **worker_kwargs) instance.append(worker) worker_id += 1 @@ -47,7 +47,7 @@ def __init__( for inst_id in range(N_decode_instance): instance = [] for i, p in enumerate(range(PP_decode)): - worker = Worker(env, worker_id, cluster=self, pipe_rank=i, **worker_kwargs) + worker = Worker(env, worker_id, cluster=self, PP=PP_decode, pipe_rank=i, **worker_kwargs) instance.append(worker) worker_id += 1 diff --git a/simdistserve/clusters/vllm.py b/simdistserve/clusters/vllm.py index ec32fad..af6c56c 100644 --- a/simdistserve/clusters/vllm.py +++ b/simdistserve/clusters/vllm.py @@ -29,7 +29,7 @@ def __init__( for inst_id in range(N_instance): instance = [] for i, p in enumerate(range(PP)): - worker = Worker(env, worker_id, cluster=self, pipe_rank=i, **worker_kwargs) + worker = Worker(env, worker_id, cluster=self, PP=PP, pipe_rank=i, **worker_kwargs) instance.append(worker) worker_id += 1 diff --git a/simdistserve/simulate_ratio.py b/simdistserve/simulate_ratio.py new file mode 100644 index 0000000..8c72410 --- /dev/null +++ b/simdistserve/simulate_ratio.py @@ -0,0 +1,117 @@ +import argparse + +from simdistserve.benchmarks.parallel_ratio_bisect import simulate_bisect_ratio_search +from simdistserve.constants import ModelTypes + + +def parse_args(): + parser = argparse.ArgumentParser("Simulate DistServe or vLLM to find the optimal configuration.") + parser.add_argument("--prefill-tp", type=int, default=8, + help="Prefill TP num (default 8)") + parser.add_argument("--prefill-pp", type=int, default=1, + help="Prefill PP num (default 1)") + parser.add_argument("--decode-tp", type=int, default=8, + help="Decode TP num (default 8)") + parser.add_argument("--decode-pp", type=int, default=1, + help="Decode PP num (default 1)") + parser.add_argument("--is-high-affinity", action="store_true") + parser.add_argument("--backend", type=str, default="distserve", + help="Choose from: distserve, vllm") + parser.add_argument("--workload", type=str, default="sharegpt", + help="Choose from: sharegpt, humaneval, longbench") + parser.add_argument("--prefill-target", type=int, default=200, + help="Prefill TTFT attainment target in ms (default 200ms)") + parser.add_argument("--decode-target", type=int, default=100, + help="Decode TPOT attainment target in ms (default 100ms)") + parser.add_argument("--max-prefill-instances", type=int, default=8, + help="Max prefill instances to search (default 8)") + parser.add_argument("--max-decode-instances", type=int, default=8, + help="Max decode instances to search (default 8)") + parser.add_argument("--prefill-percentage", type=int, default=90, + help="Percentage of prefill target (default P90)") + parser.add_argument("--decode-percentage", type=int, default=90, + help="Percentage of prefill target (default P90)") + parser.add_argument("--max-per-gpu-rate", type=int, default=5, + help="Max per GPU rate to search (default 5)") + parser.add_argument("--kv-cache-mem-per-gpu", type=int, default=10, + help="KV cache memory per GPU in GB (default 10GB)") + parser.add_argument("--kv-transfer-bw", type=int, default=80, + help="KV transfer bandwidth in Gbps (default 10Gbps)") + parser.add_argument("--esp", type=float, default=0.25, + help="Stopping criteria: `high - low < esp` (default esp = 0.25)") + parser.add_argument("--N", type=int, default=300, + help="Number of samples to simulate (default 1000)") + parser.add_argument("--model-type", type=str, default="opt_13b", + help="Model type to simulate (opt_13b, opt_66b, opt_175b)") + + args = parser.parse_args() + args.model_type = ModelTypes.model_str_to_object(args.model_type) + return args + + +def find_best_config(config_to_best_per_gpu_rate, backend): + best_config = None + best_ngpu = float("inf") + best_per_gpu_rate = 0 + num_gpu = 0 + for config, per_gpu_rate in config_to_best_per_gpu_rate.items(): + if backend == 'distserve': + pp_cross, tp_prefill, pp_prefill, tp_decode, pp_decode = config + num_gpu = pp_cross * (tp_prefill * pp_prefill + tp_decode * pp_decode) + elif backend == 'vllm': + tp, pp = config + num_gpu = tp * pp + + if per_gpu_rate > best_per_gpu_rate or (per_gpu_rate == best_per_gpu_rate and num_gpu < best_ngpu): + best_config = config + best_per_gpu_rate = per_gpu_rate + best_ngpu = num_gpu + + return best_config, best_per_gpu_rate + + +def check_dataset_env_var(): + import os + if "DATASET" in os.environ: + return + raise KeyError( + "Please set the environment variable `DATASET` to the path of the workload datasets. " + "For user who started the environment with `DistServe-AE-GPU` docker image, " + "simply do:\nexport DATASET=`/app/dataset`\n" + "See the `repro-dataset.md` to prepare for workload dataset if you are using your custom environment." + ) + + +if __name__ == '__main__': + # def main(num_node, num_gpu_per_node, is_dist_high: bool = True): + args = parse_args() + print(args) + + if args.backend != "distserve": + raise ValueError(f"Unsupported backend: {args.backend }") + + + result = simulate_bisect_ratio_search( + prefill_tp=args.prefill_tp, + prefill_pp=args.prefill_pp, + decode_tp=args.decode_tp, + decode_pp=args.decode_pp, + model_type=args.model_type, + is_dist_high=args.is_high_affinity, + backend=args.backend, + attainment=(args.prefill_target, args.decode_target, args.prefill_percentage, args.decode_percentage), + max_prefill_instance=args.max_prefill_instances, + max_decode_instance=args.max_decode_instances, + max_per_gpu_rate=args.max_per_gpu_rate, + kv_cache_mem_per_gpu=args.kv_cache_mem_per_gpu, + kv_transfer_bw=args.kv_transfer_bw, + esp=args.esp, + N=args.N, + ) + best_config = None + max_per_gpu_rate = 0 + for config, per_gpu_rate in result.items(): + if per_gpu_rate > max_per_gpu_rate: + best_config = config + max_per_gpu_rate = per_gpu_rate + print(f"Best config: prefill_instance={best_config[0]}, decode_instance={best_config[1]}, per_gpu_rate={max_per_gpu_rate}") diff --git a/simdistserve/utils.py b/simdistserve/utils.py index 5b23a40..ae417a3 100644 --- a/simdistserve/utils.py +++ b/simdistserve/utils.py @@ -4,6 +4,8 @@ from functools import reduce from itertools import chain from typing import List +from transformers import PretrainedConfig +from simdistserve.constants import ModelTypes _verbose = True @@ -65,3 +67,21 @@ def irange(*args): x, y, z = args return range(x, y + 1, z) raise ValueError(f"args={args}") + +def cal_kvcache_token_size(model_name: str): + # in KB + model_name = ModelTypes.formalize_model_name(model_name) + model_config = PretrainedConfig.from_pretrained(model_name) + + hidden_size = model_config.hidden_size + layer_num = model_config.num_hidden_layers + return 2 * 2 * hidden_size * layer_num / 1024 + + + +def cal_kvcache_slots( + model_name: str, + memory: int, # in GB +): + token_size = cal_kvcache_token_size(model_name) # in KB + return 1024 * 1024 * memory // token_size