diff --git a/sam2/csrc/connected_components.cu b/sam2/csrc/connected_components.cu index ced21eb32..cd814fefa 100644 --- a/sam2/csrc/connected_components.cu +++ b/sam2/csrc/connected_components.cu @@ -212,6 +212,7 @@ __global__ void final_counting( std::vector get_connected_componnets( const torch::Tensor& inputs) { + AT_ASSERTM(inputs.is_cuda(), "inputs must be a CUDA tensor"); AT_ASSERTM(inputs.ndimension() == 4, "inputs must be [N, 1, H, W] shape"); AT_ASSERTM( @@ -240,7 +241,8 @@ std::vector get_connected_componnets( dim3 grid_count = dim3((W + BLOCK_COLS) / BLOCK_COLS, (H + BLOCK_ROWS) / BLOCK_ROWS); dim3 block_count = dim3(BLOCK_COLS, BLOCK_ROWS); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + auto device_idx = inputs.device().index(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(device_idx); for (int n = 0; n < N; n++) { uint32_t offset = n * H * W;