From ac83cf1609ae07623cd222d6760078b4a512f4c4 Mon Sep 17 00:00:00 2001 From: jiangshen Date: Wed, 1 Jan 2025 16:05:07 +0800 Subject: [PATCH] Solve a tinny bug of pipeline parallelism --- src/csrc/model/gpt/gpt.cc | 1 + src/csrc/util/nccl_utils.cc | 2 ++ 2 files changed, 3 insertions(+) diff --git a/src/csrc/model/gpt/gpt.cc b/src/csrc/model/gpt/gpt.cc index 3a8edf2..b28983c 100644 --- a/src/csrc/model/gpt/gpt.cc +++ b/src/csrc/model/gpt/gpt.cc @@ -560,6 +560,7 @@ std::vector Gpt::forward( sync_check_cuda_error(); } else { + d_position_ids.remalloc(num_tokens); if (parallelism_param.is_stage_leader()){ st::util::stNcclRecv( d_decoder_input.ptr, diff --git a/src/csrc/util/nccl_utils.cc b/src/csrc/util/nccl_utils.cc index b7bb460..2be1502 100644 --- a/src/csrc/util/nccl_utils.cc +++ b/src/csrc/util/nccl_utils.cc @@ -1,4 +1,5 @@ #include "nccl_utils.h" +#include "util/cuda_utils.h" #include #define NCCL_CHECK(cmd) \ @@ -99,6 +100,7 @@ void stNcclRecv( } NCCL_CHECK(ncclRecv(buff, count, datatype, recv_from, comm.comm, stream)); + CUDA_CHECK(cudaStreamSynchronize(stream)); } void stNcclSendRecv(