diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 4dbf153208f..4b0d5d18b4d 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -3228,14 +3228,16 @@ def update_weights(self, version: str = None, verify_checksum: bool = False): cache_rebuild_cost = 0.0 if release_cache: clear_start = time.perf_counter() - self._clear_cache_for_gdr_weight_update() + self._maybe_clear_memory_before_weight_update(clear_kv_cache=True, clear_cuda_graph=self.use_cudagraph) cache_clear_cost = time.perf_counter() - clear_start result = self.dynamic_weight_manager.update_weights_by_gdr(version, verify_checksum) if release_cache: rebuild_start = time.perf_counter() - self._rebuild_cache_after_gdr_weight_update() + self._maybe_rebuild_memory_after_weight_update( + rebuild_kv_cache=True, rebuild_cuda_graph=self.use_cudagraph + ) cache_rebuild_cost = time.perf_counter() - rebuild_start result["release_cache"] = release_cache @@ -3244,19 +3246,37 @@ def update_weights(self, version: str = None, verify_checksum: bool = False): self.dynamic_weight_manager.finalize_update() return result else: + self._maybe_clear_memory_before_weight_update(clear_kv_cache=False, clear_cuda_graph=self.use_cudagraph) result = self.dynamic_weight_manager.update_weights_by_rdma(version, verify_checksum) + self._maybe_rebuild_memory_after_weight_update( + rebuild_kv_cache=False, rebuild_cuda_graph=self.use_cudagraph + ) self.dynamic_weight_manager.finalize_update() return result - def _clear_cache_for_gdr_weight_update(self): - cache_flag = ( - self.fd_config.cache_config.num_cpu_blocks > 0 - or self.fd_config.cache_config.kvcache_storage_backend is not None - ) - kv_cache_status = self.kv_cache_status if cache_flag else None - if kv_cache_status: - kv_cache_status.value[0] = KVCacheStatus.CLEARING - if self.use_cudagraph: + def _maybe_clear_memory_before_weight_update(self, clear_kv_cache=False, clear_cuda_graph=False): + if clear_kv_cache: + # Clear cache on cache transfer manager + cache_flag = ( + self.fd_config.cache_config.num_cpu_blocks > 0 + or self.fd_config.cache_config.kvcache_storage_backend is not None + ) + kv_cache_status = self.kv_cache_status if cache_flag else None + if kv_cache_status: + kv_cache_status.value[0] = KVCacheStatus.CLEARING + + # Clear cache on model runner + if self.speculative_decoding and self.spec_method == SpecMethod.MTP: + self.proposer.clear_mtp_cache() + self.clear_cache() + + # Wait for cache cleared on both side + if kv_cache_status: + while kv_cache_status.value[0] != KVCacheStatus.CLEARED: + time.sleep(0.01) + paddle.device.cuda.empty_cache() + + if clear_cuda_graph: self.model.clear_graph_opt_backend() if envs.FD_USE_BLOCK_WISE_CUDA_GRAPH: from fastdeploy.model_executor.graph_optimization.cuda_graph_op import ( @@ -3270,40 +3290,43 @@ def _clear_cache_for_gdr_weight_update(self): and self.graph_opt_config.draft_model_use_cudagraph ): self.proposer.model.clear_graph_opt_backend() - if self.speculative_decoding and self.spec_method == SpecMethod.MTP: - self.proposer.clear_mtp_cache() - self.clear_cache() - if kv_cache_status: - while kv_cache_status.value[0] != KVCacheStatus.CLEARED: - time.sleep(0.01) - paddle.device.cuda.empty_cache() + self._cached_model_output_data = None self._cached_sampler_output = None self._cached_post_process_event = None self._cached_launch_token_num = -1 self._cached_real_bsz = -1 - def _rebuild_cache_after_gdr_weight_update(self): - cache_flag = ( - self.fd_config.cache_config.num_cpu_blocks > 0 - or self.fd_config.cache_config.kvcache_storage_backend is not None - ) - kv_cache_status = self.kv_cache_status if cache_flag else None - if kv_cache_status: - kv_cache_status.value[0] = KVCacheStatus.UPDATING + def _maybe_rebuild_memory_after_weight_update(self, rebuild_kv_cache=False, rebuild_cuda_graph=False): + if rebuild_kv_cache: + # Rebuild cache on cache transfer manager + cache_flag = ( + self.fd_config.cache_config.num_cpu_blocks > 0 + or self.fd_config.cache_config.kvcache_storage_backend is not None + ) + kv_cache_status = self.kv_cache_status if cache_flag else None + if kv_cache_status: + kv_cache_status.value[0] = KVCacheStatus.UPDATING + + # Rebuild cache on model runner + if not self.enable_cache_manager_v1: + self.proposer.initialize_kv_cache(main_model_num_blocks=self.num_gpu_blocks) + self.initialize_kv_cache() + + # Wait for cache rebuilt on both side + if kv_cache_status: + while kv_cache_status.value[0] != KVCacheStatus.NORMAL: + time.sleep(0.01) + self.share_inputs.reset_share_inputs() if self.spec_method == SpecMethod.MTP: self.proposer.model_inputs.reset_model_inputs() - if not self.enable_cache_manager_v1: - self.proposer.initialize_kv_cache(main_model_num_blocks=self.num_gpu_blocks) - self.initialize_kv_cache() - if self.use_cudagraph: + + if rebuild_cuda_graph: self.capture_model() + if self.fd_config.routing_replay_config.enable_routing_replay: self.routing_replay_manager.update_suspend_routing_replay() - if kv_cache_status: - while kv_cache_status.value[0] != KVCacheStatus.NORMAL: - time.sleep(0.01) def sleep(self, tags):