Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
4250814
Classification
GernotMaier Dec 26, 2025
d390191
no energies in classification
GernotMaier Dec 26, 2025
78a8107
write signal/background efficiency
GernotMaier Dec 26, 2025
69ac608
apply classification
GernotMaier Dec 27, 2025
3652bff
cleanup
GernotMaier Dec 27, 2025
a74f38b
unit tests
GernotMaier Dec 28, 2025
e75735b
zenith bins
GernotMaier Dec 28, 2025
ab4c762
cleanup
GernotMaier Dec 29, 2025
b50b931
simplification
GernotMaier Dec 29, 2025
1c964b9
write more to joblib file
GernotMaier Dec 29, 2025
5bd6109
unification
GernotMaier Dec 29, 2025
165c9de
simplification
GernotMaier Dec 29, 2025
034caed
using native XGB
GernotMaier Dec 30, 2025
5e3d08b
remove size from training
GernotMaier Dec 30, 2025
b06b2a4
config
GernotMaier Dec 30, 2025
1980c41
consistent naming
GernotMaier Dec 30, 2025
bf10257
ignore docstrings in tests
GernotMaier Dec 30, 2025
1b9d66c
tests
GernotMaier Dec 30, 2025
7a42e37
unit tests
GernotMaier Dec 30, 2025
4695f2f
apply cuts
GernotMaier Dec 31, 2025
45871e4
remove tests
GernotMaier Dec 31, 2025
4e8a760
log message
GernotMaier Dec 31, 2025
72bc845
Merge pull request #15 from Eventdisplay/xgboost-multi-output
GernotMaier Dec 31, 2025
c964219
simplified configuration
GernotMaier Dec 31, 2025
9df8c52
notable simplifications
GernotMaier Jan 1, 2026
7b75433
config module
GernotMaier Jan 1, 2026
91dc97f
cleanup
GernotMaier Jan 1, 2026
a7499ce
Update src/eventdisplay_ml/evaluate.py
GernotMaier Jan 1, 2026
daa5027
Update src/eventdisplay_ml/features.py
GernotMaier Jan 1, 2026
18d3a7e
Update src/eventdisplay_ml/utils.py
GernotMaier Jan 1, 2026
8d07e69
Update src/eventdisplay_ml/data_processing.py
GernotMaier Jan 1, 2026
97a8912
Update src/eventdisplay_ml/features.py
GernotMaier Jan 1, 2026
c70dc53
Update src/eventdisplay_ml/utils.py
GernotMaier Jan 1, 2026
a88994c
disable unit tests
GernotMaier Jan 1, 2026
77561b8
pre-commit
GernotMaier Jan 1, 2026
9667d1d
cleanup
GernotMaier Jan 1, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ jobs:

unit_tests:
runs-on: ubuntu-latest
if: false
strategy:
matrix:
python-version: ["3.13"]
Expand Down
2 changes: 2 additions & 0 deletions docs/changes/13.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
- add classification routines for gamma/hadron separation.
- add pre-training quality cuts.
3 changes: 3 additions & 0 deletions docs/changes/13.maintenance.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
- refactoring code to minimize duplication and improve maintainability.
- unified command line interface for all scripts.
- unit tests are disabled for now due to rapid changes in the codebase.
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ optional-dependencies."tests" = [
urls."bug tracker" = "https://github.com/Eventdisplay/Eventdisplay-ML/issues"
urls."documentation" = "https://github.com/Eventdisplay/Eventdisplay-ML"
urls."repository" = "https://github.com/Eventdisplay/Eventdisplay-ML"
scripts.eventdisplay-ml-apply-xgb-classify = "eventdisplay_ml.scripts.apply_xgb_classify:main"
scripts.eventdisplay-ml-apply-xgb-stereo = "eventdisplay_ml.scripts.apply_xgb_stereo:main"
scripts.eventdisplay-ml-train-xgb-classify = "eventdisplay_ml.scripts.train_xgb_classify:main"
scripts.eventdisplay-ml-train-xgb-stereo = "eventdisplay_ml.scripts.train_xgb_stereo:main"

[tool.setuptools]
Expand Down Expand Up @@ -118,6 +120,9 @@ lint.ignore = [

lint.pydocstyle.convention = "numpy"

[tool.ruff.lint.per-file-ignores]
"tests/**.py" = ["D103"]

[tool.codespell]
ignore-words-list = "chec,arrang,livetime"

Expand Down
174 changes: 174 additions & 0 deletions src/eventdisplay_ml/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
"""Configuration for XGBoost model training."""

import argparse
import logging

import numpy as np

from eventdisplay_ml import utils
from eventdisplay_ml.features import target_features
from eventdisplay_ml.hyper_parameters import (
hyper_parameters,
pre_cuts_classification,
pre_cuts_regression,
)
from eventdisplay_ml.models import load_models

_logger = logging.getLogger(__name__)


def configure_training(analysis_type):
"""Configure model training based on command-line arguments."""
parser = argparse.ArgumentParser(description=(f"Train XGBoost models for {analysis_type}."))

if analysis_type == "stereo_analysis":
parser.add_argument(
"--input_file_list", help=f"List of input mscw files for {analysis_type}."
)
if analysis_type == "classification":
parser.add_argument("--input_signal_file_list", help="List of input signal mscw files.")
parser.add_argument(
"--input_background_file_list", help="List of input background mscw files."
)

parser.add_argument(
"--model_prefix",
required=True,
help=("Path to directory for writing XGBoost models (without n_tel / energy bin suffix)."),
)
parser.add_argument(
"--hyperparameter_config",
help="Path to JSON file with hyperparameter configuration.",
default=None,
type=str,
)
parser.add_argument("--n_tel", type=int, help="Telescope multiplicity (2, 3, or 4).")
parser.add_argument(
"--train_test_fraction",
type=float,
help="Fraction of data for training (e.g., 0.5).",
default=0.5,
)
parser.add_argument(
"--max_events",
type=int,
help="Maximum number of events to process across all files.",
)
parser.add_argument(
"--random_state",
type=int,
help="Random state for train/test split.",
default=None,
)

if analysis_type == "classification":
parser.add_argument(
"--model_parameters",
type=str,
help=("Path to model parameter file (JSON) defining energy and zenith bins."),
)
parser.add_argument(
"--energy_bin_number",
type=int,
help="Energy bin number for selection (optional).",
default=0,
)

model_configs = vars(parser.parse_args())

_logger.info(f"--- XGBoost {analysis_type} training ---")
_logger.info(f"Telescope multiplicity: {model_configs.get('n_tel')}")
_logger.info(f"Model output prefix: {model_configs.get('model_prefix')}")
_logger.info(f"Train vs test fraction: {model_configs['train_test_fraction']}")
_logger.info(f"Max events: {model_configs['max_events']}")
if analysis_type == "classification":
_logger.info(f"Energy bin {model_configs['energy_bin_number']}")

model_configs["models"] = hyper_parameters(
analysis_type, model_configs.get("hyperparameter_config")
)
model_configs["targets"] = target_features(analysis_type)

if analysis_type == "stereo_analysis":
model_configs["pre_cuts"] = pre_cuts_regression(model_configs.get("n_tel"))
elif analysis_type == "classification":
model_parameters = utils.load_model_parameters(
model_configs["model_parameters"], model_configs["energy_bin_number"]
)
model_configs["pre_cuts"] = pre_cuts_classification(
model_configs.get("n_tel"),
e_min=np.power(10.0, model_parameters.get("energy_bins_log10_tev", []).get("E_min")),
e_max=np.power(10.0, model_parameters.get("energy_bins_log10_tev", []).get("E_max")),
)
model_configs["energy_bins_log10_tev"] = model_parameters.get("energy_bins_log10_tev", [])
model_configs["zenith_bins_deg"] = model_parameters.get("zenith_bins_deg", [])

return model_configs


def configure_apply(analysis_type):
"""Configure model application based on command-line arguments."""
parser = argparse.ArgumentParser(description=(f"Apply XGBoost models for {analysis_type}."))

parser.add_argument(
"--input_file",
required=True,
metavar="INPUT.root",
help="Path to input mscw file",
)
parser.add_argument(
"--model_prefix",
required=True,
metavar="MODEL_PREFIX",
help=("Path to directory containing XGBoost models (without n_tel / energy bin suffix)."),
)
parser.add_argument(
"--model_name",
default="xgboost",
help="Model name to load (default: xgboost)",
)
parser.add_argument(
"--output_file",
required=True,
metavar="OUTPUT.root",
help="Output file path for predictions",
)
parser.add_argument(
"--image_selection",
type=str,
default="15",
help=(
"Optional telescope selection. Can be bit-coded (e.g., 14 for telescopes 1,2,3) "
"or comma-separated indices (e.g., '1,2,3'). "
"Keeps events with all selected telescopes or 4-telescope events. "
"Default is 15, which selects all 4 telescopes."
),
)
parser.add_argument(
"--max_events",
type=int,
default=None,
help="Maximum number of events to process (default: all events)",
)
parser.add_argument(
"--chunk_size",
type=int,
default=500000,
help="Number of events to process per chunk (default: 500000)",
)

model_configs = vars(parser.parse_args())

_logger.info(f"--- XGBoost {analysis_type} evaluation ---")
_logger.info(f"Input file: {model_configs.get('input_file')}")
_logger.info(f"Model prefix: {model_configs.get('model_prefix')}")
_logger.info(f"Output file: {model_configs.get('output_file')}")
_logger.info(f"Image selection: {model_configs.get('image_selection')}")

model_configs["models"], par = load_models(
analysis_type, model_configs["model_prefix"], model_configs["model_name"]
)
model_configs["energy_bins_log10_tev"] = par.get("energy_bins_log10_tev", [])
model_configs["zenith_bins_deg"] = par.get("zenith_bins_deg", [])

return model_configs
Loading