Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion mlpf/model/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,13 @@
from .logger import _logger
from .utils import unpack_predictions, unpack_target

# import habana if available
try:
import habana_frameworks.torch.core as htcore
except ImportError:
pass

def predict_one_batch(conv_type, model, i, batch, rank, jetdef, jet_ptcut, jet_match_dr, outpath, dir_name, sample):
def predict_one_batch(conv_type, model, i, batch, rank, jetdef, jet_ptcut, jet_match_dr, outpath, dir_name, sample, habana=False):

# skip prediction if output exists
outfile = f"{outpath}/preds{dir_name}/{sample}/pred_{rank}_{i}.parquet"
Expand All @@ -43,6 +48,8 @@ def predict_one_batch(conv_type, model, i, batch, rank, jetdef, jet_ptcut, jet_m
# run model on batch
batch = batch.to(rank)
ypred = model(batch.X, batch.mask)
if habana:
htcore.mark_step()

# convert all outputs to float32 in case running in float16 or bfloat16
ypred = tuple([y.to(torch.float32) for y in ypred])
Expand Down
59 changes: 40 additions & 19 deletions mlpf/model/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@
from mlpf.model.plots import validation_plots
from mlpf.optimizers.lamb import Lamb

# import habana if available
try:
import habana_frameworks.torch.core as htcore
import habana_frameworks.torch.hpu as torch_device
except ImportError:
import torch.cuda as torch_device

def configure_model_trainable(model: MLPF, trainable: Union[str, List[str]], is_training: bool):
"""Set only the given layers as trainable in the model"""
Expand All @@ -66,22 +72,26 @@ def configure_model_trainable(model: MLPF, trainable: Union[str, List[str]], is_
model.eval()


def model_step(batch, model, loss_fn):
def model_step(batch, model, loss_fn, habana=False):
ypred_raw = model(batch.X, batch.mask)
if habana:
htcore.mark_step()
ypred = unpack_predictions(ypred_raw)
ytarget = unpack_target(batch.ytarget, model)
loss_opt, losses_detached = loss_fn(ytarget, ypred, batch)
return loss_opt, losses_detached, ypred_raw, ypred, ytarget


def optimizer_step(model, loss_opt, optimizer, lr_schedule, scaler):
def optimizer_step(model, loss_opt, optimizer, lr_schedule, scaler, habana=False):
# Clear gradients
for param in model.parameters():
param.grad = None

# Backward pass and optimization
scaler.scale(loss_opt).backward()
if habana: htcore.mark_step()
scaler.step(optimizer)
if habana: htcore.mark_step()
scaler.update()
if lr_schedule:
# ReduceLROnPlateau scheduler should only be updated after each full epoch
Expand Down Expand Up @@ -133,6 +143,7 @@ def train_epoch(
device_type="cuda",
dtype=torch.float32,
scaler=None,
habana=False,
):
"""Run one training epoch

Expand Down Expand Up @@ -164,7 +175,7 @@ def train_epoch(
iterator = tqdm.tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch} train loop on rank={rank}")

for itrain, batch in iterator:
batch = batch.to(rank, non_blocking=True)
batch = batch.to("hpu" if habana else rank, non_blocking=True)

with torch.autocast(device_type=device_type, dtype=dtype, enabled=device_type == "cuda"):
loss_opt, loss, _, _, _ = model_step(batch, model, mlpf_loss)
Expand Down Expand Up @@ -270,7 +281,7 @@ def eval_epoch(
)

# Save validation plots for first batch
if (rank == 0 or rank == "cpu") and ival == 0:
if (rank == 0 or rank == "cpu" or rank == "hpu") and ival == 0:
validation_plots(batch, ypred_raw, ytarget, ypred, tensorboard_writer, epoch, outdir)

# Accumulate losses
Expand Down Expand Up @@ -328,6 +339,7 @@ def train_all_epochs(
comet_step_freq=None,
val_freq=None,
checkpoint_dir: str = "",
habana=False,
):
"""Main training loop that handles all epochs and validation

Expand Down Expand Up @@ -369,7 +381,12 @@ def train_all_epochs(
tensorboard_writer_train = None
tensorboard_writer_valid = None

device_type = "cuda" if isinstance(rank, int) else "cpu"
if habana:
device_type = "hpu"
elif isinstance(rank, int):
device_type = "cuda"
else:
device_type = "cpu"
t0_initial = time.time()

# Early stopping setup
Expand Down Expand Up @@ -397,6 +414,7 @@ def train_all_epochs(
device_type=device_type,
dtype=dtype,
scaler=scaler,
habana=habana,
)
train_time = time.time() - epoch_start_time

Expand Down Expand Up @@ -564,7 +582,7 @@ def train_all_epochs(
tensorboard_writer_valid.close()


def run_test(rank, world_size, config, outdir, model, sample, testdir_name, dtype):
def run_test(rank, world_size, config, outdir, model, sample, testdir_name, dtype, habana=False):
batch_size = config["gpu_batch_multiplier"]
version = config["test_dataset"][sample]["version"]

Expand Down Expand Up @@ -606,7 +624,8 @@ def run_test(rank, world_size, config, outdir, model, sample, testdir_name, dtyp
os.system(f"mkdir -p {outdir}/preds{testdir_name}/{sample}")

_logger.info(f"Running predictions on {sample}")
torch.cuda.empty_cache()
if not habana:
torch.cuda.empty_cache()

# FIXME: import this from a central place
if config["dataset"] == "clic":
Expand Down Expand Up @@ -640,7 +659,7 @@ def run_test(rank, world_size, config, outdir, model, sample, testdir_name, dtyp
dist.barrier() # block until all workers finished executing run_predictions()


def run(rank, world_size, config, outdir, logfile):
def run(rank, world_size, config, outdir, logfile, habana=False):
if (rank == 0) or (rank == "cpu"): # keep writing the logs
_configLogger("mlpf", filename=logfile)

Expand Down Expand Up @@ -674,7 +693,7 @@ def run(rank, world_size, config, outdir, logfile):

# load a pre-trained checkpoint (continue an aborted training or fine-tune)
if config["load"]:
model = MLPF(**model_kwargs).to(torch.device(rank))
model = MLPF(**model_kwargs).to(torch.device("hpu" if habana else rank))
optimizer = get_optimizer(model, config)

checkpoint = torch.load(config["load"], map_location=torch.device(rank))
Expand Down Expand Up @@ -720,7 +739,7 @@ def run(rank, world_size, config, outdir, logfile):
model = MLPF(**model_kwargs)
optimizer = get_optimizer(model, config)

model.to(rank)
model.to("hpu" if habana else rank)
model.compile()
configure_model_trainable(model, config["model"]["trainable"], True)

Expand Down Expand Up @@ -800,6 +819,7 @@ def run(rank, world_size, config, outdir, logfile):
comet_step_freq=config["comet_step_freq"],
val_freq=config["val_freq"],
checkpoint_dir=str(checkpoint_dir),
habana=habana,
)

checkpoint = torch.load(f"{checkpoint_dir}/best_weights.pth", map_location=torch.device(rank))
Expand All @@ -812,7 +832,7 @@ def run(rank, world_size, config, outdir, logfile):

if config["test"]:
for sample in config["enabled_test_datasets"]:
run_test(rank, world_size, config, outdir, model, sample, testdir_name, dtype)
run_test(rank, world_size, config, outdir, model, sample, testdir_name, dtype, habana)

# make plots only on rank 0
if (rank == 0) or (rank == "cpu"):
Expand Down Expand Up @@ -862,7 +882,7 @@ def override_config(config: dict, args):


# Run either single GPU or single-node multi-GPU using pytorch DDP
def device_agnostic_run(config, world_size, outdir):
def device_agnostic_run(config, world_size, outdir, habana=False):
if config["train"]:
logfile = f"{outdir}/train.log"
else:
Expand All @@ -871,25 +891,26 @@ def device_agnostic_run(config, world_size, outdir):

if config["gpus"]:
assert (
world_size <= torch.cuda.device_count()
), f"--gpus is too high (specified {world_size} gpus but only {torch.cuda.device_count()} gpus are available)"
world_size <= torch_device.device_count()
), f"--gpus is too high (specified {world_size} gpus but only {torch_device.device_count()} gpus are available)"

torch.cuda.empty_cache()
if not habana:
torch.cuda.empty_cache()
if world_size > 1:
_logger.info(f"Will use torch.nn.parallel.DistributedDataParallel() and {world_size} gpus", color="purple")
for rank in range(world_size):
_logger.info(torch.cuda.get_device_name(rank), color="purple")
_logger.info(torch_device.get_device_name(rank), color="purple")

mp.spawn(
run,
args=(world_size, config, outdir, logfile),
args=(world_size, config, outdir, logfile, habana),
nprocs=world_size,
join=True,
)
elif world_size == 1:
rank = 0
_logger.info(f"Will use single-gpu: {torch.cuda.get_device_name(rank)}", color="purple")
run(rank, world_size, config, outdir, logfile)
_logger.info(f"Will use single-gpu: {torch_device.get_device_name(rank)}", color="purple")
run(rank, world_size, config, outdir, logfile, habana)

else:
rank = "cpu"
Expand Down
17 changes: 13 additions & 4 deletions mlpf/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@
from mlpf.model.PFDataset import SHARING_STRATEGY
from utils import create_experiment_dir

# import habana if available
try:
import habana_frameworks.torch.core as htcore
except ImportError:
pass


def get_parser():
"""Create and return the ArgumentParser object."""
Expand Down Expand Up @@ -114,6 +120,9 @@ def get_parser():
parser_hpo.add_argument("--raytune-num-samples", type=int, help="Number of samples to draw from the search space")
parser_hpo.add_argument("--comet", action="store_true", help="Use comet.ml logging")

# option for habana training
parser_train.add_argument("--habana", action="store_true", default=None, help="use Habana Gaudi device")
parser_test.add_argument("--habana", action="store_true", default=None, help="use Habana Gaudi device")
return parser


Expand Down Expand Up @@ -173,10 +182,10 @@ def main():
"samples": {"cms_pf_ttbar": config[ds]["cms"]["physical_pu"]["samples"]["cms_pf_ttbar"]},
}
}
# load only the last config split
config[ds]["cms"]["physical_pu"]["samples"]["cms_pf_ttbar"]["splits"] = ["10"]
# load only the first config split
config[ds]["cms"]["physical_pu"]["samples"]["cms_pf_ttbar"]["splits"] = ["1"]
config["test_dataset"] = {"cms_pf_ttbar": config["test_dataset"]["cms_pf_ttbar"]}
config["test_dataset"]["cms_pf_ttbar"]["splits"] = ["10"]
config["test_dataset"]["cms_pf_ttbar"]["splits"] = ["1"]

# override loaded config with values from command line args
config = override_config(config, args)
Expand All @@ -201,7 +210,7 @@ def main():
run_ray_training(config, args, experiment_dir)
elif args.command in ["train", "test"]:
world_size = args.gpus if args.gpus > 0 else 1
device_agnostic_run(config, world_size, experiment_dir)
device_agnostic_run(config, world_size, experiment_dir, args.habana)


if __name__ == "__main__":
Expand Down
62 changes: 4 additions & 58 deletions parameters/pytorch/pyg-cms.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ gpus: 1
gpu_batch_multiplier: 1
load:
finetune:
num_epochs: 10
num_epochs: 2
patience: 20
lr: 0.0002
lr_schedule: cosinedecay # constant, cosinedecay, onecycle
Expand Down Expand Up @@ -102,25 +102,7 @@ train_dataset:
samples:
cms_pf_ttbar:
version: 2.7.1
splits: [1,2,3,4,5,6,7,8,9,10]
cms_pf_qcd:
version: 2.7.1
splits: [1,2,3,4,5,6,7,8,9,10]
cms_pf_ztt:
version: 2.7.1
splits: [1,2,3,4,5,6,7,8,9,10]
physical_nopu:
batch_size: 16
samples:
cms_pf_ttbar_nopu:
version: 2.7.1
splits: [1,2,3,4,5,6,7,8,9,10]
cms_pf_qcd_nopu:
version: 2.7.1
splits: [1,2,3,4,5,6,7,8,9,10]
cms_pf_ztt_nopu:
version: 2.7.1
splits: [1,2,3,4,5,6,7,8,9,10]
splits: [1]

valid_dataset:
cms:
Expand All @@ -129,45 +111,9 @@ valid_dataset:
samples:
cms_pf_ttbar:
version: 2.7.1
splits: [1,2,3,4,5,6,7,8,9,10]
cms_pf_qcd:
version: 2.7.1
splits: [1,2,3,4,5,6,7,8,9,10]
cms_pf_ztt:
version: 2.7.1
splits: [1,2,3,4,5,6,7,8,9,10]
physical_nopu:
batch_size: 16
samples:
cms_pf_ttbar_nopu:
version: 2.7.1
splits: [1,2,3,4,5,6,7,8,9,10]
cms_pf_qcd_nopu:
version: 2.7.1
splits: [1,2,3,4,5,6,7,8,9,10]
cms_pf_ztt_nopu:
version: 2.7.1
splits: [1,2,3,4,5,6,7,8,9,10]
splits: [1]

test_dataset:
cms_pf_ttbar:
version: 2.7.1
splits: [1,2,3,4,5,6,7,8,9,10]
cms_pf_qcd:
version: 2.7.1
splits: [1,2,3,4,5,6,7,8,9,10]
# cms_pf_qcd13p6:
# version: 2.7.1
# splits: [1,2,3,4,5,6,7,8,9,10]
cms_pf_ztt:
version: 2.7.1
splits: [1,2,3,4,5,6,7,8,9,10]
cms_pf_ttbar_nopu:
version: 2.7.1
splits: [1,2,3,4,5,6,7,8,9,10]
cms_pf_qcd_nopu:
version: 2.7.1
splits: [1,2,3,4,5,6,7,8,9,10]
cms_pf_ztt_nopu:
version: 2.7.1
splits: [1,2,3,4,5,6,7,8,9,10]
splits: [1]
11 changes: 0 additions & 11 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,14 @@ boost_histogram
click
comet-ml
fastjet
fsspec
jupyter
jupyter-book
keras
keras-tuner
matplotlib
mlcroissant
mplhep
networkx
notebook
numba
numpy
onnx
onnxruntime
pandas
Expand All @@ -32,14 +28,7 @@ scikit-optimize
scipy
seaborn
setGPU
tensorflow
tensorflow-datasets
tf2onnx
torch
torch_runstats
torchaudio
torchvision
tqdm
uproot
vector
zenodo_get
Loading