diff --git a/.gitignore b/.gitignore index 1d407a7..6ea5777 100644 --- a/.gitignore +++ b/.gitignore @@ -167,3 +167,27 @@ cython_debug/ # Use wildcards as well *~ *.o +# Miscallenous files generated by DGraph data processing +skbuild/ +.vscode/ +logs/ +torchrun_* +*.png +rdvz +*.pt +*.core +*.graph +*.out +*.gz +data_processed +*.zip +cache +graph_cache +*.nsys-rep +*.nsys +*.pth +*.pyc +*.npy +*.npz +*.sqlite +*.csv \ No newline at end of file diff --git a/DGraph/distributed/Engine.py b/DGraph/distributed/Engine.py index 19e7774..547aada 100644 --- a/DGraph/distributed/Engine.py +++ b/DGraph/distributed/Engine.py @@ -50,7 +50,7 @@ def scatter( output_size: int, rank_mappings: Optional[torch.Tensor] = None, *args, - **kwargs + **kwargs, ) -> torch.Tensor: raise NotImplementedError @@ -60,7 +60,7 @@ def gather( indices: Union[torch.Tensor, torch.LongTensor], rank_mappings: Optional[torch.Tensor] = None, *args, - **kwargs + **kwargs, ) -> torch.Tensor: raise NotImplementedError diff --git a/DGraph/distributed/RankLocalOps.py b/DGraph/distributed/RankLocalOps.py index c4b6de0..b7302f1 100644 --- a/DGraph/distributed/RankLocalOps.py +++ b/DGraph/distributed/RankLocalOps.py @@ -16,9 +16,15 @@ """ import torch +import torch.distributed as dist try: - from DGraph.torch_local import local_masked_gather, local_masked_scatter + from DGraph.torch_local import ( + local_masked_gather, + local_masked_scatter, + local_masked_scatter_gather, + local_masked_scatter_add_gather, + ) _LOCAL_OPT_KERNELS_AVAILABLE = True except ImportError: @@ -81,6 +87,93 @@ def OptimizedRankLocalMaskedGather( return output +def OptimizedLocalScatterGather( + src: torch.Tensor, + src_indices: torch.Tensor, + dst_indices: torch.Tensor, + output: torch.Tensor, +): + """ + Performs the operation + + for i in range(len(src_indices)): + output[dst_indices[i]] = src[src_indices[i]] + Args: + src (torch.Tensor): Source tensor + src_indices (torch.Tensor): Source indices + dst_indices (torch.Tensor): Destination indices + output (torch.Tensor): Output tensor + Returns: + torch.Tensor: Output tensor after scatter-gather + """ + + if not _LOCAL_OPT_KERNELS_AVAILABLE: + warnings.warn( + "Optimized local kernels are not available. Falling back to the default implementation." + ) + output[dst_indices] = src[src_indices] + else: + bs = src.shape[0] + num_src_rows = src.shape[1] + num_features = src.shape[-1] + num_output_rows = output.shape[1] + local_masked_scatter_gather( + src, + src_indices.cuda(), + dst_indices.cuda(), + output, + bs, + num_src_rows, + num_features, + num_output_rows, + ) + return output + + +def OptimizedLocalScatterSumGather( + src: torch.Tensor, + src_indices: torch.Tensor, + dst_indices: torch.Tensor, + output: torch.Tensor, +): + """ + Performs the operation + + for i in range(len(src_indices)): + output[dst_indices[i]] += src[src_indices[i]] + Args: + src (torch.Tensor): Source tensor + src_indices (torch.Tensor): Source indices + dst_indices (torch.Tensor): Destination indices + output (torch.Tensor): Output tensor + Returns: + torch.Tensor: Output tensor after scatter-gather + """ + + if not _LOCAL_OPT_KERNELS_AVAILABLE: + warnings.warn( + "Optimized local kernels are not available. Falling back to the default implementation." + ) + for i in range(src_indices.shape[0]): + output[:, dst_indices[i], :] += src[:, src_indices[i], :] + else: + bs = src.shape[0] + num_src_rows = src.shape[1] + num_features = src.shape[-1] + num_output_rows = output.shape[1] + local_masked_scatter_add_gather( + src, + src_indices.cuda(), + dst_indices.cuda(), + output, + bs, + num_src_rows, + num_features, + num_output_rows, + ) + return output + + def OutOfPlaceRankLocalMaskedGather( _src: torch.Tensor, indices: torch.Tensor, rank_mapping: torch.Tensor, rank: int ) -> torch.Tensor: @@ -140,7 +233,9 @@ def RankLocalRenumberingWithMapping(_indices, rank_mapping): unique_indices, inverse_indices = torch.unique(_indices, return_inverse=True) rank_mapping = rank_mapping.to(_indices.device) renumbered_indices = inverse_indices - unique_rank_mapping = torch.zeros_like(unique_indices, dtype=rank_mapping.dtype, device=rank_mapping.device) + unique_rank_mapping = torch.zeros_like( + unique_indices, dtype=rank_mapping.dtype, device=rank_mapping.device + ) unique_rank_mapping.scatter_(0, inverse_indices, rank_mapping) return renumbered_indices, unique_indices, unique_rank_mapping diff --git a/DGraph/distributed/csrc/local_data_kernels.cuh b/DGraph/distributed/csrc/local_data_kernels.cuh index f12ca4a..1b2ea2b 100644 --- a/DGraph/distributed/csrc/local_data_kernels.cuh +++ b/DGraph/distributed/csrc/local_data_kernels.cuh @@ -251,4 +251,144 @@ namespace Local } } } + + + + template + struct FloatAtomicAddOp + { + __device__ __forceinline__ void operator()(T *cur_addr, const T new_val) + { + atomicAdd(cur_addr, new_val); + } + }; + + template + struct FloatSetOp + { + __device__ __forceinline__ void operator()(T *cur_addr, const T new_val) + { + *cur_addr = new_val; + } + }; + + + /** + * + * Masked Gather Kernel operation that performs the operation: + Y [mask[i]] = Op(Y [mask[i]], X [indices[i]]) + + where Y is the output matrix, X is the input matrix, indices is the index matrix, and mask is the mask matrix. + */ + + template + __global__ void Masked_Scatter_Gather_Kernel( + const float *__restrict__ values, + const long *__restrict__ indices, + const long *__restrict__ mask, + float *__restrict__ output, + const int mini_batch_size, + const int num_indices, + const int num_cols, + const int num_output_rows) + { + const size_t gidx = threadIdx.x + blockIdx.x * blockDim.x; + const size_t gidy = threadIdx.y + blockIdx.y * blockDim.y; + const size_t gidz = threadIdx.z + blockIdx.z * blockDim.z; + + const size_t nthreadsx = gridDim.x * blockDim.x; + const size_t nthreadsy = gridDim.y * blockDim.y; + const size_t nthreadsz = gridDim.z * blockDim.z; + + Op op; + + for (size_t mb_i = gidz; mb_i < mini_batch_size; mb_i += nthreadsz) + { + const auto values_offset = mb_i * num_cols * num_indices; + const auto output_offset = mb_i * num_cols * num_output_rows; + const auto ind_offset = mb_i * num_indices; + const auto mask_offset = mb_i * num_indices; + + for (size_t row = gidy; row < num_indices; row += nthreadsy) + { + const auto output_row = mask[mask_offset + row]; + const auto input_row = indices[ind_offset + row]; + + for (size_t col = gidx; col < num_cols; col += nthreadsx) + { + auto *output_addr = &output[output_offset + output_row * num_cols + col]; + const auto input_val = values[values_offset + input_row * num_cols + col]; + op(output_addr, input_val); + } + } + } + } + + /* + * + Optimized masked scatter gather kernel that performs the operation: + Y [mask[i]] = X [indices[i]] + + This kernel is optimized for the case where the num_cols is a multiple of 4. + + where Y is the output matrix, X is the input matrix, indices is the index matrix, and mask is the mask matrix. + */ + template + __global__ void Optimized_Masked_Scatter_Gather_Kernel( + const float *__restrict__ values, + const long *__restrict__ indices, + const long *__restrict__ mask, + float *__restrict__ output, + const int mini_batch_size, + const int num_indices, + const int num_cols, + const int num_output_rows) + { + const size_t gidx = threadIdx.x + blockIdx.x * blockDim.x; + const size_t gidy = threadIdx.y + blockIdx.y * blockDim.y; + const size_t gidz = threadIdx.z + blockIdx.z * blockDim.z; + + const size_t nthreadsx = gridDim.x * blockDim.x; + const size_t nthreadsy = gridDim.y * blockDim.y; + const size_t nthreadsz = gridDim.z * blockDim.z; + + // Grid-stride loop over mini-batches + + Op binary_operator; + for (size_t mb_i = gidz; mb_i < mini_batch_size; mb_i += nthreadsz) + { + const auto values_offset = mb_i * num_cols / 4 * num_indices; + const auto output_offset = mb_i * num_cols / 4 * num_output_rows; + const auto ind_offset = mb_i * num_indices; + const auto mask_offset = mb_i * num_indices; + + // Grid-stride loop over rows + for (size_t row = gidy; row < num_indices; row += nthreadsy) + { + long output_row, input_row; + + if (threadIdx.x == 0) + { + output_row = mask[mask_offset + row]; + input_row = indices[ind_offset + row]; + } + + output_row = __shfl_sync(0xFFFFFFFF, output_row, 0); + input_row = __shfl_sync(0xFFFFFFFF, input_row, 0); + + output_row = mask[mask_offset + row]; + input_row = indices[ind_offset + row]; + + size_t col = gidx; + + for (; col < num_cols / 4; col += nthreadsx) + { + const float4 values_vec = reinterpret_cast(values)[values_offset + input_row * num_cols / 4 + col]; + float4* output_addr = &reinterpret_cast(output)[output_offset + output_row * num_cols / 4 + col]; + binary_operator(output_addr, values_vec); + } + } + } + } + } // namespace Local \ No newline at end of file diff --git a/DGraph/distributed/csrc/torch_local_bindings.cpp b/DGraph/distributed/csrc/torch_local_bindings.cpp index a91f516..fe685b6 100644 --- a/DGraph/distributed/csrc/torch_local_bindings.cpp +++ b/DGraph/distributed/csrc/torch_local_bindings.cpp @@ -21,4 +21,6 @@ PYBIND11_MODULE(torch_local, m) { m.def("local_masked_gather", &local_masked_gather, "Masked Gather"); m.def("local_masked_scatter", &local_masked_scatter, "Masked Scatter"); + m.def("local_masked_scatter_gather", &local_masked_scatter_gather, "Masked Scatter Gather"); + m.def("local_masked_scatter_add_gather", &local_masked_scatter_add_gather, "Masked Scatter Add Gather"); } diff --git a/DGraph/distributed/csrc/torch_local_kernels.cu b/DGraph/distributed/csrc/torch_local_kernels.cu index b70bf36..896050f 100644 --- a/DGraph/distributed/csrc/torch_local_kernels.cu +++ b/DGraph/distributed/csrc/torch_local_kernels.cu @@ -114,4 +114,120 @@ torch::Tensor local_masked_scatter(torch::Tensor input, rank); CUDACHECK(cudaGetLastError()); return output; +} + +torch::Tensor local_masked_scatter_gather(torch::Tensor input, + torch::Tensor indices, + torch::Tensor mask, + torch::Tensor output, + const int num_batches, + const int num_values_rows, + const int num_cols, + const int num_output_rows) +{ + CHECK_INPUT(input); + CHECK_INPUT(indices); + CHECK_INPUT(mask); + CHECK_INPUT(output); + + const float *input_ptr = input.data_ptr(); + const long *indices_ptr = indices.data_ptr(); + const long *mask_ptr = mask.data_ptr(); + float *output_ptr = output.data_ptr(); + + dim3 block_dims, grid_dims; + block_dims.x = 32; + block_dims.y = 32; + block_dims.z = 1; + + const auto num_grids_needed = (num_output_rows + block_dims.y - 1) / block_dims.y; + const auto num_col_grids_needed = (num_cols + block_dims.x - 1) / block_dims.x; + grid_dims.x = num_col_grids_needed < 65535 ? num_col_grids_needed : 65535; + grid_dims.y = num_grids_needed < 65535 ? num_grids_needed : 65535; + grid_dims.z = 1; + + at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream(input.device().index()); + + if (num_cols % 4 != 0) + { + Local::Masked_Scatter_Gather_Kernel><<>>(input_ptr, + indices_ptr, + mask_ptr, + output_ptr, + num_batches, + num_values_rows, + num_cols, + num_output_rows); + } + else + { + Local::Optimized_Masked_Scatter_Gather_Kernel><<>>(input_ptr, + indices_ptr, + mask_ptr, + output_ptr, + num_batches, + num_values_rows, + num_cols, + num_output_rows); + } + CUDACHECK(cudaGetLastError()); + return output; +} + +torch::Tensor local_masked_scatter_add_gather(torch::Tensor input, + torch::Tensor indices, + torch::Tensor mask, + torch::Tensor output, + const int num_batches, + const int num_values_rows, + const int num_cols, + const int num_output_rows) +{ + CHECK_INPUT(input); + CHECK_INPUT(indices); + CHECK_INPUT(mask); + CHECK_INPUT(output); + + const float *input_ptr = input.data_ptr(); + const long *indices_ptr = indices.data_ptr(); + const long *mask_ptr = mask.data_ptr(); + float *output_ptr = output.data_ptr(); + + dim3 block_dims, grid_dims; + block_dims.x = 32; + block_dims.y = 32; + block_dims.z = 1; + + const auto num_grids_needed = (num_output_rows + block_dims.y - 1) / block_dims.y; + const auto num_col_grids_needed = (num_cols + block_dims.x - 1) / block_dims.x; + grid_dims.x = num_col_grids_needed < 65535 ? num_col_grids_needed : 65535; + grid_dims.y = num_grids_needed < 65535 ? num_grids_needed : 65535; + grid_dims.z = 1; + + at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream(input.device().index()); + + if (num_cols % 4 != 0) + { + Local::Masked_Scatter_Gather_Kernel><<>>(input_ptr, + indices_ptr, + mask_ptr, + output_ptr, + num_batches, + num_values_rows, + num_cols, + num_output_rows); + } + else + { + Local::Optimized_Masked_Scatter_Gather_Kernel><<>>(input_ptr, + indices_ptr, + mask_ptr, + output_ptr, + num_batches, + num_values_rows, + num_cols, + num_output_rows); + } + CUDACHECK(cudaGetLastError()); + return output; } \ No newline at end of file diff --git a/DGraph/distributed/include/torch_local.hpp b/DGraph/distributed/include/torch_local.hpp index f780160..7a4a258 100644 --- a/DGraph/distributed/include/torch_local.hpp +++ b/DGraph/distributed/include/torch_local.hpp @@ -19,4 +19,22 @@ torch::Tensor local_masked_scatter(torch::Tensor input, const int num_values_rows, const int num_cols, const int num_output_rows, - const int rank); \ No newline at end of file + const int rank); + +torch::Tensor local_masked_scatter_gather(torch::Tensor input, + torch::Tensor indices, + torch::Tensor rank_local_placement, + torch::Tensor output, + const int num_batches, + const int num_values_rows, + const int num_cols, + const int num_output_rows); + +torch::Tensor local_masked_scatter_add_gather(torch::Tensor input, + torch::Tensor indices, + torch::Tensor rank_local_placement, + torch::Tensor output, + const int num_batches, + const int num_values_rows, + const int num_cols, + const int num_output_rows); \ No newline at end of file diff --git a/DGraph/distributed/nccl/NCCLBackendEngine.py b/DGraph/distributed/nccl/NCCLBackendEngine.py index b3ea11a..b8d2fd0 100644 --- a/DGraph/distributed/nccl/NCCLBackendEngine.py +++ b/DGraph/distributed/nccl/NCCLBackendEngine.py @@ -16,488 +16,22 @@ import torch import torch.distributed as dist from DGraph.distributed.Engine import BackendEngine -from DGraph.distributed.nccl._indices_utils import ( - _generate_local_rank_mapping, - _get_local_unique_recv_placement, -) -from DGraph.distributed.nccl._nccl_cache import NCCLGatherCache, NCCLScatterCache -from DGraph.distributed.nccl.alltoallv_impl import ( - _nccl_alltoall_v, - _nccl_alltoallv_with_dict, -) -from DGraph.distributed.RankLocalOps import ( - RankLocalMaskedGather, - RankLocalMaskedScatter, - RankLocalRenumberingWithMapping, - OptimizedRankLocalMaskedGather, +from DGraph.distributed.nccl._NCCLCommPlan import NCCLGraphCommPlan +from DGraph.distributed.nccl._torch_func_impl import ( + GatherFunction, + ScatterFunction, + CommPlan_ScatterFunction, + CommPlan_GatherFunction, ) + from torch.autograd import Function from DGraph.utils import largest_split +from typing import overload TIMINGS = {"Gather_Index_Forward": [], "Gather_Forward_Local": []} -class GatherFunction(Function): - @staticmethod - def forward( - ctx, - local_send_tensor: torch.Tensor, - indices: torch.LongTensor, - # vertex_ranks: torch.Tensor, - edge_rank_loc: torch.Tensor, - edge_dest_ranks: torch.Tensor, - rank: int, - world_size: int, - cache: Optional[NCCLGatherCache] = None, - ): - num_local_input_rows = local_send_tensor.shape[1] - - if cache is not None: - # We have a cache, use it, don't need to save anything - ctx.has_cache = True - ctx.cache = cache - # TODO: Should we cash the indices as well? - S.Z - else: - ctx.has_cache = False - - ctx.save_for_backward( - indices, - edge_rank_loc, - edge_dest_ranks, - torch.tensor(num_local_input_rows), - torch.tensor(rank), - torch.tensor(world_size), - ) - - # Since NCCL is two-sided, we need to push from local rank and pull from - # remote rank to get the global gather - - # TODO: One possible optmization is cache all these calculations - # and only do the gather when the cache is invalidated. Essentially - # if we are working with static graphs, the indices and distribution pattern - # will not change and we can cache the communication pattern. - S.Z - - # We can also pre-compute this on the data ingestion side. Might - # be worth looking to some kind of cached communication pattern store - # that can be passed to the communicator. - S.Z - - batch_size = 1 - num_features = local_send_tensor.shape[2] - - if cache is not None: - local_indices = cache.gather_local_indices % local_send_tensor.shape[1] - local_gather_mask = cache.gather_local_comm_mask - needs_comm = cache.gather_needs_comm - local_output_rows = cache.gather_num_output_rows - local_rank_mapping = cache.gather_local_remapped_ranks - recv_tensor = torch.zeros(batch_size, local_output_rows, num_features).to( - local_send_tensor.device - ) - local_recv_tensor = cache.gather_local_recv_mapping - else: - # Get the edges that are local to the rank - - local_slice_mask = edge_rank_loc == rank - - num_local_output_rows = int(local_slice_mask.sum().item()) - - recv_tensor = torch.zeros( - batch_size, num_local_output_rows, num_features - ).to(local_send_tensor.device) - - local_indices_slice = indices[local_slice_mask.unsqueeze(0)] - local_rank_mapping = edge_rank_loc[local_slice_mask] - local_recv_tensor = edge_dest_ranks[local_slice_mask] - - # assert torch.all(local_recv_tensor == rank), local_recv_tensor - - local_indices = local_indices_slice % local_send_tensor.shape[1] - - needs_comm = (local_recv_tensor != rank).any() - - recv_tensor = OptimizedRankLocalMaskedGather( - local_send_tensor, - local_indices, - local_rank_mapping, - recv_tensor, - rank, - ) - - if needs_comm: - - recv_tensor = _nccl_alltoall_v( - local_send_tensor=local_send_tensor, - local_recv_tensor=recv_tensor, - indices=indices, - local_rank_mapping=local_recv_tensor, - edge_rank_loc=edge_rank_loc, - src_rank_loc=edge_dest_ranks, - rank=rank, - world_size=world_size, - cache=cache, - ) - - return recv_tensor - - @staticmethod - def backward(ctx, grad_output): - # We need to switch the send and recv ranks - ( - indices, - recv_ranks, - send_ranks, - # vertices_per_rank, - num_local_input_rows, - rank, - world_size, - ) = ctx.saved_tensors - - if ctx.has_cache: - cache: Optional[NCCLGatherCache] = ctx.cache - else: - cache = None - - num_local_output_rows = num_local_input_rows.item() - rank = rank.item() - world_size = world_size.item() - send_tensor = grad_output - - # Now it's a scatter operation - num_features = send_tensor.shape[-1] - device = send_tensor.device - local_rank_output = torch.zeros(1, num_local_output_rows, num_features).to( - device - ) - - indices = indices.view(-1) - local_slice_mask = recv_ranks == rank - local_indices_slice = indices[local_slice_mask] - local_dest_ranks = send_ranks[local_slice_mask] - - local_rank_output = RankLocalMaskedScatter( - send_tensor, - local_rank_output, - local_indices_slice, - local_dest_ranks, - rank, - ) - - if cache is not None: - local_comm_mask = cache.scatter_local_comm_mask - else: - local_comm_mask = local_dest_ranks != rank - - send_buffer_dict = {} - if torch.any(local_comm_mask): - # These rows need to be sent to other ranks - # First aggregate these into a single buffer - - if cache is not None: - num_remote_rows = cache.scatter_num_remote_rows - remapped_ranks = cache.scatter_local_remapped_ranks - renumbered_indices = cache.scatter_renumbered_indices - receiving_ranks = cache.scatter_remote_send_to_ranks - - else: - - local_comm_indices = local_indices_slice[local_comm_mask] - local_remote_dest_mappings = local_dest_ranks[local_comm_mask] - - renumbered_indices, unique_indices, remapped_ranks = ( - RankLocalRenumberingWithMapping( - local_comm_indices, local_remote_dest_mappings - ) - ) - receiving_ranks = torch.unique(local_dest_ranks[local_comm_mask]) - num_remote_rows = len(unique_indices) - - buffer = torch.zeros(1, num_remote_rows, num_features).to(device) - buffer.scatter_add_( - 1, - renumbered_indices.view(1, -1, 1).expand(1, -1, num_features), - send_tensor[:, local_comm_mask, :], - ) - - for _recv_rank in receiving_ranks: - _recv_indices = remapped_ranks == _recv_rank - send_buffer_dict[_recv_rank.item()] = buffer[:, _recv_indices, :] - - # Now we need to receive the data from the remote ranks - - recv_buffer_dict = {} - - recv_placement = {} - - if cache is not None: - recv_placement = cache.scatter_recv_local_placement - - # Allocate the receive buffers for the communication based on the - # size of the recv_placement indices. - for key, unique_send_indices in recv_placement.items(): - num_elements = unique_send_indices.shape[0] - recv_buffer_dict[key] = torch.zeros(1, num_elements, num_features).to( - device - ) - else: - send_to_rank = send_ranks # Pedantic variable name change - all_comm_mask = send_to_rank != recv_ranks - reciever_mask = send_to_rank == rank - receive_from_remote = all_comm_mask & reciever_mask - - if torch.any(receive_from_remote): - receive_from_ranks = recv_ranks[receive_from_remote] - - for _sender in range(world_size): - if _sender == rank: - continue - if torch.any(receive_from_ranks == _sender): - _send_mask = (recv_ranks == _sender) & receive_from_remote - _send_indices = indices[_send_mask] % num_local_output_rows - # TODO: This is brittle, look into a better way to do this - S.Z - - unique_send_indices = torch.unique(_send_indices) - num_elements = unique_send_indices.shape[0] - recv_buffer_dict[_sender] = torch.zeros( - 1, num_elements, num_features - ).cuda() - recv_placement[_sender] = unique_send_indices - - recv_buffer_dict = _nccl_alltoallv_with_dict( - send_buffer_dict, recv_buffer_dict, rank, world_size - ) - for key, recv_buffer in recv_buffer_dict.items(): - local_rank_output.scatter_add_( - 1, - recv_placement[key].view(1, -1, 1).expand(1, -1, num_features), - recv_buffer, - ) - - send_tensor_grad = local_rank_output - indices_grad = None - send_ranks_grad = None - recv_ranks_grad = None - rank_grad = None - world_size_grad = None - cache_grad = None - - return ( - send_tensor_grad, - indices_grad, - send_ranks_grad, - recv_ranks_grad, - rank_grad, - world_size_grad, - cache_grad, - ) - - -class ScatterFunction(Function): - @staticmethod - def forward( - ctx, - send_tensor: torch.Tensor, - indices: torch.Tensor, - edge_src_ranks: torch.Tensor, - edge_dest_ranks: torch.Tensor, - num_local_output_rows: int, - rank: int, - world_size: int, - scatter_cache: Optional[NCCLScatterCache] = None, - ) -> torch.Tensor: - - ctx.save_for_backward( - indices, - edge_src_ranks, - edge_dest_ranks, - torch.tensor(num_local_output_rows), - torch.tensor(rank), - torch.tensor(world_size), - ) - use_cache = scatter_cache is not None - if use_cache: - ctx.scatter_cache = scatter_cache - ctx.has_cache = True - else: - ctx.has_cache = False - - num_features = send_tensor.shape[-1] - device = send_tensor.device - - local_rank_output = torch.zeros(1, num_local_output_rows, num_features).to( - device - ) - - indices = indices.view(-1) - - local_edge_mask = edge_src_ranks == rank - - local_indices_slice = indices[local_edge_mask] - local_dest_ranks = edge_dest_ranks[local_edge_mask] - - local_rank_output = RankLocalMaskedScatter( - send_tensor, - local_rank_output, - local_indices_slice, - local_dest_ranks, - rank, - ) - - if use_cache: - local_comm_mask = scatter_cache.scatter_local_comm_mask - else: - local_comm_mask = local_dest_ranks != rank - - all_comm_mask = edge_src_ranks != edge_dest_ranks - reciever_mask = edge_dest_ranks == rank - receive_from_remote_mask = all_comm_mask & reciever_mask - - send_buffer_dict = {} - - if torch.any(local_comm_mask): - - if use_cache: - num_remote_rows = scatter_cache.scatter_num_remote_rows - remapped_ranks = scatter_cache.scatter_local_remapped_ranks - renumbered_indices = scatter_cache.scatter_local_renumbered_indices - receving_ranks = scatter_cache.scatter_remote_send_to_ranks - - else: - # These rows need to be sent to other ranks - # First aggregate these into a single buffer - local_comm_indices = local_indices_slice[local_comm_mask] - local_remote_dest_mappings = local_dest_ranks[local_comm_mask] - # TODO: This is very slow, look into a better way to do this - S.Z - # Uncached is slow, should look into augmenting torch functions - # to speed this up - S.Z - renumbered_indices, unique_indices, remapped_ranks = ( - RankLocalRenumberingWithMapping( - local_comm_indices, local_remote_dest_mappings - ) - ) - num_remote_rows = len(unique_indices) - receving_ranks = torch.unique(local_dest_ranks[local_comm_mask]) - - buffer = torch.zeros(1, num_remote_rows, num_features).to(device) - buffer.scatter_add_( - 1, - renumbered_indices.view(1, -1, 1).expand(1, -1, num_features), - send_tensor[:, local_comm_mask, :], - ) - - for _recv_rank in receving_ranks: - _recv_indices = remapped_ranks == _recv_rank - send_buffer_dict[_recv_rank.item()] = buffer[:, _recv_indices, :] - - recv_buffer_dict = {} - recv_placement = {} - if use_cache: - recv_placement = scatter_cache.scatter_recv_local_placement - else: - recv_placement = _get_local_unique_recv_placement( - indices, - edge_src_ranks, - receive_from_remote_mask, - num_local_output_rows, - rank, - world_size, - ) - - # Allocate the receive buffers for the communication based on the - # size of the recv_placement indices. - for key, unique_send_indices in recv_placement.items(): - num_elements = unique_send_indices.shape[0] - recv_buffer_dict[key] = torch.zeros(1, num_elements, num_features).to( - device - ) - recv_buffer_dict = _nccl_alltoallv_with_dict( - send_buffer_dict, recv_buffer_dict, rank, world_size - ) - for key, recv_buffer in recv_buffer_dict.items(): - local_rank_output.scatter_add_( - 1, - recv_placement[key].view(1, -1, 1).expand(1, -1, num_features), - recv_buffer, - ) - return local_rank_output - - @staticmethod - def backward(ctx, grad_output): - # We need to switch the send and recv ranks - indices, recv_ranks, send_ranks, num_input_rows, rank, world_size = ( - ctx.saved_tensors - ) - - local_mask = recv_ranks == rank - if ctx.has_cache: - cache: NCCLScatterCache = ctx.scatter_cache - num_local_output_rows = cache.gather_num_output_rows - - else: - rank = int(rank.item()) - world_size = int(world_size.item()) - - indices = indices.view(1, -1) - - # Now it's a gather operation - - num_local_output_rows = int(local_mask.sum().item()) - - batch_size = 1 - num_features = grad_output.shape[2] - - recv_tensor = torch.zeros(batch_size, num_local_output_rows, num_features).to( - grad_output.device - ) - - local_indices_slice = indices[0][local_mask] - local_rank_mapping = send_ranks[local_mask] - - local_indices = local_indices_slice % grad_output.shape[1] - - if len(local_indices_slice) > 0: - - recv_tensor[:, local_rank_mapping == rank, :] = RankLocalMaskedGather( - grad_output, local_indices, local_rank_mapping, rank - ) - - recv_tensor = _nccl_alltoall_v( - local_send_tensor=grad_output, - local_recv_tensor=recv_tensor, - indices=indices, - local_rank_mapping=local_rank_mapping, - edge_rank_loc=send_ranks, - src_rank_loc=recv_ranks, - rank=rank, - world_size=world_size, - ) - - # if rank == 0: - # breakpoint() - # dist.barrier() - # NOTE: even if the inputs are non-tensors, the number of backward outputs - # must be the same as the number of inputs. - send_tensor_grad = recv_tensor - indices_grad = None - send_ranks_grad = None - recv_ranks_grad = None - num_local_output_rows_grad = None - rank_grad = None - world_size_grad = None - scatter_cache_grad = None - - return ( - send_tensor_grad, - indices_grad, - send_ranks_grad, - recv_ranks_grad, - num_local_output_rows_grad, - rank_grad, - world_size_grad, - scatter_cache_grad, - ) - - class NCCLBackendEngine(BackendEngine): _is_initialized = False _rank = -1 @@ -559,66 +93,99 @@ def get_local_rank_slice(self, tensor: torch.Tensor, dim: int) -> torch.Tensor: end_index = start_index + local_size return tensor[:, start_index:end_index] + @overload def scatter( self, local_send_tensor: torch.Tensor, indices: torch.Tensor, rank_mappings: torch.Tensor, output_size: int, - cache: Optional[NCCLScatterCache] = None, - *args, - **kwargs, + ) -> torch.Tensor: ... + + @overload + def scatter( + self, + local_send_tensor: torch.Tensor, + *, + comm_plan: NCCLGraphCommPlan, + ) -> torch.Tensor: ... + + def scatter( + self, + local_send_tensor: torch.Tensor, + indices: Optional[torch.Tensor] = None, + rank_mappings: Optional[torch.Tensor] = None, + output_size: Optional[int] = None, + comm_plan: Optional[NCCLGraphCommPlan] = None, ) -> torch.Tensor: - send_tensor_shape = local_send_tensor.shape - b_size = send_tensor_shape[0] - world_size = self.get_world_size() - rank = self.get_rank() - assert b_size == 1, "Multi-batch gather disabled for testing" - assert len(send_tensor_shape) == 3, "Currently only support 3D tensors" - assert indices.shape[-1] == rank_mappings.shape[-1], ( - f"Indices shape: {indices.shape} and rank mappings shape: " - + f" {rank_mappings.shape} must match" - ) - assert rank_mappings.shape[0] == 2, ( - "Rank mappings shape[0] expected to be 2, " - + f"but got {rank_mappings.shape[0]}" - ) - assert ( - local_send_tensor.device.type == "cuda" - ), f"Device: {local_send_tensor.device.type} expected cuda" - assert output_size > 0, "Output size must be greater than 0" + if comm_plan is not None: + return CommPlan_ScatterFunction.apply(local_send_tensor, comm_plan) # type: ignore + else: + if indices is None or rank_mappings is None or output_size is None: + raise ValueError( + "Indices, rank mappings, and output size must be provided for NCCL backend" + ) - src_ranks = rank_mappings[0] - dest_ranks = rank_mappings[1] + send_tensor_shape = local_send_tensor.shape + b_size = send_tensor_shape[0] - use_cache = cache is not None + world_size = self.get_world_size() + rank = self.get_rank() + assert b_size == 1, "Multi-batch gather disabled for testing" + assert len(send_tensor_shape) == 3, "Currently only support 3D tensors" + assert indices.shape[-1] == rank_mappings.shape[-1], ( + f"Indices shape: {indices.shape} and rank mappings shape: " + + f" {rank_mappings.shape} must match" + ) + assert rank_mappings.shape[0] == 2, ( + "Rank mappings shape[0] expected to be 2, " + + f"but got {rank_mappings.shape[0]}" + ) + assert ( + local_send_tensor.device.type == "cuda" + ), f"Device: {local_send_tensor.device.type} expected cuda" + assert output_size > 0, "Output size must be greater than 0" - if use_cache: - assert type(cache) == NCCLScatterCache - scatter_cache = cache - else: - scatter_cache = None + src_ranks = rank_mappings[0] + dest_ranks = rank_mappings[1] - output_tensor = ScatterFunction.apply( - local_send_tensor, - indices, - src_ranks, - dest_ranks, - output_size, - rank, - world_size, - scatter_cache, - ) + output_tensor = ScatterFunction.apply( + local_send_tensor, + indices, + src_ranks, + dest_ranks, + output_size, + rank, + world_size, + ) return output_tensor # type: ignore + @overload + def gather( + self, + local_send_tensor: torch.Tensor, + indices: torch.Tensor, + rank_mappings: torch.Tensor, + **kwargs, + ) -> torch.Tensor: ... + + @overload + def gather( + self, + local_send_tensor: torch.Tensor, + *, + comm_plan: NCCLGraphCommPlan, + **kwargs, + ) -> torch.Tensor: ... + def gather( self, local_send_tensor: torch.Tensor, indices: torch.Tensor, rank_mappings: torch.Tensor, - cache: Optional[NCCLGatherCache] = None, + comm_plan: Optional[NCCLGraphCommPlan] = None, **kwargs, ) -> torch.Tensor: """Gather the distributed tensor across all ranks according to the indices @@ -644,6 +211,9 @@ def gather( rank_mappings (torch.Tensor): The rank mappings for the gather operation """ + if comm_plan is not None: + return CommPlan_GatherFunction.apply(local_send_tensor, comm_plan) # type: ignore + send_tensor_shape = local_send_tensor.shape b_size = send_tensor_shape[0] world_size = self.get_world_size() @@ -667,14 +237,6 @@ def gather( send_rank = rank_mappings[0] recv_rank = rank_mappings[1] - use_cache = cache is not None - - if use_cache: - assert type(cache) == NCCLGatherCache, f"Invalid cache type {type(cache)}" - gather_cache = cache - else: - gather_cache = None - output_tensor = GatherFunction.apply( local_send_tensor, indices, @@ -682,7 +244,6 @@ def gather( recv_rank, rank, world_size, - gather_cache, ) dist.barrier() diff --git a/DGraph/distributed/nccl/_NCCLCommPlan.py b/DGraph/distributed/nccl/_NCCLCommPlan.py new file mode 100644 index 0000000..ad5a4e7 --- /dev/null +++ b/DGraph/distributed/nccl/_NCCLCommPlan.py @@ -0,0 +1,272 @@ +import torch +from dataclasses import dataclass +from typing import List, Optional +import torch.distributed as dist + + +@dataclass +class NCCLGraphCommPlan: + """ + Class to store communication plan for distributed gather-scatter (vector addressing) + + Attributes: + rank (int): Local rank + world_size (int): World size + local_num_vertices (int): Number of local vertices + local_src_idx (torch.Tensor): Local source indices for scatter-sum + local_dst_idx (torch.Tensor): Local destination indices for scatter-sum + send_src_idx (torch.Tensor): Source indices to send to other ranks + send_buffer_idx (torch.Tensor): Buffer indices to store data to send to other ranks + send_comm_vector (torch.Tensor): Communication vector of shape [world_size] of messages to send to each rank + recv_dst_idx (torch.Tensor): Destination indices to receive from other ranks + recv_comm_vector (torch.Tensor): Communication vector of shape [world_size] of messages to + """ + + rank: int + world_size: int + + # Allocation meta data + num_local_vertices: int + num_local_edges: int + + # Local edge-vertex mapping + # + # Used for: + # 1) Local scatter-sum (edge -> vertex aggregation) + # y[local_vertex_idx] += x[local_edge_idx] + # 2) Local gather (vertex -> edge gathering) + # y[local_edge_idx] = x[local_vertex_idx] + + local_edge_idx: torch.Tensor + local_vertex_idx: torch.Tensor + + # Boundary edges (data must be sent/received to/from other ranks for gather/scatter) + + boundary_edge_idx: torch.Tensor + boundary_edge_buffer_map: torch.Tensor + boundary_edge_splits: List[int] + + # Boundary vertices (vertices that have edges on other ranks) + boundary_vertex_idx: torch.Tensor + boundary_vertex_splits: List[int] + + def to(self, device: torch.device): + self.local_edge_idx = self.local_edge_idx.to(device) + self.local_vertex_idx = self.local_vertex_idx.to(device) + self.boundary_edge_idx = self.boundary_edge_idx.to(device) + self.boundary_edge_buffer_map = self.boundary_edge_buffer_map.to(device) + self.boundary_vertex_idx = self.boundary_vertex_idx.to(device) + return self + + +@dataclass +class NCCLEdgeConditionedGraphCommPlan: + """ + Class to store communication plan for distributed gather-scatter for edge-conditioned + graphs where both source and destination vertices are needed. + + Attributes: + rank (int): Local rank + world_size (int): World size + + source_graph_plan (NCCLGraphCommPlan): Communication plan for source vertices + dest_graph_plan (NCCLGraphCommPlan): Communication plan for destination vertices + """ + + rank: int + world_size: int + + source_graph_plan: NCCLGraphCommPlan + dest_graph_plan: Optional[NCCLGraphCommPlan] = None + + def to(self, device: torch.device): + self.source_graph_plan = self.source_graph_plan.to(device) + if self.dest_graph_plan is not None: + self.dest_graph_plan = self.dest_graph_plan.to(device) + return self + + +def compute_edge_slices(dest_ranks, rank, my_dst_global, offset): + + is_internal = dest_ranks == rank + internal_dst_global = my_dst_global[is_internal] + internal_node_idx = internal_dst_global - offset[rank + 1] + + internal_edge_indices = torch.nonzero(is_internal, as_tuple=True)[0] + + remote_mask = ~is_internal + + boundary_edge_indices = torch.nonzero(remote_mask, as_tuple=True)[0] + + b_dst_global = my_dst_global[remote_mask] + b_dest_ranks = dest_ranks[remote_mask] + + return ( + internal_node_idx, + internal_edge_indices, + b_dst_global, + b_dest_ranks, + boundary_edge_indices, + ) + + +def fast_2D_unique(indices_1, indices_2): + packed_keys = indices_1.to(torch.int64) << 32 | indices_2.to(torch.int64) + unique_packed, inverse_indices = torch.unique( + packed_keys, return_inverse=True, sorted=False + ) + unique_1 = unique_packed >> 32 + unique_2 = unique_packed & 0xFFFFFFFF + return unique_1, unique_2, inverse_indices + + +def COO_to_NCCLCommPlan( + rank: int, + world_size: int, + global_edges_dst: torch.Tensor, + local_edge_list: torch.Tensor, + offset: torch.Tensor, +) -> NCCLGraphCommPlan: + """ + + Convert COO (Coordinate List) format graph to NCCLGraphCommPlan for distributed gather-scatter operations. + + Args: + rank (int): Local rank + world_size (int): World size + global_edges_src (torch.Tensor): Global source indices of edges + global_edges_dst (torch.Tensor): Global destination indices of edges + vertex_rank_placement (torch.Tensor): Rank placement of vertices + local_edge_list (torch.Tensor): List of indices of local edges + offset (torch.Tensor): Offset for each rank. + The vertices are partitioned among ranks in a contiguous manner. + All vertices in the range [offset[rank], offset[rank + 1]) are assigned to the rank. + + """ + device = local_edge_list.device + my_dst_global = global_edges_dst[local_edge_list].to(device) + + if int(offset[-1].item()) > (2**32): + raise ValueError( + f"{offset[-1]}, Number of vertices exceeding {2**32}, which is not supported" + ) + + my_start = offset[rank].item() + my_end = offset[rank + 1].item() + num_local_vertices = int(my_end - my_start) + num_local_edges = local_edge_list.size(0) + + dest_ranks = torch.bucketize(my_dst_global, offset, right=True) - 1 + + # Seperate this out to reduce memory usage + ( + internal_node_idx, + internal_edge_indices, + b_dst_global, + b_dest_ranks, + boundary_edge_indices, + ) = compute_edge_slices(dest_ranks, rank, my_dst_global, offset) + + unique_ranks, unique_global_ids, inverse_indices = fast_2D_unique( + b_dest_ranks, b_dst_global + ) + + print(f"Rank {rank} has {len(boundary_edge_indices)} edges to send ") + print(f"Rank {rank} has {len(unique_ranks)} unique messages to send ") + + if len(unique_ranks) > 0: + print( + f"Rank {rank} message reduction ratio: {len(boundary_edge_indices)/len(unique_ranks)}" + ) + + boundary_edge_buffer_map = inverse_indices + + boundary_edge_splits = torch.bincount(unique_ranks, minlength=world_size).tolist() + + recv_counts_tensor = torch.zeros(world_size, dtype=torch.long, device=device) + send_counts_tensor = torch.tensor( + boundary_edge_splits, dtype=torch.long, device=device + ) + dist.all_to_all_single(recv_counts_tensor, send_counts_tensor) + boundary_node_splits = recv_counts_tensor.tolist() + + total_recv_nodes = sum(boundary_node_splits) + recv_global_ids = torch.empty(total_recv_nodes, dtype=torch.long, device=device) + + dist.all_to_all_single( + recv_global_ids, + unique_global_ids, + output_split_sizes=boundary_node_splits, + input_split_sizes=boundary_edge_splits, + ) + + boundary_node_idx = recv_global_ids - my_start + + return NCCLGraphCommPlan( + rank=rank, + world_size=world_size, + num_local_vertices=num_local_vertices, + num_local_edges=num_local_edges, + local_edge_idx=internal_edge_indices, + local_vertex_idx=internal_node_idx, + boundary_edge_idx=boundary_edge_indices, + boundary_edge_buffer_map=boundary_edge_buffer_map, + boundary_edge_splits=boundary_edge_splits, + boundary_vertex_idx=boundary_node_idx, + boundary_vertex_splits=boundary_node_splits, + ) + + +def COO_to_NCCLEdgeConditionedCommPlan( + rank: int, + world_size: int, + global_edges_src: torch.Tensor, + global_edges_dst: torch.Tensor, + local_edge_list: torch.Tensor, + src_offset: torch.Tensor, + dest_offset: Optional[torch.Tensor], +) -> NCCLEdgeConditionedGraphCommPlan: + """ + + Convert COO (Coordinate List) format graph to NCCLEdgeConditionedGraphCommPlan for distributed gather-scatter operations. + + Args: + rank (int): Local rank + world_size (int): World size + global_edges_src (torch.Tensor): Global source indices of edges + global_edges_dst (torch.Tensor): Global destination indices of edges + local_edge_list (torch.Tensor): List of indices of local edges + src_offset (torch.Tensor): Offset for each rank for source vertices. + The vertices are partitioned among ranks in a contiguous manner. + All vertices in the range [src_offset[rank], src_offset[rank + 1]) are assigned to the rank. + dest_offset (Optional[torch.Tensor]): Offset for each rank for destination vertices. + The vertices are partitioned among ranks in a contiguous manner. + All vertices in the range [dest_offset[rank], dest_offset[rank + 1]) are assigned to the rank. + """ + device = local_edge_list.device + + source_plan = COO_to_NCCLCommPlan( + rank, + world_size, + global_edges_src, + local_edge_list, + src_offset, + ) + + if dest_offset is None: + dest_offset = src_offset + + dest_plan = COO_to_NCCLCommPlan( + rank, + world_size, + global_edges_dst, + local_edge_list, + dest_offset, + ) + + return NCCLEdgeConditionedGraphCommPlan( + rank=rank, + world_size=world_size, + source_graph_plan=source_plan, + dest_graph_plan=dest_plan, + ) diff --git a/DGraph/distributed/nccl/__init__.py b/DGraph/distributed/nccl/__init__.py index cf28164..aae0291 100644 --- a/DGraph/distributed/nccl/__init__.py +++ b/DGraph/distributed/nccl/__init__.py @@ -12,9 +12,9 @@ # # SPDX-License-Identifier: (Apache-2.0) from DGraph.distributed.nccl.NCCLBackendEngine import NCCLBackendEngine, TIMINGS -from DGraph.distributed.nccl._nccl_cache import ( - NCCLGatherCache, - NCCLScatterCache, - NCCLScatterCacheGenerator, - NCCLGatherCacheGenerator, +from DGraph.distributed.nccl._NCCLCommPlan import ( + NCCLGraphCommPlan, + NCCLEdgeConditionedGraphCommPlan, + COO_to_NCCLCommPlan, + COO_to_NCCLEdgeConditionedCommPlan, ) diff --git a/DGraph/distributed/nccl/_nccl_cache.py b/DGraph/distributed/nccl/_nccl_cache.py index 28a2d01..b378f1c 100644 --- a/DGraph/distributed/nccl/_nccl_cache.py +++ b/DGraph/distributed/nccl/_nccl_cache.py @@ -64,6 +64,9 @@ class NCCLScatterCache: world_size: int +# @dataclass +# class + def all_to_all_cache_helper( indices, edge_placement, edge_vertex_ranks, num_rows, rank, world_size ): @@ -200,7 +203,6 @@ def NCCLScatterCacheGenerator( recv_placement = _get_local_unique_recv_placement( indices, edge_placement, remote_recv_mask, num_output_rows, rank, world_size ) - # Information for the backward pass # It's a gather operation so quite a bit simpler @@ -253,7 +255,6 @@ def NCCLGatherCacheGenerator( indices, edge_placement, edge_dest_ranks, num_input_rows, rank, world_size ) ) - local_slice_mask = edge_placement == rank local_mask = edge_placement[local_slice_mask] diff --git a/DGraph/distributed/nccl/_torch_func_impl.py b/DGraph/distributed/nccl/_torch_func_impl.py new file mode 100644 index 0000000..71880b7 --- /dev/null +++ b/DGraph/distributed/nccl/_torch_func_impl.py @@ -0,0 +1,673 @@ +import torch +from typing import Optional +from torch.autograd import Function +import torch.distributed as dist +from dataclasses import dataclass +from DGraph.distributed.nccl._nccl_cache import NCCLGatherCache, NCCLScatterCache +from DGraph.distributed.RankLocalOps import ( + OptimizedRankLocalMaskedGather, + OptimizedLocalScatterGather, + OptimizedLocalScatterSumGather, +) +from DGraph.distributed.nccl._NCCLCommPlan import NCCLGraphCommPlan + + +class CommPlan_GatherFunction(Function): + @staticmethod + def forward( + ctx, + local_send_tensor: torch.Tensor, + comm_plan: NCCLGraphCommPlan, + ) -> torch.Tensor: + """ + Forward pass for distributed gather using the common plan to effectively perform: + y[i] = x[indices[i]] + + The process is as follows: + 1) Perform local gather from local vertices to local edges + 2) Gather + + Args: + ctx (torch.autograd.FunctionContext): Context object + local_send_tensor (torch.Tensor): Local send tensor + comm_plan (GatherCommPlan): Communication plan + """ + assert ( + len(local_send_tensor.shape) == 3 + ), "Local send tensor must be of shape (batch_size, num_rows, num_features)" + ctx.comm_plan = comm_plan + + num_features = local_send_tensor.shape[-1] + num_batches = local_send_tensor.shape[0] + + output_tensor = torch.zeros( + num_batches, comm_plan.num_local_edges, num_features + ).to(local_send_tensor.device) + + # Local vertex to edge gather + output_tensor = OptimizedLocalScatterGather( + src=local_send_tensor, + src_indices=comm_plan.local_edge_idx, + dst_indices=comm_plan.local_vertex_idx, + output=output_tensor, + ) + + # To do: Combine this with the local gather above to reduce kernel launches + send_buf = local_send_tensor[:, comm_plan.boundary_edge_idx, :] + + total_recv = sum(comm_plan.boundary_edge_splits) + + recv_buffer = torch.empty(num_batches, total_recv, num_features).to( + local_send_tensor.device + ) + dist.all_to_all_single( + recv_buffer, + send_buf, + output_split_sizes=comm_plan.boundary_edge_splits, + input_split_sizes=comm_plan.boundary_edge_splits, + ) + + output_tensor = OptimizedLocalScatterGather( + src=recv_buffer, + src_indices=comm_plan.boundary_edge_buffer_map, + dst_indices=comm_plan.boundary_edge_idx, + output=output_tensor, + ) + + return output_tensor + + @staticmethod + def backward(ctx, grad_output): + """ + Backward pass for distributed gather + + Args: + ctx (torch.autograd.FunctionContext): Context object + grad_output (torch.Tensor): Gradient of the output tensor. + Shape: (batch_size, num_local_edges, num_features) + """ + comm_plan = ctx.comm_plan + num_features = grad_output.shape[-1] + num_batches = grad_output.shape[0] + device = grad_output.device + + grad_input = torch.zeros( + num_batches, comm_plan.num_local_vertices, num_features, device=device + ) + + grad_input = OptimizedLocalScatterSumGather( + src=grad_output, + output=grad_input, + src_indices=comm_plan.local_edge_idx, + dst_indices=comm_plan.local_vertex_idx, + ) + + send_buf = grad_output[:, comm_plan.boundary_vertex_idx, :] + total_recv = sum(comm_plan.boundary_vertex_splits) + recv_buffer = torch.empty(num_batches, total_recv, num_features).to(device) + dist.all_to_all_single( + recv_buffer, + send_buf, + output_split_sizes=comm_plan.boundary_vertex_splits, + input_split_sizes=comm_plan.boundary_edge_splits, + ) + grad_input = OptimizedLocalScatterSumGather( + src=recv_buffer, + output=grad_input, + src_indices=comm_plan.boundary_edge_buffer_map, + dst_indices=comm_plan.boundary_vertex_idx, + ) + + return grad_input, None + + +class CommPlan_ScatterFunction(Function): + @staticmethod + def forward( + ctx, + local_send_tensor: torch.Tensor, + comm_plan: NCCLGraphCommPlan, + ) -> torch.Tensor: + """ + Forward pass for distributed scatter + + Args: + ctx (torch.autograd.FunctionContext): Context object + local_send_tensor (torch.Tensor): Local send tensor + comm_plan (NCCLGraphCommPlan): Communication plan + """ + assert ( + len(local_send_tensor.shape) == 3 + ), "Local send tensor must be of shape (batch_size, num_rows, num_features)" + ctx.comm_plan = comm_plan + + num_features = local_send_tensor.shape[-1] + num_batches = local_send_tensor.shape[0] + + output_tensor = torch.zeros( + num_batches, comm_plan.num_local_vertices, num_features + ).to(local_send_tensor.device) + + output_tensor = OptimizedLocalScatterSumGather( + src=local_send_tensor, + output=output_tensor, + src_indices=comm_plan.local_edge_idx, + dst_indices=comm_plan.local_vertex_idx, + ) + + total_send_rows = sum(comm_plan.boundary_edge_splits) + + send_buf = torch.zeros( + num_batches, total_send_rows, num_features, device=local_send_tensor.device + ) + + send_buf = OptimizedLocalScatterSumGather( + src=local_send_tensor, + output=send_buf, + src_indices=comm_plan.boundary_edge_idx, + dst_indices=comm_plan.boundary_edge_buffer_map, + ) + + total_recv_rows = sum(comm_plan.boundary_vertex_splits) + recv_buffer = torch.empty( + num_batches, total_recv_rows, num_features, device=local_send_tensor.device + ) + dist.all_to_all_single( + recv_buffer, + send_buf, + output_split_sizes=comm_plan.boundary_vertex_splits, + input_split_sizes=comm_plan.boundary_edge_splits, + ) + output_tensor = OptimizedLocalScatterSumGather( + src=recv_buffer, + output=output_tensor, + src_indices=comm_plan.boundary_edge_buffer_map, + dst_indices=comm_plan.boundary_vertex_idx, + ) + + return output_tensor + + @staticmethod + def backward(ctx, grad_output): + """ + Backward pass for distributed scatter + + Args: + ctx (torch.autograd.FunctionContext): Context object + grad_output (torch.Tensor): Gradient of the output tensor + """ + comm_plan = ctx.comm_plan + num_features = grad_output.shape[-1] + num_batches = grad_output.shape[0] + device = grad_output.device + num_output_rows = comm_plan.num_local_edges + + grad_input = torch.zeros( + num_batches, num_output_rows, num_features, device=device + ) + + grad_input = OptimizedLocalScatterGather( + src=grad_output, + src_indices=comm_plan.local_vertex_idx, + dst_indices=comm_plan.local_edge_idx, + output=grad_input, + ) + + num_send_rows = sum(comm_plan.boundary_vertex_splits) + send_buf_locs = torch.arange(num_send_rows, device=device) + send_buf = torch.zeros(num_batches, num_send_rows, num_features, device=device) + send_buf = OptimizedLocalScatterGather( + src=grad_output, + src_indices=comm_plan.boundary_vertex_idx, + dst_indices=send_buf_locs, + output=send_buf, + ) + total_recv_rows = sum(comm_plan.boundary_edge_splits) + recv_buffer = torch.empty( + num_batches, total_recv_rows, num_features, device=device + ) + dist.all_to_all_single( + recv_buffer, + send_buf, + output_split_sizes=comm_plan.boundary_edge_splits, + input_split_sizes=comm_plan.boundary_vertex_splits, + ) + + grad_input = OptimizedLocalScatterGather( + src=recv_buffer, + src_indices=comm_plan.boundary_edge_idx, + dst_indices=comm_plan.boundary_edge_buffer_map, + output=grad_input, + ) + + return grad_input, None + + +class GatherFunction(Function): + @staticmethod + def forward( + ctx, + local_send_tensor: torch.Tensor, + indices: torch.LongTensor, + # vertex_ranks: torch.Tensor, + edge_rank_loc: torch.Tensor, + edge_dest_ranks: torch.Tensor, + rank: int, + world_size: int, + ): + num_local_input_rows = local_send_tensor.shape[1] + + ctx.save_for_backward( + indices, + edge_rank_loc, + edge_dest_ranks, + torch.tensor(num_local_input_rows), + torch.tensor(rank), + torch.tensor(world_size), + ) + + # Since NCCL is two-sided, we need to push from local rank and pull from + # remote rank to get the global gather + + # TODO: One possible optmization is cache all these calculations + # and only do the gather when the cache is invalidated. Essentially + # if we are working with static graphs, the indices and distribution pattern + # will not change and we can cache the communication pattern. - S.Z + + # We can also pre-compute this on the data ingestion side. Might + # be worth looking to some kind of cached communication pattern store + # that can be passed to the communicator. - S.Z + + batch_size = 1 + num_features = local_send_tensor.shape[2] + + local_slice_mask = edge_rank_loc == rank + + num_local_output_rows = int(local_slice_mask.sum().item()) + + recv_tensor = torch.zeros(batch_size, num_local_output_rows, num_features).to( + local_send_tensor.device + ) + + local_indices_slice = indices[local_slice_mask.unsqueeze(0)] + local_rank_mapping = edge_rank_loc[local_slice_mask] + local_recv_tensor = edge_dest_ranks[local_slice_mask] + + # assert torch.all(local_recv_tensor == rank), local_recv_tensor + + local_indices = local_indices_slice % local_send_tensor.shape[1] + + needs_comm = (local_recv_tensor != rank).any() + + recv_tensor = OptimizedRankLocalMaskedGather( + local_send_tensor, + local_indices, + local_rank_mapping, + recv_tensor, + rank, + ) + + if needs_comm: + + recv_tensor = _nccl_alltoall_v( + local_send_tensor=local_send_tensor, + local_recv_tensor=recv_tensor, + indices=indices, + local_rank_mapping=local_recv_tensor, + edge_rank_loc=edge_rank_loc, + src_rank_loc=edge_dest_ranks, + rank=rank, + world_size=world_size, + cache=cache, + ) + + return recv_tensor + + @staticmethod + def backward(ctx, grad_output): + # We need to switch the send and recv ranks + ( + indices, + recv_ranks, + send_ranks, + # vertices_per_rank, + num_local_input_rows, + rank, + world_size, + ) = ctx.saved_tensors + + num_local_output_rows = num_local_input_rows.item() + rank = rank.item() + world_size = world_size.item() + send_tensor = grad_output + + # Now it's a scatter operation + num_features = send_tensor.shape[-1] + device = send_tensor.device + local_rank_output = torch.zeros(1, num_local_output_rows, num_features).to( + device + ) + + indices = indices.view(-1) + local_slice_mask = recv_ranks == rank + local_indices_slice = indices[local_slice_mask] + local_dest_ranks = send_ranks[local_slice_mask] + + local_rank_output = RankLocalMaskedScatter( + send_tensor, + local_rank_output, + local_indices_slice, + local_dest_ranks, + rank, + ) + + if cache is not None: + local_comm_mask = cache.scatter_local_comm_mask + else: + local_comm_mask = local_dest_ranks != rank + + send_buffer_dict = {} + if torch.any(local_comm_mask): + # These rows need to be sent to other ranks + # First aggregate these into a single buffer + + if cache is not None: + num_remote_rows = cache.scatter_num_remote_rows + remapped_ranks = cache.scatter_local_remapped_ranks + renumbered_indices = cache.scatter_renumbered_indices + receiving_ranks = cache.scatter_remote_send_to_ranks + + else: + + local_comm_indices = local_indices_slice[local_comm_mask] + local_remote_dest_mappings = local_dest_ranks[local_comm_mask] + + renumbered_indices, unique_indices, remapped_ranks = ( + RankLocalRenumberingWithMapping( + local_comm_indices, local_remote_dest_mappings + ) + ) + receiving_ranks = torch.unique(local_dest_ranks[local_comm_mask]) + num_remote_rows = len(unique_indices) + + buffer = torch.zeros(1, num_remote_rows, num_features).to(device) + buffer.scatter_add_( + 1, + renumbered_indices.view(1, -1, 1).expand(1, -1, num_features), + send_tensor[:, local_comm_mask, :], + ) + + for _recv_rank in receiving_ranks: + _recv_indices = remapped_ranks == _recv_rank + send_buffer_dict[_recv_rank.item()] = buffer[:, _recv_indices, :] + + # Now we need to receive the data from the remote ranks + + recv_buffer_dict = {} + + recv_placement = {} + + if cache is not None: + recv_placement = cache.scatter_recv_local_placement + + # Allocate the receive buffers for the communication based on the + # size of the recv_placement indices. + for key, unique_send_indices in recv_placement.items(): + num_elements = unique_send_indices.shape[0] + recv_buffer_dict[key] = torch.zeros(1, num_elements, num_features).to( + device + ) + else: + send_to_rank = send_ranks # Pedantic variable name change + all_comm_mask = send_to_rank != recv_ranks + reciever_mask = send_to_rank == rank + receive_from_remote = all_comm_mask & reciever_mask + + if torch.any(receive_from_remote): + receive_from_ranks = recv_ranks[receive_from_remote] + + for _sender in range(world_size): + if _sender == rank: + continue + if torch.any(receive_from_ranks == _sender): + _send_mask = (recv_ranks == _sender) & receive_from_remote + _send_indices = indices[_send_mask] % num_local_output_rows + # TODO: This is brittle, look into a better way to do this - S.Z + + unique_send_indices = torch.unique(_send_indices) + num_elements = unique_send_indices.shape[0] + recv_buffer_dict[_sender] = torch.zeros( + 1, num_elements, num_features + ).cuda() + recv_placement[_sender] = unique_send_indices + + recv_buffer_dict = _nccl_alltoallv_with_dict( + send_buffer_dict, recv_buffer_dict, rank, world_size + ) + for key, recv_buffer in recv_buffer_dict.items(): + local_rank_output.scatter_add_( + 1, + recv_placement[key].view(1, -1, 1).expand(1, -1, num_features), + recv_buffer, + ) + + send_tensor_grad = local_rank_output + indices_grad = None + send_ranks_grad = None + recv_ranks_grad = None + rank_grad = None + world_size_grad = None + cache_grad = None + + return ( + send_tensor_grad, + indices_grad, + send_ranks_grad, + recv_ranks_grad, + rank_grad, + world_size_grad, + cache_grad, + ) + + +class ScatterFunction(Function): + @staticmethod + def forward( + ctx, + send_tensor: torch.Tensor, + indices: torch.Tensor, + edge_src_ranks: torch.Tensor, + edge_dest_ranks: torch.Tensor, + num_local_output_rows: int, + rank: int, + world_size: int, + ) -> torch.Tensor: + + ctx.save_for_backward( + indices, + edge_src_ranks, + edge_dest_ranks, + torch.tensor(num_local_output_rows), + torch.tensor(rank), + torch.tensor(world_size), + ) + use_cache = scatter_cache is not None + if use_cache: + ctx.scatter_cache = scatter_cache + ctx.has_cache = True + else: + ctx.has_cache = False + + num_features = send_tensor.shape[-1] + device = send_tensor.device + + local_rank_output = torch.zeros(1, num_local_output_rows, num_features).to( + device + ) + + indices = indices.view(-1) + + local_edge_mask = edge_src_ranks == rank + + local_indices_slice = indices[local_edge_mask] + local_dest_ranks = edge_dest_ranks[local_edge_mask] + + local_rank_output = RankLocalMaskedScatter( + send_tensor, + local_rank_output, + local_indices_slice, + local_dest_ranks, + rank, + ) + + if use_cache: + local_comm_mask = scatter_cache.scatter_local_comm_mask + else: + local_comm_mask = local_dest_ranks != rank + + all_comm_mask = edge_src_ranks != edge_dest_ranks + reciever_mask = edge_dest_ranks == rank + receive_from_remote_mask = all_comm_mask & reciever_mask + + send_buffer_dict = {} + + if torch.any(local_comm_mask): + + if use_cache: + num_remote_rows = scatter_cache.scatter_num_remote_rows + remapped_ranks = scatter_cache.scatter_local_remapped_ranks + renumbered_indices = scatter_cache.scatter_local_renumbered_indices + receving_ranks = scatter_cache.scatter_remote_send_to_ranks + + else: + # These rows need to be sent to other ranks + # First aggregate these into a single buffer + local_comm_indices = local_indices_slice[local_comm_mask] + local_remote_dest_mappings = local_dest_ranks[local_comm_mask] + # TODO: This is very slow, look into a better way to do this - S.Z + # Uncached is slow, should look into augmenting torch functions + # to speed this up - S.Z + renumbered_indices, unique_indices, remapped_ranks = ( + RankLocalRenumberingWithMapping( + local_comm_indices, local_remote_dest_mappings + ) + ) + num_remote_rows = len(unique_indices) + receving_ranks = torch.unique(local_dest_ranks[local_comm_mask]) + + buffer = torch.zeros(1, num_remote_rows, num_features).to(device) + buffer.scatter_add_( + 1, + renumbered_indices.view(1, -1, 1).expand(1, -1, num_features), + send_tensor[:, local_comm_mask, :], + ) + + for _recv_rank in receving_ranks: + _recv_indices = remapped_ranks == _recv_rank + send_buffer_dict[_recv_rank.item()] = buffer[:, _recv_indices, :] + + recv_buffer_dict = {} + recv_placement = {} + if use_cache: + recv_placement = scatter_cache.scatter_recv_local_placement + else: + recv_placement = _get_local_unique_recv_placement( + indices, + edge_src_ranks, + receive_from_remote_mask, + num_local_output_rows, + rank, + world_size, + ) + + # Allocate the receive buffers for the communication based on the + # size of the recv_placement indices. + for key, unique_send_indices in recv_placement.items(): + num_elements = unique_send_indices.shape[0] + recv_buffer_dict[key] = torch.zeros(1, num_elements, num_features).to( + device + ) + recv_buffer_dict = _nccl_alltoallv_with_dict( + send_buffer_dict, recv_buffer_dict, rank, world_size + ) + for key, recv_buffer in recv_buffer_dict.items(): + local_rank_output.scatter_add_( + 1, + recv_placement[key].view(1, -1, 1).expand(1, -1, num_features), + recv_buffer, + ) + return local_rank_output + + @staticmethod + def backward(ctx, grad_output): + # We need to switch the send and recv ranks + indices, recv_ranks, send_ranks, num_input_rows, rank, world_size = ( + ctx.saved_tensors + ) + + local_mask = recv_ranks == rank + if ctx.has_cache: + cache: NCCLScatterCache = ctx.scatter_cache + num_local_output_rows = cache.gather_num_output_rows + + else: + rank = int(rank.item()) + world_size = int(world_size.item()) + + indices = indices.view(1, -1) + + # Now it's a gather operation + + num_local_output_rows = int(local_mask.sum().item()) + + batch_size = 1 + num_features = grad_output.shape[2] + + recv_tensor = torch.zeros(batch_size, num_local_output_rows, num_features).to( + grad_output.device + ) + + local_indices_slice = indices[0][local_mask] + local_rank_mapping = send_ranks[local_mask] + + local_indices = local_indices_slice % grad_output.shape[1] + + if len(local_indices_slice) > 0: + + recv_tensor[:, local_rank_mapping == rank, :] = RankLocalMaskedGather( + grad_output, local_indices, local_rank_mapping, rank + ) + + recv_tensor = _nccl_alltoall_v( + local_send_tensor=grad_output, + local_recv_tensor=recv_tensor, + indices=indices, + local_rank_mapping=local_rank_mapping, + edge_rank_loc=send_ranks, + src_rank_loc=recv_ranks, + rank=rank, + world_size=world_size, + cache=cache, + ) + + # NOTE: even if the inputs are non-tensors, the number of backward outputs + # must be the same as the number of inputs. + send_tensor_grad = recv_tensor + indices_grad = None + send_ranks_grad = None + recv_ranks_grad = None + num_local_output_rows_grad = None + rank_grad = None + world_size_grad = None + scatter_cache_grad = None + + return ( + send_tensor_grad, + indices_grad, + send_ranks_grad, + recv_ranks_grad, + num_local_output_rows_grad, + rank_grad, + world_size_grad, + scatter_cache_grad, + ) diff --git a/DGraph/distributed/nccl/alltoallv_impl.py b/DGraph/distributed/nccl/alltoallv_impl.py index 060c390..d549cb9 100644 --- a/DGraph/distributed/nccl/alltoallv_impl.py +++ b/DGraph/distributed/nccl/alltoallv_impl.py @@ -159,3 +159,23 @@ def _nccl_alltoallv_with_dict(send_buffer_dict, recv_buffer_dict, rank, world_si for key, recv_buffer in recv_buffer_dict.items(): recv_buffer_dict[key] = recv_buffer.float() return recv_buffer_dict + + +def torch_alltoallv_with_comm_map(contiguous_send_tensor: torch.Tensor, + contiguous_recv_tensor: torch.Tensor, + send_comm_map: torch.Tensor, + recv_comm_map: torch.Tensor, + rank: int, + world_size: int): + assert len(send_comm_map) == world_size, "Send comm map should be of size world_size" + assert len(recv_comm_map) == world_size, "Recv comm map should be of size world_size" + + send_sizes = send_comm_map.tolist() + recv_sizes = recv_comm_map.tolist() + + send_list = list(torch.split(contiguous_send_tensor, send_sizes, dim=1)) + recv_list = list(torch.split(contiguous_recv_tensor, recv_sizes, dim=1)) + + dist.all_to_all(recv_list, send_list) + return recv_list + diff --git a/experiments/Benchmarks/README.md b/experiments/Benchmarks/README.md index cabbb68..e60f14c 100644 --- a/experiments/Benchmarks/README.md +++ b/experiments/Benchmarks/README.md @@ -34,7 +34,8 @@ class ScatterGraphData: data_rank_mapping: torch.Tensor # Where each data is located edge_rank_placement: torch.Tensor # Where each edge is located edge_dst_rank: torch.Tensor # Rank of the destination vertex of each edge - edge_indices: torch.Tensor # Vertex index of the destination vertex of each num_local_vertices: int # Number of vertices on each rank + edge_indices: torch.Tensor # Vertex index of the destination vertex of each edge + num_local_vertices: int # Number of vertices on each rank ``` *** New communication patterns can be added to the benchmarking code by creating new instances of these dataclasses. *** diff --git a/experiments/OGB-LSC/CacheGenerator.py b/experiments/OGB-LSC/CacheGenerator.py new file mode 100644 index 0000000..e99ad06 --- /dev/null +++ b/experiments/OGB-LSC/CacheGenerator.py @@ -0,0 +1,184 @@ +# Copyright (c) 2014-2025, Lawrence Livermore National Security, LLC. +# Produced at the Lawrence Livermore National Laboratory. +# Written by the LBANN Research Team (B. Van Essen, et al.) listed in +# the CONTRIBUTORS file. See the top-level LICENSE file for details. +# +# LLNL-CODE-697807. +# All rights reserved. +# +# This file is part of LBANN: Livermore Big Artificial Neural Network +# Toolkit. For details, see http://software.llnl.gov/LBANN or +# https://github.com/LBANN and https://github.com/LLNL/LBANN. +# +# SPDX-License-Identifier: (Apache-2.0) + +import torch + +import os.path as osp +from DGraph.distributed.nccl._nccl_cache import ( + NCCLGatherCacheGenerator, + NCCLScatterCacheGenerator, +) + + +def get_cache( + src_gather_cache, + dest_gather_cache, + dest_scatter_cache, + src_gather_cache_file, + dest_gather_cache_file, + dest_scatter_cache_file, + rank, + world_size, + src_indices, + dest_indices, + edge_location, + src_data_mappings, + dest_data_mappings, + num_src_rows, + num_dest_rows, +): + """ """ + if src_gather_cache is None: + + _src_gather_cache = NCCLGatherCacheGenerator( + indices=src_indices, + edge_placement=edge_location, + edge_dest_ranks=src_data_mappings, + num_input_rows=num_src_rows, + rank=rank, + world_size=world_size, + ) + + torch.save(_src_gather_cache, src_gather_cache_file) + else: + _src_gather_cache = src_gather_cache + + if dest_scatter_cache is None: + _dest_scatter_cache = NCCLScatterCacheGenerator( + indices=src_indices, + edge_placement=edge_location, + edge_dest_ranks=src_data_mappings, + num_output_rows=num_src_rows, + rank=rank, + world_size=world_size, + ) + + torch.save(_dest_scatter_cache, dest_scatter_cache_file) + else: + _dest_scatter_cache = dest_scatter_cache + + if dest_gather_cache is None: + _dest_gather_cache = NCCLGatherCacheGenerator( + indices=dest_indices, + edge_placement=edge_location, + edge_dest_ranks=dest_data_mappings, + num_input_rows=num_dest_rows, + rank=rank, + world_size=world_size, + ) + + torch.save(_dest_gather_cache, dest_gather_cache_file) + else: + _dest_gather_cache = dest_gather_cache + + # Unit tests + + return _src_gather_cache, _dest_scatter_cache, _dest_gather_cache + + +if __name__ == "__main__": + from fire import Fire + from functools import partial + from config import SyntheticDatasetConfig + + # Use this script to generate the caches prior to running the main training script + # This is useful because cache generation can take a long time and could cause issues + # with timeouts on some systems. + + def main(dataset): + assert dataset in ["synthetic", "mag240m"] + if dataset == "synthetic": + from synthetic.synthetic_dataset import HeterogeneousDataset as Dataset + + synthetic_config = SyntheticDatasetConfig() + graph_dataset = partial( + Dataset, + num_papers=synthetic_config.num_papers, + num_authors=synthetic_config.num_authors, + num_institutions=synthetic_config.num_institutions, + num_features=synthetic_config.num_features, + num_classes=synthetic_config.num_classes, + ) + elif dataset == "mag240m": + from mag240m.DGraph_MAG240M import DGraph_MAG240M as Dataset + + graph_dataset = partial(Dataset, data_dir="data/MAG240M") + + rank = 0 + world_size = 4 + COMM = type( + "dummy_comm", + (object,), + {"get_rank": lambda self: rank, "get_world_size": lambda self: world_size}, + ) + comm = COMM() + + dataset = graph_dataset( + comm=comm, + ) + + dataset = dataset.add_batch_dimension() + dataset = dataset.to("cpu") + + xs, edge_indices, edge_types, rank_mappings = dataset[0] + + # for simulated_rank in range(world_size): + simulated_rank = 0 + for simulated_rank in [0, 1]: + rel = 0 + + for edge_index, edge_type, rank_mapping in zip( + edge_indices, edge_types, rank_mappings + ): + if rel != 3: + rel += 1 + continue + print(f"Edge index shape: {edge_index.shape}") + print(f"Edge type shape: {edge_type}") + print(f"Rank mapping shape: {rank_mapping[0].shape}") + print(f"Rank mapping shape: {rank_mapping[1].shape}") + + get_cache( + src_gather_cache=None, + dest_gather_cache=None, + dest_scatter_cache=None, + src_gather_cache_file=f"test_cache/synthetic_src_gather_cache_{rel}_{simulated_rank}_{world_size}.pt", + dest_gather_cache_file=f"test_cache/synthetic_dest_gather_cache_{rel}_{simulated_rank}_{world_size}.pt", + dest_scatter_cache_file=f"test_cache/synthetic_dest_scatter_cache_{rel}_{simulated_rank}_{world_size}.pt", + rank=simulated_rank, + world_size=world_size, + src_indices=edge_index[:, 0], + dest_indices=edge_index[:, 1], + edge_location=rank_mapping[0], + src_data_mappings=rank_mapping[0], + dest_data_mappings=rank_mapping[1], + num_src_rows=xs[edge_type[0]].shape[1], + num_dest_rows=xs[edge_type[1]].shape[1], + ) + + rel += 1 + rel = 3 + synthetic_scatter_cache_1 = torch.load( + f"test_cache/synthetic_dest_scatter_cache_{rel}_1_{world_size}.pt", + weights_only=False, + ) + synthetic_scatter_cache_0 = torch.load( + f"test_cache/synthetic_dest_scatter_cache_{rel}_0_{world_size}.pt", + weights_only=False, + ) + + print(synthetic_scatter_cache_1.scatter_recv_local_placement) + print(synthetic_scatter_cache_0.scatter_recv_local_placement) + + Fire(main) diff --git a/experiments/OGB-LSC/README.md b/experiments/OGB-LSC/README.md new file mode 100644 index 0000000..4a35ad5 --- /dev/null +++ b/experiments/OGB-LSC/README.md @@ -0,0 +1,26 @@ +# Directed Heterogeneous Graphs on DGraph + +`DGraph` supports arbitrary graph types, GNNs, and structures for distributed training. This example shows how to use `DGraph` to train a Relational Graph Attention Network ([RGAT](https://arxiv.org/abs/1703.06103)) on the [OGB-LSC MAG240M](https://ogb.stanford.edu/docs/lsc/mag240m/) dataset, which is a large-scale heterogeneous graph with three types of nodes (paper, author, institution) and three types of edges (paper->paper, paper->author, author->institution). + +## Requirements + +- fire + +## Data preparation +The dataset is fairly large (over 100GB). Please follow the instructions in the `mag240m` folder to download and preprocess the dataset. + +## Training +To train RGAT on a synthetic dataset, run the following command: + +```bash +torchrun-hpc -N -n main.py \ +--dataset synthetic --num_papers \ +--num_authors --num_institutions -n main.py --dataset mag240m \ +--data-path +``` diff --git a/experiments/OGB-LSC/RGAT.py b/experiments/OGB-LSC/RGAT.py new file mode 100644 index 0000000..778616e --- /dev/null +++ b/experiments/OGB-LSC/RGAT.py @@ -0,0 +1,466 @@ +# Copyright (c) 2014-2025, Lawrence Livermore National Security, LLC. +# Produced at the Lawrence Livermore National Laboratory. +# Written by the LBANN Research Team (B. Van Essen, et al.) listed in +# the CONTRIBUTORS file. See the top-level LICENSE file for details. +# +# LLNL-CODE-697807. +# All rights reserved. +# +# This file is part of LBANN: Livermore Big Artificial Neural Network +# Toolkit. For details, see http://software.llnl.gov/LBANN or +# https://github.com/LBANN and https://github.com/LLNL/LBANN. +# +# SPDX-License-Identifier: (Apache-2.0) + +import torch +import torch.nn as nn +import torch.distributed as dist +from distributed_layers import DistributedBatchNorm1D +import os.path as osp +from CacheGenerator import get_cache +import os +from typing import Any, Optional, overload +from DGraph.distributed.nccl import ( + NCCLBackendEngine, + NCCLGraphCommPlan, + NCCLEdgeConditionedGraphCommPlan, +) + + +class ConvLayer(nn.Module): + def __init__(self, in_channels, out_channels): + super(ConvLayer, self).__init__() + self.conv = nn.Linear(in_channels, out_channels) + self.act = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv(x) + x = self.act(x) + return x + + +class CommAwareGAT(nn.Module): + def __init__( + self, + in_channels, + out_channels, + comm, + heads=1, + bias=True, + residual=False, + hetero=False, + ): + super(CommAwareGAT, self).__init__() + self.conv1 = nn.Linear(in_channels, out_channels, bias=False) + self.comm = comm + self.project_message = nn.Linear(2 * out_channels, 1) + self.leaky_relu = nn.LeakyReLU(0.2) + self.residual = residual + self.heads = heads + self.hetero = hetero + if self.residual: + self.res_net = nn.Linear(in_channels, out_channels, bias=False) + if bias: + self.bias = nn.Parameter(torch.empty(out_channels)) + nn.init.zeros_(self.bias) + else: + self.register_parameter("bias", None) + + @overload + def forward( + self, + x: torch.Tensor, + comm_plan: NCCLEdgeConditionedGraphCommPlan, + *, + x_j: Optional[torch.Tensor] = None, + ): ... + + @overload + def forward( + self, + x: torch.Tensor, + *, + edge_index: Any, + rank_mapping: Any, + x_j: Optional[torch.Tensor] = None, + src_gather_cache: Optional[Any] = None, + dest_gather_cache: Optional[Any] = None, + dest_scatter_cache: Optional[Any] = None, + ): ... + + def forward( + self, + x, + comm_plan=None, + *, + edge_index=None, + rank_mapping=None, + x_j=None, + src_gather_cache=None, + dest_gather_cache=None, + dest_scatter_cache=None, + ): + """Forward method that can use either a communication plan or COO format + + Args: + x: Node features tensor + comm_plan: Communication plan object (if available) + edge_index: Edge index tensor in COO format + rank_mapping: Rank mapping tensors + x_j: Optional source node features tensor (for hetero graphs) + src_gather_cache: Optional cache for source gather communication + dest_gather_cache: Optional cache for destination gather communication + dest_scatter_cache: Optional cache for destination scatter communication + + Returns: + out: Output node features tensor + """ + if comm_plan is not None: + return self._forward_comm_plan(x, comm_plan, x_j=x_j) + + return self._forward_coo( + x, + edge_index=edge_index, + rank_mapping=rank_mapping, + x_j=x_j, + src_gather_cache=src_gather_cache, + dest_gather_cache=dest_gather_cache, + dest_scatter_cache=dest_scatter_cache, + ) + + def _process_messages( + self, + h, + h_j, + ): + messages = torch.cat([h, h_j], dim=-1) + edge_scores = self.leaky_relu(self.project_message(messages)) + numerator = torch.exp(edge_scores) + return numerator + + def _calc_attention_messages( + self, + neighbor_features, + numerator, + denominator, + ): + alpha_ij = numerator / (denominator + 1e-16) + attention_messages = neighbor_features * alpha_ij + return attention_messages + + def _apply_res_and_bias(self, out, x): + if self.residual: + out = out + self.res_net(x) + if self.bias is not None: + out = out + self.bias + return out + + def _forward_comm_plan( + self, x, comm_plan: NCCLEdgeConditionedGraphCommPlan, x_j=None + ): + h = self.conv1(x) + + source_graph_plan = comm_plan.source_graph_plan + if self.hetero: + assert x_j is not None + h_j = self.conv1(x_j) + assert comm_plan.dest_graph_plan is not None + dest_graph_plan = comm_plan.dest_graph_plan + else: + h_j = h + dest_graph_plan = source_graph_plan + + assert isinstance(self.comm.__backend_engine, NCCLBackendEngine) + + h_i = self.comm.__backend_engine.gather(h, comm_plan=source_graph_plan) + + h_j = self.comm.__backend_engine.gather(h_j, comm_plan=dest_graph_plan) + + numerator = self._process_messages(h_i, h_j) + + denominator = self.comm.__backend_engine.scatter( + numerator, comm_plan=source_graph_plan + ) + + denominator = self.comm.__backend_engine.gather( + denominator, comm_plan=dest_graph_plan + ) + + attention_messages = self._calc_attention_messages(h_j, numerator, denominator) + + out = self.comm.__backend_engine.scatter( + attention_messages, comm_plan=source_graph_plan + ) + out = self._apply_res_and_bias(out, x) + + return out + + def _forward_coo( + self, + x, + edge_index, + rank_mapping, + x_j=None, + src_gather_cache=None, + dest_gather_cache=None, + dest_scatter_cache=None, + ): + h = self.conv1(x) + if self.hetero: + assert x_j is not None + h_j = self.conv1(x_j) + else: + h_j = h + + _src_indices = edge_index[:, 0, :] + _dst_indices = edge_index[:, 1, :] + _src_rank_mappings = torch.cat( + [rank_mapping[0].unsqueeze(0), rank_mapping[0].unsqueeze(0)], dim=0 + ) + _dst_rank_mappings = torch.cat( + [rank_mapping[0].unsqueeze(0), rank_mapping[1].unsqueeze(0)], dim=0 + ) + + h_i = self.comm.gather( + h, _dst_indices, _dst_rank_mappings, cache=dest_gather_cache + ) + + h_j = self.comm.gather( + h_j, _src_indices, _src_rank_mappings, cache=src_gather_cache + ) + + numerator = self._process_messages(h_i, h_j) + + denominator = self.comm.scatter( + numerator, + _dst_indices, + _dst_rank_mappings, + h.size(1), + cache=dest_scatter_cache, + ) + + denominator = self.comm.gather( + denominator, _src_indices, _src_rank_mappings, cache=dest_gather_cache + ) + + attention_messages = self._calc_attention_messages(h_j, numerator, denominator) + + out = self.comm.scatter( + attention_messages, + _dst_indices, + _dst_rank_mappings, + h.size(1), + cache=dest_scatter_cache, + ) + + out = self._apply_res_and_bias(out, x) + + return out + + +class CommAwareRGAT(nn.Module): + def __init__( + self, + in_channels, + out_channels, + hidden_channels, + num_relations, + num_layers, + heads, + comm, + dropout=0.5, + use_cache=True, + cache_file_path="rgat_cache", + ): + super(CommAwareRGAT, self).__init__() + self.layers = nn.ModuleList() + self.bn_layers = nn.ModuleList() + self.skip_layers = nn.ModuleList() + self.num_layers = num_layers + self.dropout = dropout + self.comm = comm + self.use_cache = use_cache + relation_specific_convs = [] + + for _ in range(num_relations): + relation_specific_convs.append( + CommAwareGAT( + in_channels, + hidden_channels, + heads=heads, + bias=True, + residual=True, + comm=comm, + hetero=True, + ) + ) + self.layers.append(nn.ModuleList(relation_specific_convs)) + + for _ in range(num_layers - 1): + relation_specific_convs = [] + for _ in range(num_relations): + relation_specific_convs.append( + CommAwareGAT( + hidden_channels, + hidden_channels, + heads=heads, + bias=True, + residual=True, + comm=comm, + hetero=True, + ) + ) + self.layers.append(nn.ModuleList(relation_specific_convs)) + + for _ in range(num_layers): + self.bn_layers.append(DistributedBatchNorm1D(hidden_channels)) + + self.skip_layers.append(nn.Linear(in_channels, hidden_channels)) + for _ in range(num_layers - 1): + self.skip_layers.append(nn.Linear(hidden_channels, hidden_channels)) + + self.mlp = nn.Sequential( + nn.Linear(hidden_channels, hidden_channels), + DistributedBatchNorm1D(hidden_channels), + nn.ReLU(inplace=True), + nn.Dropout(dropout), + nn.Linear(hidden_channels, out_channels), + ) + self.num_relations = num_relations + self._setup_caches(cache_file_path) + + def _setup_caches(self, cache_file_path): + num_relations = self.num_relations + comm = self.comm + # Caching for RGAT is a little bit tricky. There are three types of communication + # 1. Source gather (gathering source node features from source ranks) + # 2. Destination gather (gathering destination node features from destination ranks) + # 3. Destination scatter (scattering the messages to destination ranks) + # That gets repeated for each relation type. + # So we will have 3 * num_relations cache files + + self.src_gather_cache_files = [ + ( + f"{cache_file_path}_src_gather_cache_rel_{rel}_rank" + + f"_{comm.get_world_size()}_{comm.get_rank()}.pt" + ) + for rel in range(num_relations) + ] + + self.dest_scatter_cache_files = [ + ( + f"{cache_file_path}_dest_scatter_cache_rel_{rel}_rank" + + f"_{comm.get_world_size()}_{comm.get_rank()}.pt" + ) + for rel in range(num_relations) + ] + self.dest_gather_cache_files = [ + ( + f"{cache_file_path}_dest_gather_cache_rel_{rel}_rank" + + f"_{comm.get_world_size()}_{comm.get_rank()}.pt" + ) + for rel in range(num_relations) + ] + self.src_gather_caches = [] + self.dest_scatter_caches = [] + self.dest_gather_caches = [] + + if self.use_cache: + for caches in zip( + self.src_gather_cache_files, + self.dest_scatter_cache_files, + self.dest_gather_cache_files, + ): + ( + src_gather_cache_file, + dest_scatter_cache_file, + dest_gather_cache_file, + ) = caches + if ( + osp.exists(src_gather_cache_file) + and osp.exists(dest_scatter_cache_file) + and osp.exists(dest_gather_cache_file) + ): + _src_gather_cache = torch.load( + src_gather_cache_file, weights_only=False + ) + _dest_scatter_cache = torch.load( + dest_scatter_cache_file, weights_only=False + ) + _dest_gather_cache = torch.load( + dest_gather_cache_file, weights_only=False + ) + self.src_gather_caches.append(_src_gather_cache) + self.dest_scatter_caches.append(_dest_scatter_cache) + self.dest_gather_caches.append(_dest_gather_cache) + else: + self.src_gather_caches.append(None) + self.dest_scatter_caches.append(None) + self.dest_gather_caches.append(None) + + def forward(self, xs, adjts, edge_types, rank_mappings): + assert len(adjts) == len(edge_types) + assert len(adjts) == self.num_relations + + outs = xs + + for i in range(self.num_layers): + temp_outs = [self.skip_layers[i](outs[feat]) for feat in range(len(outs))] + for j, (edge_index, edge_type, rank_mapping) in enumerate( + zip(adjts, edge_types, rank_mappings) + ): + + src_edge_type, dst_edge_type = edge_type + if self.use_cache: + caches = get_cache( + src_gather_cache=self.src_gather_caches[j], + dest_gather_cache=self.dest_gather_caches[j], + dest_scatter_cache=self.dest_scatter_caches[j], + src_gather_cache_file=self.src_gather_cache_files[j], + dest_scatter_cache_file=self.dest_scatter_cache_files[j], + dest_gather_cache_file=self.dest_gather_cache_files[j], + rank=self.comm.get_rank(), + world_size=self.comm.get_world_size(), + src_indices=edge_index[:, 0, :], + dest_indices=edge_index[:, 1, :], + edge_location=rank_mapping[0], + src_data_mappings=rank_mapping[0], + dest_data_mappings=rank_mapping[1], + num_src_rows=outs[src_edge_type].size(1), + num_dest_rows=outs[dst_edge_type].size(1), + ) + src_gather_cache, dest_scatter_cache, dest_gather_cache = caches + else: + src_gather_cache = None + dest_scatter_cache = None + dest_gather_cache = None + + temp_outs[dst_edge_type] += self.layers[i][j]( + outs[dst_edge_type], + edge_index, + rank_mapping, + x_j=outs[src_edge_type], + src_gather_cache=src_gather_cache, + dest_gather_cache=dest_gather_cache, + dest_scatter_cache=dest_scatter_cache, + ) + outs = [ + self.bn_layers[i](temp_outs[feat]) for feat in range(len(temp_outs)) + ] + outs = [torch.relu(outs[feat]) for feat in range(len(outs))] + outs = [ + torch.dropout(outs[feat], p=self.dropout, train=self.training) + for feat in range(len(outs)) + ] + + dummy_prameters_use = bool(int(os.getenv("RGAT_DUMMY_ALL_PARAMS_USE", "0"))) + if dummy_prameters_use: + # Dummy operation to touch all outs to avoid DDP's 'unused parameters' + dummy = torch.zeros(1, device=outs[0].device, dtype=outs[0].dtype) + for t in outs: + dummy = dummy + ( + t[0].sum() * 0.0 + ) # zero-valued scalar that depends on t + outs[0][0] = outs[0][0] + dummy + + return self.mlp(outs[0]) diff --git a/experiments/OGB-LSC/Trainer.py b/experiments/OGB-LSC/Trainer.py new file mode 100644 index 0000000..b3fea9a --- /dev/null +++ b/experiments/OGB-LSC/Trainer.py @@ -0,0 +1,130 @@ +# Copyright (c) 2014-2025, Lawrence Livermore National Security, LLC. +# Produced at the Lawrence Livermore National Laboratory. +# Written by the LBANN Research Team (B. Van Essen, et al.) listed in +# the CONTRIBUTORS file. See the top-level LICENSE file for details. +# +# LLNL-CODE-697807. +# All rights reserved. +# +# This file is part of LBANN: Livermore Big Artificial Neural Network +# Toolkit. For details, see http://software.llnl.gov/LBANN or +# https://github.com/LBANN and https://github.com/LLNL/LBANN. +# +# SPDX-License-Identifier: (Apache-2.0) +import torch +from RGAT import CommAwareRGAT +from config import ModelConfig, TrainingConfig +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from distributed_layers import GetGlobalVal +import os + + +class Trainer: + def __init__(self, dataset, comm): + self.dataset = dataset + self.comm = comm + self.model_config = ModelConfig() + self.training_config = TrainingConfig() + # TODO: We need some better way to set the device but + # difficult to do that since systems have different bindings. + # self.device = torch.device(f"cuda:{comm.get_local_rank()}") + rank = comm.get_rank() + print(f"Rank {rank} using GPU {rank % torch.cuda.device_count()}") + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) + self.device = torch.device("cuda") + self.model = CommAwareRGAT( + in_channels=self.dataset.num_features, + out_channels=self.dataset.num_classes, + num_relations=self.dataset.num_relations, + hidden_channels=self.model_config.hidden_channels, + num_layers=self.model_config.num_layers, + heads=self.model_config.heads, + comm=comm, + dropout=self.model_config.dropout, + ).to(self.device) + # Enable unused-parameter detection only if requested (reduces sync errors with moderate overhead) + ddp_find_unused = bool(int(os.getenv("RGAT_DDP_FIND_UNUSED", "0"))) + self.model = DDP( + self.model, + device_ids=[rank % num_gpus], + find_unused_parameters=ddp_find_unused, + ) + self.optimizer = torch.optim.Adam( + self.model.parameters(), lr=self.training_config.lr, weight_decay=5e-4 + ) + + def prepare_data(self): + self.dataset = self.dataset.add_batch_dimension() + self.dataset = self.dataset.to(self.device) + + def train(self): + self.model.train() + + xs, edge_index, edge_type, rank_mapping = self.dataset[0] + + # Fetch once; masks/targets are static across epochs + train_mask = self.dataset.get_mask("train") + target = self.dataset.get_target("train") + + for epoch in range(1, self.training_config.epochs + 1): + # zero grads before forward to avoid dangling reduction state + self.optimizer.zero_grad(set_to_none=True) + + out = self.model(xs, edge_index, edge_type, rank_mapping) + local_train_vertices = out[:, train_mask, :].squeeze(0) + + loss = torch.nn.functional.cross_entropy( + local_train_vertices, target, reduction="sum" + ) + local_num_targets = target.size(0) + global_num_targets = GetGlobalVal(local_num_targets) + loss = loss / global_num_targets # Average the loss + + loss.backward() + self.optimizer.step() + if self.comm.get_rank() == 0: + print(f"Epoch {epoch:03d} | loss {loss.item():.4f}") + return loss.item() + + @torch.no_grad() + def evaluate(self): + self.model.eval() + + xs, edge_index, edge_type, rank_mapping = self.dataset[0] + out = self.model(xs, edge_index, edge_type, rank_mapping) + + y_pred = out.argmax(dim=-1, keepdim=True).cpu().numpy() + train_mask = self.dataset.get_mask("train").cpu().numpy() + val_mask = self.dataset.get_mask("val").cpu().numpy() + test_mask = self.dataset.get_mask("test").cpu().numpy() + y_true_train = self.dataset.get_target("train").cpu().numpy() + y_pred_val = self.dataset.get_target("val").cpu().numpy() + y_pred_test = self.dataset.get_target("test").cpu().numpy() + + train_acc = (y_pred[train_mask] == y_true_train).sum() / int(train_mask.sum()) + # Not guaranteed to have validation or test samples on every rank + num_local_val_samples = int(val_mask.sum()) + num_local_test_samples = int(test_mask.sum()) + if num_local_val_samples == 0: + val_acc = 0.0 + else: + val_acc = (y_pred[val_mask] == y_pred_val).sum().item() + val_acc = GetGlobalVal(val_acc) + + num_global_val_samples = GetGlobalVal(num_local_val_samples) + val_acc = val_acc / int(num_global_val_samples) + + if num_local_test_samples == 0: + test_acc = 0.0 + else: + test_acc = (y_pred[test_mask] == y_pred_test).sum().item() + + test_acc = GetGlobalVal(test_acc) + num_global_test_samples = GetGlobalVal(num_local_test_samples) + test_acc = test_acc / int(num_global_test_samples) + + # All ranks should have the same accuracy values + + return train_acc, val_acc, test_acc diff --git a/experiments/OGB-LSC/config.py b/experiments/OGB-LSC/config.py new file mode 100644 index 0000000..49d8164 --- /dev/null +++ b/experiments/OGB-LSC/config.py @@ -0,0 +1,45 @@ +# Copyright (c) 2014-2025, Lawrence Livermore National Security, LLC. +# Produced at the Lawrence Livermore National Laboratory. +# Written by the LBANN Research Team (B. Van Essen, et al.) listed in +# the CONTRIBUTORS file. See the top-level LICENSE file for details. +# +# LLNL-CODE-697807. +# All rights reserved. +# +# This file is part of LBANN: Livermore Big Artificial Neural Network +# Toolkit. For details, see http://software.llnl.gov/LBANN or +# https://github.com/LBANN and https://github.com/LLNL/LBANN. +# +# SPDX-License-Identifier: (Apache-2.0) + +from dataclasses import dataclass + + +@dataclass +class ModelConfig: + hidden_channels: int = 1024 + dropout: float = 0.5 + num_layers: int = 2 + heads: int = 4 + use_cache: bool = True + # Those numbers are available in the dataset classes (synthetic or mag240m) + # num_features: int = 768 + # num_relations: int = 5 + # num_classes: int = 153 + + +@dataclass +class TrainingConfig: + epochs: int = 100 + lr: float = 0.0001 + lr_step_size: int = 25 + lr_gamma: float = 0.25 + + +@dataclass +class SyntheticDatasetConfig: + num_papers: int = 2048 + num_authors: int = 512 + num_institutions: int = 16 + num_features: int = 16 + num_classes: int = 153 diff --git a/experiments/OGB-LSC/distributed_layers.py b/experiments/OGB-LSC/distributed_layers.py new file mode 100644 index 0000000..54408b1 --- /dev/null +++ b/experiments/OGB-LSC/distributed_layers.py @@ -0,0 +1,214 @@ +# Copyright (c) 2014-2025, Lawrence Livermore National Security, LLC. +# Produced at the Lawrence Livermore National Laboratory. +# Written by the LBANN Research Team (B. Van Essen, et al.) listed in +# the CONTRIBUTORS file. See the top-level LICENSE file for details. +# +# LLNL-CODE-697807. +# All rights reserved. +# +# This file is part of LBANN: Livermore Big Artificial Neural Network +# Toolkit. For details, see http://software.llnl.gov/LBANN or +# https://github.com/LBANN and https://github.com/LLNL/LBANN. +# +# SPDX-License-Identifier: (Apache-2.0) + +import torch +from torch import nn +import torch.distributed as dist +from torch.autograd import Function +from typing import Callable + + +def _compute_bn_forward(input, learned_gamma=None, learned_beta=None): + local_sum = torch.mean(input, dim=0) + global_sum = local_sum.clone() + num_rows = torch.tensor([input.size(0)], dtype=torch.float32, device=input.device) + + global_num_rows = num_rows.clone() + + dist.all_reduce(global_num_rows, op=dist.ReduceOp.SUM) + global_mean = global_sum / global_num_rows + local_var = ((input - global_mean) ** 2).sum(dim=0) + global_var = local_var.clone() + dist.all_reduce(global_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(global_var, op=dist.ReduceOp.SUM) + global_var = global_var / global_num_rows + + x_hat = (input - global_mean) / torch.sqrt(global_var + 1e-5) + if learned_gamma is not None and learned_beta is not None: + output = x_hat * learned_gamma + learned_beta + + return output, x_hat, global_mean, global_var, global_num_rows + + +def _compute_bn_backward( + grad_output, x, x_hat, mean, var, num_rows, learned_gamma=None, learned_beta=None +): + if learned_gamma is not None and learned_beta is not None: + local_dbeta = torch.sum(grad_output, dim=0) + global_dbeta = local_dbeta.clone().unsqueeze(0) + dist.all_reduce(global_dbeta, op=dist.ReduceOp.SUM) + local_dgamma = torch.sum(grad_output * x_hat, dim=0) + global_dgamma = local_dgamma.clone().unsqueeze(0) + dist.all_reduce(global_dgamma, op=dist.ReduceOp.SUM) + dx_hat = grad_output * learned_gamma + else: + dx_hat = grad_output + global_dgamma = None + global_dbeta = None + + local_dvar = torch.sum(dx_hat * (x - mean) * -0.5 * (var + 1e-5) ** 2, dim=0) + global_dvar = local_dvar.clone() + dist.all_reduce(global_dvar, op=dist.ReduceOp.SUM) + + local_dmean = torch.sum( + dx_hat * -1 / torch.sqrt(var + 1e-5), dim=0 + ) + global_dvar * torch.mean(-2 * (x - mean), dim=0) + global_dmean = local_dmean.clone() + dist.all_reduce(global_dmean, op=dist.ReduceOp.SUM) + dx = ( + (dx_hat / torch.sqrt(var + 1e-5)) + + (global_dvar * 2 * (x - mean) / num_rows) + + (global_dmean / num_rows) + ) + return dx, global_dgamma, global_dbeta + + +class DistributedBN_with_Recompute(Function): + @staticmethod + def forward(ctx, input, learned_gamma=None, learned_beta=None): + ctx.save_for_backward(input) + ctx.learned_gamma = learned_gamma + ctx.learned_beta = learned_beta + output, _, global_mean, global_var, global_num_rows = _compute_bn_forward( + input, learned_gamma, learned_beta + ) + ctx.mean = global_mean + ctx.var = global_var + ctx.input = input + ctx.num_rows = global_num_rows + return output, global_mean, global_var + + @staticmethod + def backward(ctx, grad_output, grad_mean, grad_var): + x = ctx.input + mean = ctx.mean + var = ctx.var + # recompute x_hat to save memory + x_hat = (x - mean) / torch.sqrt(var + 1e-5) + learned_gamma = ctx.learned_gamma + learned_beta = ctx.learned_beta + num_rows = ctx.num_rows + + dx, global_dgamma, global_dbeta = _compute_bn_backward( + grad_output, x, x_hat, mean, var, num_rows, learned_gamma, learned_beta + ) + + return dx, global_dgamma, global_dbeta + + +class DistributedBN_Impl(Function): + @staticmethod + def forward(ctx, input, learned_gamma=None, learned_beta=None): + output, x_hat, global_mean, global_var, global_num_rows = _compute_bn_forward( + input, learned_gamma, learned_beta + ) + + ctx.save_for_backward(x_hat) + ctx.learned_gamma = learned_gamma + ctx.learned_beta = learned_beta + ctx.mean = global_mean + ctx.var = global_var + ctx.num_rows = global_num_rows + ctx.input = input + ctx.x_hat = x_hat + return output, global_mean, global_var + + @staticmethod + def backward(ctx, grad_output, grad_mean, grad_var): + + learned_gamma = ctx.learned_gamma + learned_beta = ctx.learned_beta + mean = ctx.mean + var = ctx.var + x_hat = ctx.x_hat + num_rows = ctx.num_rows + x = ctx.input + dx, global_dgamma, global_dbeta = _compute_bn_backward( + grad_output, x, x_hat, mean, var, num_rows, learned_gamma, learned_beta + ) + + return dx, global_dgamma, global_dbeta + + +class DistributedBatchNorm1D(nn.Module): + def __init__( + self, + num_features, + eps=1e-5, + momentum=0.1, + affine=True, + track_running_stats=True, + recompute=False, + ): + super(DistributedBatchNorm1D, self).__init__() + if affine: + self.gamma = nn.Parameter(torch.ones(1, num_features)) + self.beta = nn.Parameter(torch.zeros(1, num_features)) + else: + self.register_parameter("gamma", None) + self.register_parameter("beta", None) + self.eps = eps + self.momentum = momentum + self.track_running_stats = track_running_stats + if self.track_running_stats: + self.register_buffer("running_mean", torch.zeros(1, num_features)) + self.register_buffer("running_var", torch.ones(1, num_features)) + self.register_buffer( + "num_batches_tracked", torch.tensor(0, dtype=torch.long) + ) + else: + self.register_parameter("running_mean", None) + self.register_parameter("running_var", None) + self.register_parameter("num_batches_tracked", None) + self.recompute = recompute + if recompute: + self.bn: Callable = DistributedBN_with_Recompute.apply + else: + self.bn: Callable = DistributedBN_Impl.apply + + def forward(self, x): + if x.dim() == 3: + assert x.size(0) == 1, "only mini-batch size 1 is supported" + x = x.squeeze(0) + elif x.dim() != 2: + raise ValueError("Expected 2D or 3D input (got {}D input)".format(x.dim())) + + if self.training: + if self.track_running_stats: + self.num_batches_tracked += 1 + y, mean, var = self.bn(x, self.gamma, self.beta) + + if self.track_running_stats: + with torch.no_grad(): + self.running_mean = ( + 1 - self.momentum + ) * self.running_mean + self.momentum * mean + self.running_var = ( + 1 - self.momentum + ) * self.running_var + self.momentum * var + else: + y = (x - self.running_mean) / torch.sqrt(self.running_var + self.eps) + if self.gamma is not None and self.beta is not None: + y = y * self.gamma + self.beta + + if y.dim() == 2: + y = y.unsqueeze(0) + return y + + +def GetGlobalVal(local_val): + """Get the global sum of a local value across all ranks.""" + global_val = torch.tensor([local_val]).cuda() + dist.all_reduce(global_val, op=dist.ReduceOp.SUM) + return global_val.item() diff --git a/experiments/OGB-LSC/mag240m/DGraph_MAG240M.py b/experiments/OGB-LSC/mag240m/DGraph_MAG240M.py new file mode 100644 index 0000000..9b6313a --- /dev/null +++ b/experiments/OGB-LSC/mag240m/DGraph_MAG240M.py @@ -0,0 +1,476 @@ +# Copyright (c) 2014-2025, Lawrence Livermore National Security, LLC. +# Produced at the Lawrence Livermore National Laboratory. +# Written by the LBANN Research Team (B. Van Essen, et al.) listed in +# the CONTRIBUTORS file. See the top-level LICENSE file for details. +# +# LLNL-CODE-697807. +# All rights reserved. +# +# This file is part of LBANN: Livermore Big Artificial Neural Network +# Toolkit. For details, see http://software.llnl.gov/LBANN or +# https://github.com/LBANN and https://github.com/LLNL/LBANN. +# +# SPDX-License-Identifier: (Apache-2.0) +from ogb.lsc import MAG240MDataset +import torch +from typing import Optional, Tuple +from torch_sparse import SparseTensor +import numpy as np +from tqdm import tqdm +import os.path as osp +from DGraph.Communicator import Communicator +from DGraph.distributed.nccl import NCCLGraphCommPlan, COO_to_NCCLCommPlan + + +def get_col_slice(x, start_row_idx, end_row_idx, start_col_idx, end_col_idx): + """Obtained from: + https://github.com/snap-stanford/ogb/blob/master/examples/lsc/mag240m/rgnn.py + """ + outs = [] + chunk = 100000 + for i in tqdm(range(start_row_idx, end_row_idx, chunk)): + j = min(i + chunk, end_row_idx) + outs.append(x[i:j, start_col_idx:end_col_idx].copy()) + return np.concatenate(outs, axis=0) + + +def save_col_slice( + x_src, x_dst, start_row_idx, end_row_idx, start_col_idx, end_col_idx +): + """Obtained from: + https://github.com/snap-stanford/ogb/blob/master/examples/lsc/mag240m/rgnn.py + """ + assert x_src.shape[0] == end_row_idx - start_row_idx + assert x_src.shape[1] == end_col_idx - start_col_idx + chunk, offset = 100000, start_row_idx + for i in tqdm(range(0, end_row_idx - start_row_idx, chunk)): + j = min(i + chunk, end_row_idx - start_row_idx) + x_dst[offset + i : offset + j, start_col_idx:end_col_idx] = x_src[i:j] + + +def get_rank_mappings(num_nodes, world_size, rank): + nodes_per_rank = num_nodes // world_size + print(f"Rank {rank}: nodes_per_rank = {nodes_per_rank}") + # Don't use uint8 if world_size > 256 + # Doing this to save memory + if world_size > 256: + raise ValueError("world_size > 256 not supported yet") + rank_mappings = torch.zeros(num_nodes, dtype=torch.uint8) + for r in range(world_size): + start = r * nodes_per_rank + end = (r + 1) * nodes_per_rank if r != world_size - 1 else num_nodes + rank_mappings[start:end] = r + return rank_mappings + + +def edge_mapping_from_vertex_mapping(edge_index, src_rank_mappings, dst_rank_mappings): + # directed edges, so edge_index[0] -> edge_index[1] + src_indices = edge_index[0] + dest_indices = edge_index[1] + # We put the edge on the rank where the destination vertex is located + # Since heterogeneous graphs have different rank mappings for different + # vertex types. + src_data_mappings = src_rank_mappings[src_indices] + dest_data_mappings = dst_rank_mappings[dest_indices] + return (src_data_mappings, dest_data_mappings) + + +def get_edge_mappings(src_indices, dst_indices, rank_mappings): + edge_mappings = torch.zeros_like(src_indices) + # The edges are mapped to the rank of the destination node + # Because that is the accumulation rank + edge_mappings = rank_mappings[dst_indices] + return edge_mappings + + +def _generate_features_from_paper_features( + out: np.memmap, + num_nodes: int, + num_papers: int, + paper_feat: np.ndarray, + edge_index: np.ndarray, + num_features: int, +): + + row, col = torch.from_numpy(edge_index) + adj = SparseTensor( + row=row, col=col, sparse_sizes=(num_nodes, num_papers), is_sorted=True + ) + + dim_chunk_size = 64 + + for i in tqdm(range(0, num_features, dim_chunk_size)): + j = min(i + dim_chunk_size, num_features) + inputs = get_col_slice( + paper_feat, + start_row_idx=0, + end_row_idx=num_papers, + start_col_idx=i, + end_col_idx=j, + ) + inputs = torch.from_numpy(inputs) + out_ = adj.matmul(inputs, reduce="mean").numpy() # type: ignore + del inputs + save_col_slice( + x_src=out_, + x_dst=out, + start_row_idx=0, + end_row_idx=num_nodes, + start_col_idx=i, + end_col_idx=j, + ) + del out_ + out.flush() + + +class DGraph_MAG240M: + + # data_dir must be the location where all ranks can access + def __init__( + self, + comm: Communicator, + data_dir: str = "data/MAG240M", + paper_rank_mappings: Optional[torch.Tensor] = None, + author_rank_mappings: Optional[torch.Tensor] = None, + institution_rank_mappings: Optional[torch.Tensor] = None, + ): + self.rank = comm.get_rank() + self.world_size = comm.get_world_size() + self.comm = comm + self.dataset = MAG240MDataset(root=data_dir) + self.num_papers = self.dataset.num_papers + self.num_authors = self.dataset.num_authors + self.num_institutions = self.dataset.num_institutions + # self.num_classes = self.dataset.num_classes + self.paper_rank_mappings = ( + paper_rank_mappings + if paper_rank_mappings is not None + else get_rank_mappings(self.num_papers, self.world_size, self.rank) + ) + self.author_rank_mappings = ( + author_rank_mappings + if author_rank_mappings is not None + else get_rank_mappings(self.num_authors, self.world_size, self.rank) + ) + self.institution_rank_mappings = ( + institution_rank_mappings + if institution_rank_mappings is not None + else get_rank_mappings(self.num_institutions, self.world_size, self.rank) + ) + + # authors -> paper + self.write_mappings = get_edge_mappings( + torch.from_numpy(self.dataset.edge_index("author", "paper")[0]), + torch.from_numpy(self.dataset.edge_index("author", "paper")[1]), + self.paper_rank_mappings, + ) + + # author -> institution + self.write_mappings_author_institution = get_edge_mappings( + torch.from_numpy(self.dataset.edge_index("author", "institution")[0]), + torch.from_numpy(self.dataset.edge_index("author", "institution")[1]), + self.institution_rank_mappings, + ) + + self.train_mask = self.dataset.get_idx_split("train") + self.val_mask = self.dataset.get_idx_split("valid") + self.test_mask = self.dataset.get_idx_split("test-dev") + + local_papers_mask = self.paper_rank_mappings == self.rank + local_authors_mask = self.author_rank_mappings == self.rank + local_institutions_mask = self.institution_rank_mappings == self.rank + self.num_local_papers = int(local_papers_mask.sum()) + + self.generate_feature_data() + + self.paper_features = torch.from_numpy( + self.dataset.paper_feat[local_papers_mask] + ) + path = self.dataset.dir + self.author_features = torch.from_numpy( + np.memmap( + filename=path + "/author_feat.npy", + mode="r", + dtype=np.float16, + shape=(self.num_authors, self.num_features), + )[local_authors_mask] + ) + self.institution_features = torch.from_numpy( + np.memmap( + filename=path + "/institution_feat.npy", + mode="r", + dtype=np.float16, + shape=(self.num_institutions, self.num_features), + )[local_institutions_mask] + ) + self.y = torch.from_numpy(self.dataset.paper_label) + + self.paper_2_paper_edges = torch.from_numpy( + self.dataset.edge_index("paper", "cites", "paper") + ) + ( + paper_2_paper_src_data_mappings, + paper_2_paper_dest_data_mappings, + ) = edge_mapping_from_vertex_mapping( + edge_index=self.paper_2_paper_edges, + src_rank_mappings=self.paper_rank_mappings, + dst_rank_mappings=self.paper_rank_mappings, + ) + self.paper_src_data_mappings = paper_2_paper_src_data_mappings + self.paper_dest_data_mappings = paper_2_paper_dest_data_mappings + + self.author_2_paper_edges = torch.from_numpy( + self.dataset.edge_index("author", "writes", "paper") + ) + ( + author_2_paper_src_data_mappings, + author_2_paper_dest_data_mappings, + ) = edge_mapping_from_vertex_mapping( + edge_index=self.author_2_paper_edges, + src_rank_mappings=self.author_rank_mappings, + dst_rank_mappings=self.paper_rank_mappings, + ) + self.author_2_paper_src_data_mappings = author_2_paper_src_data_mappings + self.author_2_paper_dest_data_mappings = author_2_paper_dest_data_mappings + + self.author_2_institution_edges = torch.from_numpy( + self.dataset.edge_index("author", "institution") + ) + ( + author_2_institution_src_data_mappings, + author_2_institution_dest_data_mappings, + ) = edge_mapping_from_vertex_mapping( + edge_index=self.author_2_institution_edges, + src_rank_mappings=self.author_rank_mappings, + dst_rank_mappings=self.institution_rank_mappings, + ) + + self.author_2_institution_src_data_mappings = ( + author_2_institution_src_data_mappings + ) + self.author_2_institution_dest_data_mappings = ( + author_2_institution_dest_data_mappings + ) + + @property + def num_features(self) -> int: + # 768 + return self.dataset.num_paper_features + + @property + def num_classes(self) -> int: + # 153 + return self.dataset.num_classes + + @property + def num_relations(self) -> int: + # paper -> paper + # paper -> author + # author -> paper + # author -> institution + # institution -> author + return 5 + + def generate_feature_data(self): + dataset = self.dataset + # This function emulates the author and institute features generation steps here + # https://github.com/snap-stanford/ogb/blob/61e9784ca76edeaa6e259ba0f836099608ff0586/examples/lsc/mag240m/rgnn.py#L82 + + # Generate author features + # Mag240M author features are generated from paper features + num_authors = dataset.num_authors + num_papers = dataset.num_papers + path = dataset.dir + paper_feat = dataset.paper_feat + + # Only one rank must do this work + if self.rank == 0: + if not osp.exists(path + "/author_feat.npy"): + print("Generating author features") + author_feat = np.memmap( + filename=path + "/author_feat.npy", + mode="w+", + dtype=np.float16, + shape=(num_authors, self.num_features), + ) + _generate_features_from_paper_features( + out=author_feat, + num_nodes=num_authors, + num_papers=num_papers, + paper_feat=paper_feat, + edge_index=dataset.edge_index("author", "paper"), + num_features=self.num_features, + ) + + if not osp.exists(path + "/institution_feat.npy"): + print("Generating institution features") + # Generate institution features + num_institutions = dataset.num_institutions + institution_feat = np.memmap( + filename=path + "/institution_feat.npy", + mode="w+", + dtype=np.float16, + shape=(num_institutions, self.num_features), + ) + _generate_features_from_paper_features( + out=institution_feat, + num_nodes=num_authors, + num_papers=num_institutions, + paper_feat=paper_feat, + edge_index=dataset.edge_index("author", "institution"), + num_features=self.num_features, + ) + self.comm.barrier() + + # Make sure all ranks can see the generated files + if not osp.exists(path + "/author_feat.npy"): + raise FileNotFoundError("author_feat.npy not found") + if not osp.exists(path + "/institution_feat.npy"): + raise FileNotFoundError("institution_feat.npy not found") + self.comm.barrier() + + print("Data processing complete") + + # Same as synthetic? + def get_vertex_rank_mask(self, mask_type: str) -> Tuple[torch.Tensor, torch.Tensor]: + if mask_type == "train": + global_int_mask = self.train_mask + elif mask_type == "val": + global_int_mask = self.val_mask + elif mask_type == "test": + global_int_mask = self.test_mask + else: + raise ValueError(f"Invalid mask type: {mask_type}") + + # Get the ranks of the vertices + # paper_vertex_rank_mapping -> vector of size num_papers, + # where each entry is the location / rank of the vertex + paper_rank_mappings = self.paper_rank_mappings.to(global_int_mask.device) + vertex_ranks = paper_rank_mappings[global_int_mask] + # vertex_ranks is location of the vertices in the global_int_mask + vertex_ranks_mask = vertex_ranks == self.rank + return global_int_mask, vertex_ranks_mask + + # Same as synthetic? + def get_mask(self, mask_type: str) -> torch.Tensor: + + global_int_mask, vertex_ranks_mask = self.get_vertex_rank_mask(mask_type) + local_int_mask = global_int_mask[vertex_ranks_mask] + local_int_mask = local_int_mask % self.num_local_papers + return local_int_mask + + # Same as synthetic? + def get_target(self, _type: str) -> torch.Tensor: + global_int_mask, vertex_ranks_mask = self.get_vertex_rank_mask(_type) + + global_training_targets = self.y[:, global_int_mask.squeeze(0)] + local_training_targets = global_training_targets[vertex_ranks_mask] + + return local_training_targets + + def __len__(self): + return 0 + + # Same as synthetic? + def add_batch_dimension(self): + """Add a batch dimension to all tensors. This is particularly useful + because we only have one graph and DGraph is built to handle batches of graphs. + We want to do this here because this allows us to avoid copying the data + and requiring a data loader. + """ + self.paper_features = self.paper_features.unsqueeze(0) + self.author_features = self.author_features.unsqueeze(0) + self.institution_features = self.institution_features.unsqueeze(0) + self.y = self.y.unsqueeze(0) + self.train_mask = self.train_mask.unsqueeze(0) + self.val_mask = self.val_mask.unsqueeze(0) + self.test_mask = self.test_mask.unsqueeze(0) + self.paper_2_paper_edges = self.paper_2_paper_edges.unsqueeze(0) + self.author_2_paper_edges = self.author_2_paper_edges.unsqueeze(0) + self.author_2_institution_edges = self.author_2_institution_edges.unsqueeze(0) + + return self + + # Same as synthetic? + def to(self, device): + """Move the dataset tensors to the specified device. + We want to do this here because this allows us to avoid + copying the data when the different individual tensors are + accessed. + """ + self.paper_features = self.paper_features.to(device, dtype=torch.float32) + self.author_features = self.author_features.to(device, dtype=torch.float32) + self.institution_features = self.institution_features.to( + device, dtype=torch.float32 + ) + self.y = self.y.to(device) + self.train_mask = self.train_mask.to(device) + self.val_mask = self.val_mask.to(device) + self.test_mask = self.test_mask.to(device) + self.paper_2_paper_edges = self.paper_2_paper_edges.to(device) + self.author_2_paper_edges = self.author_2_paper_edges.to(device) + self.author_2_institution_edges = self.author_2_institution_edges.to(device) + + return self + + def __getitem__(self, idx): + # There are 5 relations: + # paper -> paper + # paper -> author + # author -> paper + # author -> institution + # institution -> author + edge_index = [ + self.paper_2_paper_edges, + self.author_2_paper_edges, + self.author_2_paper_edges.flip(self.author_2_paper_edges.dim() - 2), + self.author_2_institution_edges, + self.author_2_institution_edges.flip( + self.author_2_institution_edges.dim() - 2 + ), + ] + # Locations of the edges + rank_mappings = [ + [self.paper_src_data_mappings, self.paper_dest_data_mappings], + [ + self.author_2_paper_src_data_mappings, + self.author_2_paper_dest_data_mappings, + ], + [ + self.author_2_paper_dest_data_mappings, + self.author_2_paper_src_data_mappings, + ], + [ + self.author_2_institution_src_data_mappings, + self.author_2_institution_dest_data_mappings, + ], + [ + self.author_2_institution_dest_data_mappings, + self.author_2_institution_src_data_mappings, + ], + ] + edge_type = [(0, 0), (1, 0), (0, 1), (1, 2), (2, 1)] + features = [ + self.paper_features, + self.author_features, + self.institution_features, + ] + return (features, edge_index, edge_type, rank_mappings) + + +if __name__ == "__main__": + import fire + + def main(data_dir: str = "data/MAG240M"): + rank = 0 + world_size = 64 + # Python is so weird haha + COMM = type( + "dummy_comm", + (object,), + {"get_rank": lambda self: rank, "get_world_size": lambda self: world_size}, + ) + comm = COMM() + dgraph = DGraph_MAG240M(comm, data_dir=data_dir) + + fire.Fire(main) diff --git a/experiments/OGB-LSC/mag240m/README.md b/experiments/OGB-LSC/mag240m/README.md new file mode 100644 index 0000000..5be12e9 --- /dev/null +++ b/experiments/OGB-LSC/mag240m/README.md @@ -0,0 +1,29 @@ +# Processing OGB-LSC MAG240M Dataset + +This directory contains the code to preprocess and load the OGB-LSC MAG240M dataset to use with DGraph. + +## Prerequisites + +Make sure you have the following packages installed: +- `torch` +- `torch_geometric` +- `ogb` +- `torch_sparse` +- `numpy` +- `tqdm` +- `fire` + +## Preprocessing the dataset +The MAG240M dataset is a fairly large graph dataset and requires some preprocessing before it can be used with DGraph, and takes a while to process. The following script processes the dataset and saves the processed data in a directory. + +```bash +python DGraph_MAG240M.py --data_dir +``` + +Make sure to replace `` with the path where you want to store the processed data. The script will download the dataset if it is not already present in the specified directory. The processed data will be saved in the same directory. + +The processing machine requires at least `128GB` of RAM to process the dataset. + + + + diff --git a/experiments/OGB-LSC/main.py b/experiments/OGB-LSC/main.py new file mode 100644 index 0000000..35deab3 --- /dev/null +++ b/experiments/OGB-LSC/main.py @@ -0,0 +1,120 @@ +# Copyright (c) 2014-2025, Lawrence Livermore National Security, LLC. +# Produced at the Lawrence Livermore National Laboratory. +# Written by the LBANN Research Team (B. Van Essen, et al.) listed in +# the CONTRIBUTORS file. See the top-level LICENSE file for details. +# +# LLNL-CODE-697807. +# All rights reserved. +# +# This file is part of LBANN: Livermore Big Artificial Neural Network +# Toolkit. For details, see http://software.llnl.gov/LBANN or +# https://github.com/LBANN and https://github.com/LLNL/LBANN. +# +# SPDX-License-Identifier: (Apache-2.0) +import fire +import torch +from functools import partial +import os.path as osp +import DGraph.Communicator as Comm +from Trainer import Trainer +from config import ModelConfig +import torch.distributed as dist + + +def main( + comm_type: str = "nccl", + dataset: str = "synthetic", + num_papers: int = 2048, + num_authors: int = 512, + num_institutions: int = 16, + paper_rank_mapping_file: str = "", + author_rank_mapping_file: str = "", + institution_rank_mapping_file: str = "", + data_dir: str = "mag240m/data/MAG240M", +): + """Main function to run DGraph experiments on OGB-LSC datasets. + + Args: + comm_type (str): Type of communicator to use. Options are 'nccl' and + 'nvshmem'. Default is 'nccl'. + dataset (str): Dataset to use. Options are 'synthetic' and 'mag240m'. + Default is 'synthetic'. + num_papers (int): Number of paper nodes to use in the synthetic dataset. + Default is 2048. + num_authors (int): Number of author nodes to use in the synthetic dataset. + Default is 512. + num_institutions (int): Number of institution nodes to use in the synthetic + dataset. Default is 16. + paper_rank_mapping_file (str): Path to the paper rank mapping file for + mag240m dataset. Default is ''. + author_rank_mapping_file (str): Path to the author rank mapping file for + mag240m dataset. Default is not set. + institution_rank_mapping_file (str): Path to the institution rank mapping + file for mag240m dataset. Default is not set. + data_dir (str): Path to the mag240m dataset directory. Default is + 'mag240m/data/MAG240M'. + """ + assert dataset in ["synthetic", "mag240m"] + if dataset == "synthetic": + from synthetic.synthetic_dataset import HeterogeneousDataset as Dataset + + graph_dataset = partial( + Dataset, + num_papers=num_papers, + num_authors=num_authors, + num_institutions=num_institutions, + num_features=ModelConfig().num_features, + num_classes=ModelConfig().num_classes, + ) + + elif dataset == "mag240m": + from mag240m.DGraph_MAG240M import DGraph_MAG240M as Dataset + + paper_rank_mapping = None + if len(paper_rank_mapping_file) > 0: + assert osp.exists(paper_rank_mapping_file) + paper_rank_mapping = torch.load(paper_rank_mapping_file, weights_only=False) + + author_rank_mapping = None + if len(author_rank_mapping_file) > 0: + assert osp.exists(author_rank_mapping_file) + author_rank_mapping = torch.load(author_rank_mapping_file, weights_only=False) + + institution_rank_mapping = None + if len(institution_rank_mapping_file) > 0: + assert osp.exists(institution_rank_mapping_file) + institution_rank_mapping = torch.load( + institution_rank_mapping_file, weights_only=False + ) + + graph_dataset = partial( + Dataset, + paper_rank_mappings=paper_rank_mapping, + author_rank_mappings=author_rank_mapping, + institution_rank_mappings=institution_rank_mapping, + data_dir=data_dir, + ) + else: + raise ValueError(f"Invalid dataset: {dataset}") + + assert comm_type in ["nccl", "nvshmem"] + comm = Comm.Communicator.init_process_group(comm_type) + + comm.barrier() + print(f"Running with {comm.get_world_size()} ranks. Rank: {comm.get_rank()}") + + graph_dataset = graph_dataset(comm=comm) + + trainer = Trainer(graph_dataset, comm) + trainer.prepare_data() + trainer.train() + comm.destroy() + + if dist.is_initialized(): + dist.destroy_process_group() + + return 0 + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/experiments/OGB-LSC/synthetic/synthetic_dataset.py b/experiments/OGB-LSC/synthetic/synthetic_dataset.py new file mode 100644 index 0000000..05cb399 --- /dev/null +++ b/experiments/OGB-LSC/synthetic/synthetic_dataset.py @@ -0,0 +1,393 @@ +# Copyright (c) 2014-2025, Lawrence Livermore National Security, LLC. +# Produced at the Lawrence Livermore National Laboratory. +# Written by the LBANN Research Team (B. Van Essen, et al.) listed in +# the CONTRIBUTORS file. See the top-level LICENSE file for details. +# +# LLNL-CODE-697807. +# All rights reserved. +# +# This file is part of LBANN: Livermore Big Artificial Neural Network +# Toolkit. For details, see http://software.llnl.gov/LBANN or +# https://github.com/LBANN and https://github.com/LLNL/LBANN. +# +# SPDX-License-Identifier: (Apache-2.0) +from DGraph.Communicator import Communicator +import torch +from typing import Tuple + +torch.random.manual_seed(0) + + +def _generate_paper_2_paper_edges(num_papers): + # Average degree of a paper is ~11 + num_edges = num_papers * 11 + coo_list = torch.randint( + low=0, high=num_papers, size=(2, num_edges), dtype=torch.long + ) + coo_list = torch.unique(coo_list, dim=1) + transpose = coo_list.flip(0) + coo_list = torch.cat([coo_list, transpose], dim=1) + coo_list = torch.sort(coo_list, dim=1).values + return coo_list + + +def _generate_author_2_paper_edges(num_authors, num_papers): + # Average number of authors per paper is ~3.5 + num_edges = int(num_authors * 3.5) + dest_papers = torch.randint( + low=0, high=num_papers, size=(1, num_edges), dtype=torch.long + ) + src_authors = torch.randint( + low=0, high=num_authors, size=(1, num_edges), dtype=torch.long + ) + coo_list = torch.cat([src_authors, dest_papers], dim=0) + coo_list = torch.unique(coo_list, dim=1) + return coo_list + + +def _generate_author_2_institution_edges(num_authors, num_institutions): + # Average number of institutions per author is ~0.35 + num_edges = int(num_authors * 0.35) + dest_num_institutions = torch.randint( + low=0, high=num_institutions, size=(1, num_edges), dtype=torch.long + ) + src_authors = torch.randint( + low=0, high=num_authors, size=(1, num_edges), dtype=torch.long + ) + coo_list = torch.cat([src_authors, dest_num_institutions], dim=0) + coo_list = torch.unique(coo_list, dim=1) + return coo_list + + +def _get_rank_mappings(num_vertices, world_size, rank): + vertices_per_rank = num_vertices // world_size + rank_mappings = torch.zeros(num_vertices, dtype=torch.long) + vertices_cur_rank = 0 + for r in range(world_size): + start = r * vertices_per_rank + end = (r + 1) * vertices_per_rank if r != world_size - 1 else num_vertices + rank_mappings[start:end] = r + if r == rank: + vertices_cur_rank = end - start + return rank_mappings, vertices_cur_rank + + +def edge_mapping_from_vertex_mapping(edge_index, src_rank_mappings, dst_rank_mappings): + # directed edges, so edge_index[0] -> edge_index[1] + src_indices = edge_index[0] + dest_indices = edge_index[1] + # We put the edge on the rank where the destination vertex is located + # Since heterogeneous graphs have different rank mappings for different + # vertex types. + src_data_mappings = src_rank_mappings[src_indices] + dest_data_mappings = dst_rank_mappings[dest_indices] + return (src_data_mappings, dest_data_mappings) + + +class HeterogeneousDataset: + def __init__( + self, + num_papers, + num_authors, + num_institutions, + num_features, + num_classes, + comm: Communicator, + ): + self.num_papers = num_papers + self.num_authors = num_authors + self.num_institutions = num_institutions + self._num_classes = num_classes + self._num_features = num_features + self._num_relations = 5 + self.comm = comm + self.rank = comm.get_rank() + self.world_size = comm.get_world_size() + self.rank = comm.get_rank() + self.paper_vertex_rank_mapping, self.num_paper_vertices = _get_rank_mappings( + num_vertices=num_papers, world_size=self.world_size, rank=self.rank + ) + self.author_vertex_rank_mapping, self.num_author_vertices = _get_rank_mappings( + num_vertices=num_authors, world_size=self.world_size, rank=self.rank + ) + self.institution_vertex_rank_mapping, self.num_institution_vertices = ( + _get_rank_mappings( + num_vertices=num_institutions, + world_size=self.world_size, + rank=self.rank, + ) + ) + _vertices = torch.randperm(num_papers) + self.train_mask = _vertices[: int(0.7 * num_papers)] + self.val_mask = _vertices[int(0.7 * num_papers) : int(0.85 * num_papers)] + self.test_mask = _vertices[int(0.85 * num_papers) :] + self.y = torch.randint( + low=0, high=self.num_classes, size=(num_papers,), dtype=torch.long + ) + + self.paper_2_paper_edges = _generate_paper_2_paper_edges(num_papers) + + ( + paper_2_paper_src_data_mappings, + paper_2_paper_dest_data_mappings, + ) = edge_mapping_from_vertex_mapping( + edge_index=self.paper_2_paper_edges, + src_rank_mappings=self.paper_vertex_rank_mapping, + dst_rank_mappings=self.paper_vertex_rank_mapping, + ) + + self.paper_src_data_mappings = paper_2_paper_src_data_mappings + self.paper_dest_data_mappings = paper_2_paper_dest_data_mappings + + self.author_2_paper_edges = _generate_author_2_paper_edges( + num_authors, num_papers + ) + + ( + author_2_paper_src_data_mappings, + author_2_paper_dest_data_mappings, + ) = edge_mapping_from_vertex_mapping( + edge_index=self.author_2_paper_edges, + src_rank_mappings=self.author_vertex_rank_mapping, + dst_rank_mappings=self.paper_vertex_rank_mapping, + ) + self.author_2_paper_src_data_mappings = author_2_paper_src_data_mappings + self.author_2_paper_dest_data_mappings = author_2_paper_dest_data_mappings + + self.author_2_institution_edges = _generate_author_2_institution_edges( + num_authors, num_institutions + ) + + ( + author_2_institution_src_data_mappings, + author_2_institution_dest_data_mappings, + ) = edge_mapping_from_vertex_mapping( + edge_index=self.author_2_institution_edges, + src_rank_mappings=self.author_vertex_rank_mapping, + dst_rank_mappings=self.institution_vertex_rank_mapping, + ) + + self.author_2_institution_src_data_mappings = ( + author_2_institution_src_data_mappings + ) + self.author_2_institution_dest_data_mappings = ( + author_2_institution_dest_data_mappings + ) + + paper_vertices_cur_rank = int( + (self.paper_vertex_rank_mapping == self.rank).sum() + ) + author_vertices_cur_rank = int( + (self.author_vertex_rank_mapping == self.rank).sum() + ) + institution_vertices_cur_rank = int( + (self.institution_vertex_rank_mapping == self.rank).sum() + ) + self.paper_vertices_cur_rank = paper_vertices_cur_rank + + self.paper_features = torch.randn( + (paper_vertices_cur_rank, num_features), dtype=torch.float32 + ) + self.author_features = torch.randn( + (author_vertices_cur_rank, num_features), dtype=torch.float32 + ) + self.institution_features = torch.randn( + (institution_vertices_cur_rank, num_features), dtype=torch.float32 + ) + + @property + def num_features(self) -> int: + return self._num_features + + @property + def num_classes(self) -> int: + return self._num_classes + + @property + def num_relations(self) -> int: + return self._num_relations + + # def get_validation_mask(self): + # # Only papers are classified + # validation_vertices_mappings = self.paper_vertex_rank_mapping[self.val_mask] + # validation_vertices_mappings = validation_vertices_mappings.to( + # self.val_mask.device + # ) + # num_validation_vertices = (validation_vertices_mappings == self.rank).sum() + # if num_validation_vertices > 0: + # return ( + # self.val_mask[validation_vertices_mappings == self.rank] + # % self.paper_vertices_cur_rank + # ) + # else: + # return torch.tensor([], dtype=torch.long) + + # def get_test_mask(self): + # # Only papers are classified + + # paper_vertices = self.paper_vertex_rank_mapping == self.rank + # paper_vertices = paper_vertices.to(self.test_mask.device) + # num_test_vertices = (paper_vertices[self.test_mask] == self.rank).sum() + # if num_test_vertices > 0: + # return ( + # self.test_mask[paper_vertices[self.test_mask] == self.rank] + # % self.paper_vertices_cur_rank + # ) + # else: + # return torch.tensor([], dtype=torch.long) + + # def get_train_mask(self): + # # Only papers are classified + # paper_vertices = self.paper_vertex_rank_mapping == self.rank + # paper_vertices = paper_vertices.to(self.train_mask.device) + # num_train_vertices = (paper_vertices[self.train_mask] == self.rank).sum() + # if num_train_vertices > 0: + # return ( + # self.train_mask[paper_vertices[self.train_mask] == self.rank] + # % self.paper_vertices_cur_rank + # ) + # else: + # return torch.tensor([], dtype=torch.long) + + def get_vertex_rank_mask(self, mask_type: str) -> Tuple[torch.Tensor, torch.Tensor]: + if mask_type == "train": + global_int_mask = self.train_mask + elif mask_type == "val": + global_int_mask = self.val_mask + elif mask_type == "test": + global_int_mask = self.test_mask + else: + raise ValueError(f"Invalid mask type: {mask_type}") + + # Get the ranks of the vertices + # paper_vertex_rank_mapping -> vector of size num_papers, + # where each entry is the location / rank of the vertex + paper_vertex_rank_mapping = self.paper_vertex_rank_mapping.to( + global_int_mask.device + ) + vertex_ranks = paper_vertex_rank_mapping[global_int_mask] + # vertex_ranks is location of the vertices in the global_int_mask + vertex_ranks_mask = vertex_ranks == self.rank + return global_int_mask, vertex_ranks_mask + + def get_mask(self, mask_type: str) -> torch.Tensor: + + global_int_mask, vertex_ranks_mask = self.get_vertex_rank_mask(mask_type) + local_int_mask = global_int_mask[vertex_ranks_mask] + local_int_mask = local_int_mask % self.paper_vertices_cur_rank + return local_int_mask + + def get_target(self, _type: str) -> torch.Tensor: + global_int_mask, vertex_ranks_mask = self.get_vertex_rank_mask(_type) + + global_training_targets = self.y[:, global_int_mask.squeeze(0)] + local_training_targets = global_training_targets[vertex_ranks_mask] + + return local_training_targets + + def __len__(self): + return 0 + + def add_batch_dimension(self): + """Add a batch dimension to all tensors. This is particularly useful + because we only have one graph and DGraph is built to handle batches of graphs. + We want to do this here because this allows us to avoid copying the data + and requiring a data loader. + """ + self.paper_features = self.paper_features.unsqueeze(0) + self.author_features = self.author_features.unsqueeze(0) + self.institution_features = self.institution_features.unsqueeze(0) + self.y = self.y.unsqueeze(0) + self.train_mask = self.train_mask.unsqueeze(0) + self.val_mask = self.val_mask.unsqueeze(0) + self.test_mask = self.test_mask.unsqueeze(0) + self.paper_2_paper_edges = self.paper_2_paper_edges.unsqueeze(0) + self.author_2_paper_edges = self.author_2_paper_edges.unsqueeze(0) + self.author_2_institution_edges = self.author_2_institution_edges.unsqueeze(0) + + return self + + def to(self, device): + """Move the dataset tensors to the specified device. + We want to do this here because this allows us to avoid + copying the data when the different individual tensors are + accessed. + """ + self.paper_features = self.paper_features.to(device) + self.author_features = self.author_features.to(device) + self.institution_features = self.institution_features.to(device) + self.y = self.y.to(device) + self.train_mask = self.train_mask.to(device) + self.val_mask = self.val_mask.to(device) + self.test_mask = self.test_mask.to(device) + self.paper_2_paper_edges = self.paper_2_paper_edges.to(device) + self.author_2_paper_edges = self.author_2_paper_edges.to(device) + self.author_2_institution_edges = self.author_2_institution_edges.to(device) + + return self + + def __getitem__(self, idx): + # There are 5 relations: + # paper -> paper + # paper -> author + # author -> paper + # author -> institution + # institution -> author + + edge_index = [ + self.paper_2_paper_edges, + self.author_2_paper_edges, + self.author_2_paper_edges.flip(self.author_2_paper_edges.dim() - 2), + self.author_2_institution_edges, + self.author_2_institution_edges.flip( + self.author_2_institution_edges.dim() - 2 + ), + ] + # Locations of the edges + rank_mappings = [ + [self.paper_src_data_mappings, self.paper_dest_data_mappings], + [ + self.author_2_paper_src_data_mappings, + self.author_2_paper_dest_data_mappings, + ], + [ + self.author_2_paper_dest_data_mappings, + self.author_2_paper_src_data_mappings, + ], + [ + self.author_2_institution_src_data_mappings, + self.author_2_institution_dest_data_mappings, + ], + [ + self.author_2_institution_dest_data_mappings, + self.author_2_institution_src_data_mappings, + ], + ] + edge_type = [(0, 0), (1, 0), (0, 1), (1, 2), (2, 1)] + features = [ + self.paper_features, + self.author_features, + self.institution_features, + ] + return (features, edge_index, edge_type, rank_mappings) + + +if __name__ == "__main__": + rank = 0 + world_size = 16 + COMM = type( + "dummy_comm", + (object,), + {"get_rank": lambda self: rank, "get_world_size": lambda self: world_size}, + ) + comm = COMM() + + dataset = HeterogeneousDataset( + num_papers=512, + num_authors=128, + num_institutions=32, + num_features=16, + num_classes=4, + comm=comm, + ) + print(dataset[0]) diff --git a/tests/test_local_kernels.py b/tests/test_local_kernels.py index 1544644..f3a44ac 100644 --- a/tests/test_local_kernels.py +++ b/tests/test_local_kernels.py @@ -67,3 +67,84 @@ def test_optimized_local_gather(): assert torch.allclose( out_tensor.cpu(), out_tensor_gt ), "Optimized local gather failed" + + +def test_optimized_scatter_gaher(): + try: + from torch_local import local_masked_scatter_gather + except ImportError as e: + pytest.fail(f"Failed to import local_masked_scatter_gather: {e}") + + num_src_rows = 8 + num_out_rows = 8 + bs = 1 + num_features = 4 + src_tensor = torch.randn(bs, num_src_rows, num_features) + src_indices = torch.tensor([0, 3, 2, 1]) + dst_indices = torch.tensor([1, 3, 5, 7]) + + out_tensor_gt = torch.zeros(bs, num_out_rows, num_features) + + for i in range(bs): + for j in range(len(src_indices)): + out_tensor_gt[i, dst_indices[j]] = src_tensor[i, src_indices[j]] + out_tensor_gt = out_tensor_gt.view(bs, num_out_rows, num_features) + out_tensor = torch.zeros_like(out_tensor_gt) + out_tensor = out_tensor.cuda() + src_tensor = src_tensor.cuda() + src_indices = src_indices.cuda().long() + dst_indices = dst_indices.cuda().long() + local_masked_scatter_gather( + src_tensor, + src_indices, + dst_indices, + out_tensor, + bs, + num_src_rows, + num_features, + num_out_rows, + ) + assert torch.allclose( + out_tensor.cpu(), out_tensor_gt + ), "Optimized local scatter-gather failed" + + +def test_optimized_scatter_add_gather(): + try: + from torch_local import local_masked_scatter_add_gather + except ImportError as e: + pytest.fail(f"Failed to import local_masked_scatter_add_gather: {e}") + + num_src_rows = 8 + num_out_rows = 8 + bs = 1 + num_features = 4 + src_tensor = torch.randn(bs, num_src_rows, num_features) + src_indices = torch.tensor([0, 3, 2, 1, 3]) + dst_indices = torch.tensor([1, 3, 5, 7, 3]) + + out_tensor_gt = torch.zeros(bs, num_out_rows, num_features) + + for i in range(bs): + for j in range(len(src_indices)): + out_tensor_gt[i, dst_indices[j]] += src_tensor[i, src_indices[j]] + + out_tensor_gt = out_tensor_gt.view(bs, num_out_rows, num_features) + out_tensor = torch.zeros_like(out_tensor_gt) + out_tensor = out_tensor.cuda() + src_tensor = src_tensor.cuda() + src_indices = src_indices.cuda().long() + dst_indices = dst_indices.cuda().long() + local_masked_scatter_add_gather( + src_tensor, + src_indices, + dst_indices, + out_tensor, + bs, + num_src_rows, + num_features, + num_out_rows, + ) + assert torch.allclose( + out_tensor.cpu(), out_tensor_gt + ), "Optimized local scatter-add-gather failed"