Skip to content

/dev/shm fills up completely and causes bus error #816

@adeepak7

Description

@adeepak7

Hello to the repository manintainers!

I hope you are doing good.

I am here to seek your crucial help -

When I try to train the Graph Neural Network on full dataset of size 3.3 TB, my /dev/shm (which has the size of 1.9 TB currently) fills up completely and exits the program by causing a BUS ERROR.

Is there any way out of this? I also tried optimising the code with following optimisations:

  1. Increased /dev/shm size

    1.1. Remounted /dev/shm with a larger size (mount -o remount,size=2T /dev/shm).
    1.2. Ensured the container also had enough shared memory with --shm-size.

  2. Checked actual RAM usage

    2.1. Verified that /dev/shm is backed by RAM (tmpfs) and does not reduce RAM usage until memory pressure builds.

  3. Tried redirecting temporary storage
    3.1. Set TMPDIR=/workspace/repository/tmp_ipc to offload from /dev/shm.
    3.2. But GraphLearn / PyTorch NCCL still defaults to /dev/shm.

  4. Freed up stuck shared memory segments
    4.1. Used rm -f /dev/shm/torch_* /dev/shm/nccl* /dev/shm/pymp-*.
    4.2. Checked processes holding /dev/shm with fuser / /proc//fd.

  5. PyTorch / NCCL / GraphLearn Optimisations
    5.1. Disabled NCCL shared memory and Infiniband (for single GPU case):
    NCCL_SHM_DISABLE=1 NCCL_IB_DISABLE=1

  6. CPU threading optimisations
    6.1. Limited threads to avoid oversubscription:
    OMP_NUM_THREADS=1 MKL_NUM_THREADS=1

  7. Checked GPU visibility
    7.1. Used CUDA_VISIBLE_DEVICES=0 for single GPU training.
    7.2. Verified NVML inside/outside container.

  8. Probed NeighborLoader before full training
    8.1. Added debugging to ensure dataset builds and first batch loads.

  9. Adjusted fan_out
    9.1. Controlled neighborhood expansion (--fan_out 10,5,3 → smaller values reduce memory / CPU pressure).

  10. World size & rank simplification
    10.1. For single GPU: set world_size=1, rank=0 to avoid multiprocessing overhead.

  11. Avoided deadlocks during multiprocessing
    11.1. Tried running training loop with single process instead of PyTorch DDP when only one GPU is used.

  12. Preserved environment by committing container
    12.1. Saved container state to avoid losing installed tools / libraries.

  13. Restarted container cleanly
    13.1. Used docker start -ai instead of re-creating from scratch.

  14. Ensured GPU access inside container
    14.1. Used --gpus flag when starting docker

Here is the code:

# train_stable.py
import argparse, datetime, os, time, numpy as np, os.path as osp, sklearn.metrics, tqdm, torch, warnings
# --- Safe IPC & tempdir (before any spawn/imports) ---
os.environ.setdefault("TMPDIR", "/workspace/repository/tmp_ipc")
import tempfile; tempfile.tempdir = os.environ["TMPDIR"]
import torch.multiprocessing as mp
try: mp.set_start_method("spawn", force=True)
except RuntimeError: pass
torch.multiprocessing.set_sharing_strategy("file_system")  # avoid /dev/shm
# ------------------------------------------------------

import torch.distributed as dist, graphlearn_torch as glt
import mlperf_logging.mllog.constants as mllog_constants
from torch.nn.parallel import DistributedDataParallel
from dataset import IGBHeteroDataset
from mlperf_logging_utils import get_mlperf_logger, submission_info
from utilities import create_ckpt_folder
from rgnn import RGNN

warnings.filterwarnings("ignore")
mllogger = get_mlperf_logger(path=osp.dirname(osp.abspath(__file__)))

def safe_barrier(ws:int):
    if ws > 1: dist.barrier()

def evaluate(model, dataloader, device, rank, world_size, epoch_num):
    if rank == 0:
        mllogger.start(key=mllog_constants.EVAL_START, metadata={mllog_constants.EPOCH_NUM: epoch_num})
    preds, labels = [], []
    with torch.no_grad():
        for batch in dataloader:
            bs = batch['paper'].batch_size
            out = model({k: v.to(device).to(torch.float32) for k, v in batch.x_dict.items()},
                        batch.edge_index_dict)[:bs]
            labels.append(batch['paper'].y[:bs].cpu().numpy())
            preds.append(out.argmax(1).cpu().numpy())
        preds = np.concatenate(preds); labels = np.concatenate(labels)
        acc = sklearn.metrics.accuracy_score(labels, preds)
        if torch.cuda.is_available(): torch.cuda.synchronize()
        safe_barrier(world_size)
        acc_t = torch.tensor(acc, device=device)
        if world_size > 1:
            dist.all_reduce(acc_t, op=dist.ReduceOp.SUM)
            global_acc = acc_t.item() / world_size
        else:
            global_acc = acc
        if rank == 0:
            mllogger.event(key=mllog_constants.EVAL_ACCURACY, value=global_acc,
                           metadata={mllog_constants.EPOCH_NUM: epoch_num})
            mllogger.end(key=mllog_constants.EVAL_STOP, metadata={mllog_constants.EPOCH_NUM: epoch_num})
        return acc, global_acc

def run_training_proc(rank, world_size,
    hidden_channels, num_classes, num_layers, model_type, num_heads, fan_out,
    epochs, train_batch_size, val_batch_size, lr, random_seed,
    dataset, train_idx, val_idx, with_gpu, validation_acc, validation_frac_within_epoch,
    evaluate_on_epoch_end, checkpoint_on_epoch_end, ckpt_steps, ckpt_path,
    use_gpu_sampler):

    if rank == 0 and ckpt_steps > 0:
        ckpt_dir = create_ckpt_folder(base_dir=osp.dirname(osp.abspath(__file__)))

    os.environ.setdefault('MASTER_ADDR', '127.0.0.1')
    os.environ.setdefault('MASTER_PORT', '23456')
    if world_size > 1:
        dist.init_process_group('nccl', rank=rank, world_size=world_size)

    torch.cuda.set_device(rank)
    glt.utils.common.seed_everything(random_seed)
    device = torch.device(rank)

    print(f'Rank {rank} init graphlearn_torch NeighborLoader...')
    # Even split → equal steps when multi-GPU
    train_idx = torch.chunk(train_idx, world_size, dim=0)[rank]
    val_idx   = torch.chunk(val_idx,   world_size, dim=0)[rank]

    # Sampler/device policy: CPU sampler by default (safer), toggle to GPU with env
    loader_device = device if use_gpu_sampler else 'cpu'
    drop_last_train = world_size > 1

    train_loader = glt.loader.NeighborLoader(
        data=dataset,
        num_neighbors=[int(f) for f in fan_out.split(',')],
        input_nodes=('paper', train_idx),
        batch_size=train_batch_size,
        shuffle=True, drop_last=drop_last_train,
        device=loader_device, seed=random_seed, num_workers=0,
    )
    val_loader = glt.loader.NeighborLoader(
        data=dataset,
        num_neighbors=[int(f) for f in fan_out.split(',')],
        input_nodes=('paper', val_idx),
        batch_size=val_batch_size,
        shuffle=True, drop_last=False,
        device=loader_device, seed=random_seed, num_workers=0,
    )

    # Probe (does not consume the iterator used by the loop)
    print(f'Rank {rank} probing first train batch...')
    _ = next(iter(train_loader))
    print(f'Rank {rank} got first train batch: { _["paper"].batch_size } seeds')

    model = RGNN(dataset.get_edge_types(),
                 dataset.node_features['paper'].shape[1],
                 hidden_channels, num_classes,
                 num_layers=num_layers, dropout=0.2,
                 model=model_type, heads=num_heads,
                 node_type='paper').to(device)

    if ckpt_path:
        try:
            ckpt = torch.load(ckpt_path, map_location=device)
            model.load_state_dict(ckpt['model_state_dict'])
        except FileNotFoundError:
            print(f"[rank{rank}] ckpt not found: {ckpt_path}")

    if world_size > 1:
        model = DistributedDataParallel(model, device_ids=[device.index] if with_gpu else None,
                                        find_unused_parameters=True)

    ps = sum(p.nelement()*p.element_size() for p in model.parameters())
    bs = sum(b.nelement()*b.element_size() for b in model.buffers())
    print('model size: {:.3f}MB'.format((ps+bs)/1024**2))

    loss_fcn = torch.nn.CrossEntropyLoss().to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    if ckpt_path:
        try: opt.load_state_dict(ckpt['optimizer_state_dict'])  # noqa
        except Exception: pass

    batch_num = (len(train_idx)//train_batch_size) if drop_last_train else ((len(train_idx)+train_batch_size-1)//train_batch_size)
    validation_freq = max(1, int(batch_num * validation_frac_within_epoch))

    is_success = False; epoch_num = 0; t0 = time.time()
    for epoch in tqdm.tqdm(range(epochs)):
        if rank == 0:
            mllogger.start(key=mllog_constants.EPOCH_START, metadata={mllog_constants.EPOCH_NUM: epoch})
        model.train(); total_loss=0.0; train_acc=0.0; idx=0; gpu_mem=0.0; epoch_start=time.time()

        # --- Per-batch timing to reveal stalls ---
        last = time.time()

        for batch in train_loader:
            fetch_s = last; fetch_e = time.time()              # time between batches
            idx += 1
            bs = batch['paper'].batch_size
            # forward/backward
            fw_s = time.time()
            out = model({k: v.to(device).to(torch.float32) for k, v in batch.x_dict.items()},
                        batch.edge_index_dict)[:bs]
            y = batch['paper'].y[:bs]
            loss = loss_fcn(out, y); opt.zero_grad(); loss.backward(); opt.step()
            fw_e = time.time()

            total_loss += loss.item()
            train_acc += sklearn.metrics.accuracy_score(y.cpu().numpy(), out.argmax(1).detach().cpu().numpy())*100
            if torch.cuda.is_available(): gpu_mem += torch.cuda.max_memory_allocated()/1_000_000

            # Print first few batches timing (helps you see if fetch is slow)
            if idx <= 5 and rank == 0:
                print(f"[epoch{epoch} step{idx}] fetch {fetch_e-fetch_s:.2f}s | fw/bw {fw_e-fw_s:.2f}s | bs={bs}")

            last = time.time()

            if ckpt_steps > 0 and idx % ckpt_steps == 0:
                if torch.cuda.is_available(): torch.cuda.synchronize()
                safe_barrier(world_size)
                if rank == 0:
                    epoch_num = round((epoch + idx / max(1,batch_num)), 2)
                    glt.utils.common.save_ckpt(idx + epoch * batch_num,
                        create_ckpt_folder(osp.dirname(osp.abspath(__file__))),
                        model.module if world_size > 1 else model, opt, epoch_num)
                safe_barrier(world_size)

            if idx % validation_freq == 0:
                if torch.cuda.is_available(): torch.cuda.synchronize()
                safe_barrier(world_size)
                epoch_num = round((epoch + idx / max(1,batch_num)), 2)
                model.eval()
                _, global_acc = evaluate(model, val_loader, device, rank, world_size, epoch_num)
                if validation_acc is not None and global_acc >= validation_acc:
                    is_success = True; model.train(); break
                model.train()

        if torch.cuda.is_available(): torch.cuda.synchronize()
        safe_barrier(world_size)
        if rank == 0:
            mllogger.end(key=mllog_constants.EPOCH_STOP, metadata={mllog_constants.EPOCH_NUM: epoch})

        if checkpoint_on_epoch_end:
            if rank == 0:
                epoch_num = epoch + 1
                glt.utils.common.save_ckpt(idx + epoch * batch_num,
                    create_ckpt_folder(osp.dirname(osp.abspath(__file__))),
                    model.module if world_size > 1 else model, opt, epoch_num)
            safe_barrier(world_size)

        if evaluate_on_epoch_end and not is_success:
            epoch_num = epoch + 1
            model.eval()
            rank_val_acc, global_acc = evaluate(model, val_loader, device, rank, world_size, epoch_num)
            if validation_acc is not None and global_acc >= validation_acc: is_success = True
            train_acc /= max(1, idx); gpu_mem /= max(1, idx)
            tqdm.tqdm.write("Rank{:02d} | Epoch {:03d} | Loss {:.4f} | Train Acc {:.2f} | Val Acc {:.2f} | Time {} | GPU {:.1f} MB".format(
                rank, epoch, total_loss, train_acc, rank_val_acc*100,
                str(datetime.timedelta(seconds=int(time.time()-epoch_start))), gpu_mem))

        if is_success: break

    if rank == 0:
        status = mllog_constants.SUCCESS if is_success else mllog_constants.ABORTED
        mllogger.end(key=mllog_constants.RUN_STOP, metadata={mllog_constants.STATUS: status, mllog_constants.EPOCH_NUM: epoch_num})
    print("Total time taken", str(datetime.timedelta(seconds=int(time.time()-t0))))

if __name__ == '__main__':
    mllogger.event(key=mllog_constants.CACHE_CLEAR, value=True)
    mllogger.start(key=mllog_constants.INIT_START)

    parser = argparse.ArgumentParser()
    root = osp.join(osp.dirname(osp.dirname(osp.dirname(osp.realpath(__file__)))), 'data', 'igbh')
    glt.utils.ensure_dir(root)
    parser.add_argument('--path', type=str, default=root)
    parser.add_argument('--dataset_size', type=str, default='full', choices=['tiny','small','medium','large','full'])
    parser.add_argument('--num_classes', type=int, default=2983, choices=[19,2983])
    parser.add_argument('--in_memory', type=int, default=1, choices=[0,1])
    parser.add_argument('--model', type=str, default='rgat', choices=['rgat','rsage'])
    parser.add_argument('--fan_out', type=str, default='15,10,5')
    parser.add_argument('--train_batch_size', type=int, default=1024)
    parser.add_argument('--val_batch_size', type=int, default=1024)
    parser.add_argument('--hidden_channels', type=int, default=512)
    parser.add_argument('--learning_rate', type=float, default=0.001)
    parser.add_argument('--epochs', type=int, default=2)
    parser.add_argument('--num_layers', type=int, default=3)
    parser.add_argument('--num_heads', type=int, default=4)
    parser.add_argument('--random_seed', type=int, default=42)
    parser.add_argument("--cpu_mode", action="store_true")
    parser.add_argument("--edge_dir", type=str, default='in')
    parser.add_argument('--layout', type=str, default='COO', choices=['COO','CSC','CSR'])
    parser.add_argument("--pin_feature", action="store_true")
    parser.add_argument("--use_fp16", action="store_true")
    parser.add_argument("--validation_frac_within_epoch", type=float, default=0.05)
    parser.add_argument("--validation_acc", type=float, default=0.72)
    parser.add_argument("--evaluate_on_epoch_end", action="store_true")
    parser.add_argument("--checkpoint_on_epoch_end", action="store_true")
    parser.add_argument('--ckpt_steps', type=int, default=-1)
    parser.add_argument('--ckpt_path', type=str, default=None)
    args = parser.parse_args()

    args.with_gpu = (not args.cpu_mode) and torch.cuda.is_available()
    assert args.layout in ['COO','CSC','CSR']

    glt.utils.common.seed_everything(args.random_seed)
    world_size = torch.cuda.device_count()

    submission_info(mllogger, mllog_constants.GNN, 'reference_implementation')
    mllogger.event(key=mllog_constants.GLOBAL_BATCH_SIZE, value=world_size*args.train_batch_size)
    mllogger.event(key=mllog_constants.GRADIENT_ACCUMULATION_STEPS, value=1)
    mllogger.event(key=mllog_constants.OPT_NAME, value='adam')
    mllogger.event(key=mllog_constants.OPT_BASE_LR, value=args.learning_rate)
    mllogger.event(key=mllog_constants.SEED, value=args.random_seed)
    mllogger.end(key=mllog_constants.INIT_STOP)
    mllogger.start(key=mllog_constants.RUN_START)

    # Build dataset (original placement)
    igbh = IGBHeteroDataset(args.path, args.dataset_size, args.in_memory,
                            args.num_classes==2983, True, args.layout, args.use_fp16)
    ds = glt.data.Dataset(edge_dir=args.edge_dir)
    ds.init_node_features(node_feature_data=igbh.feat_dict, with_gpu=args.with_gpu and args.pin_feature)
    # Default to CPU graph mode (safer); toggle GPU ZERO_COPY via env
    use_gpu_sampler = os.getenv("USE_GPU_SAMPLER","0") == "1"
    graph_mode = 'ZERO_COPY' if (args.with_gpu and use_gpu_sampler) else 'CPU'
    ds.init_graph(edge_index=igbh.edge_dict, layout=args.layout, graph_mode=graph_mode)
    ds.init_node_labels(node_label_data={'paper': igbh.label})

    # Indices: RAM by default (avoid /dev/shm). Toggle POSIX SHM via env if needed.
    if os.getenv("USE_POSIX_SHM_INDICES","0") == "1":
        train_idx = igbh.train_idx.clone().share_memory_()
        val_idx   = igbh.val_idx.clone().share_memory_()
        print("Using POSIX SHM for indices (USE_POSIX_SHM_INDICES=1)")
    else:
        train_idx = igbh.train_idx.clone()
        val_idx   = igbh.val_idx.clone()
        print("Using regular RAM for indices")

    mllogger.event(key=mllog_constants.TRAIN_SAMPLES, value=train_idx.size(0))
    mllogger.event(key=mllog_constants.EVAL_SAMPLES,  value=val_idx.size(0))

    print('--- Launching training processes ...\n')
    if world_size == 1:
        run_training_proc(0,1, args.hidden_channels,args.num_classes,args.num_layers,args.model,args.num_heads,
                          args.fan_out,args.epochs,args.train_batch_size,args.val_batch_size,args.learning_rate,
                          args.random_seed, ds, train_idx, val_idx, args.with_gpu, args.validation_acc,
                          args.validation_frac_within_epoch, args.evaluate_on_epoch_end,
                          args.checkpoint_on_epoch_end, args.ckpt_steps, args.ckpt_path,
                          use_gpu_sampler)
    else:
        mp.spawn(run_training_proc,
                 args=(world_size,args.hidden_channels,args.num_classes,args.num_layers,args.model,args.num_heads,
                       args.fan_out,args.epochs,args.train_batch_size,args.val_batch_size,args.learning_rate,
                       args.random_seed, ds, train_idx, val_idx, args.with_gpu, args.validation_acc,
                       args.validation_frac_within_epoch, args.evaluate_on_epoch_end,
                       args.checkpoint_on_epoch_end, args.ckpt_steps, args.ckpt_path,
                       use_gpu_sampler),
                 nprocs=world_size, join=True)

But this code gets stuck even at the first epoch when I run it on 1 node and 1 GPU using the following command:

CUDA_VISIBLE_DEVICES=0 \
NCCL_SHM_DISABLE=1 \
NCCL_IB_DISABLE=1 \
OMP_NUM_THREADS=1 MKL_NUM_THREADS=1 \
python -u train_rgnn_multi_gpu.py \
  --in_memory 0 \
  --train_batch_size 16 \
  --val_batch_size 16 \
  --fan_out 2,2,2 \
  --model rgat \
  --dataset_size full \
  --layout CSC \
  --use_fp16 \
  --path /workspace/repository/graphlearn-for-pytorch/data/igbh/

Any help is deeply appreciated.

Thanks

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions