Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 62 additions & 3 deletions ctlearn/tools/predict_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
vstack,
join,
setdiff,
unique,
)

from ctapipe.containers import (
Expand Down Expand Up @@ -112,7 +113,7 @@ class PredictCTLearnModel(Tool):
load_cameradirection_model_from : pathlib.Path
Path to a Keras model file (Keras3) or directory (Keras2) for the regression
of the primary particle arrival direction based on camera coordinate offsets.
load_cameradirection_model_from : pathlib.Path
load_skydirection_model_from : pathlib.Path
Path to a Keras model file (Keras3) or directory (Keras2) for the regression
of the primary particle arrival direction based on spherical coordinate offsets.
output_path : pathlib.Path
Expand Down Expand Up @@ -854,6 +855,28 @@ def _create_nan_table(self, nonexample_identifiers, columns, shapes):
)
return nan_table

def deduplicate_first_valid(
self,
table: Table,
keys=('obs_id', 'event_id'),
valid_col='CTLearn_is_valid',
):
"""
Return a deduplicated Astropy Table.

For each group defined by `keys`, keep the first row where
`valid_col` is True. If none are valid, keep the first row.
"""

t = table.copy()

t.sort(
list(keys) + [valid_col],
reverse=[False] * len(keys) + [True]
)

return unique(t, keys=list(keys), keep='first')

def _store_pointing(self, all_identifiers):
"""
Store the telescope pointing table from to the output file.
Expand Down Expand Up @@ -1253,6 +1276,12 @@ def start(self):
classification_subarray_table[f"{self.prefix}_telescopes"] = (
reco_telescopes
)
# Deduplicate the subarray classification table to have only one entry per event
classification_subarray_table = super().deduplicate_first_valid(
table=classification_subarray_table,
keys=SUBARRAY_EVENT_KEYS,
valid_col=f"{self.prefix}_is_valid",
)
# Sort the subarray classification table
classification_subarray_table.sort(SUBARRAY_EVENT_KEYS)
# Save the prediction to the output file
Expand Down Expand Up @@ -1381,6 +1410,12 @@ def start(self):
energy_subarray_table[f"{self.prefix}_telescopes"] = (
reco_telescopes
)
# Deduplicate the subarray classification table to have only one entry per event
energy_subarray_table = super().deduplicate_first_valid(
table=energy_subarray_table,
keys=SUBARRAY_EVENT_KEYS,
valid_col=f"{self.prefix}_is_valid",
)
# Sort the subarray energy table
energy_subarray_table.sort(SUBARRAY_EVENT_KEYS)
# Save the prediction to the output file
Expand Down Expand Up @@ -1537,6 +1572,12 @@ def start(self):
direction_subarray_table[f"{self.prefix}_telescopes"] = (
reco_telescopes
)
# Deduplicate the subarray classification table to have only one entry per event
direction_subarray_table = super().deduplicate_first_valid(
table=direction_subarray_table,
keys=SUBARRAY_EVENT_KEYS,
valid_col=f"{self.prefix}_is_valid",
)
# Sort the subarray geometry table
direction_subarray_table.sort(SUBARRAY_EVENT_KEYS)
# Save the prediction to the output file
Expand Down Expand Up @@ -1717,7 +1758,7 @@ def start(self):
self.log.info("Starting the prediction...")
classification_feature_vectors = None
if self.load_type_model_from is not None:
# Predict the energy of the primary particle
# Predict the classification of the primary particle
classification_table, classification_feature_vectors = (
super()._predict_classification(example_identifiers)
)
Expand All @@ -1730,7 +1771,7 @@ def start(self):
shapes=[(len(nonexample_identifiers),)],
)
classification_table = vstack([classification_table, nan_table])
# Add is_valid column to the energy table
# Add is_valid column to the classification table
classification_table.add_column(
~np.isnan(
classification_table[f"{self.prefix}_tel_prediction"].data,
Expand All @@ -1745,6 +1786,12 @@ def start(self):
classification_table.rename_column(
f"{self.prefix}_tel_is_valid", f"{self.prefix}_is_valid"
)
# Deduplicate the subarray classification table to have only one entry per event
classification_table = super().deduplicate_first_valid(
table=classification_table,
keys=SUBARRAY_EVENT_KEYS,
valid_col=f"{self.prefix}_is_valid",
)
classification_table.sort(SUBARRAY_EVENT_KEYS)
# Add the default values and meta data to the table
add_defaults_and_meta(
Expand Down Expand Up @@ -1793,6 +1840,12 @@ def start(self):
energy_table.rename_column(
f"{self.prefix}_tel_is_valid", f"{self.prefix}_is_valid"
)
# Deduplicate the subarray energy table to have only one entry per event
energy_table = super().deduplicate_first_valid(
table=energy_table,
keys=SUBARRAY_EVENT_KEYS,
valid_col=f"{self.prefix}_is_valid",
)
energy_table.sort(SUBARRAY_EVENT_KEYS)
# Add the default values and meta data to the table
add_defaults_and_meta(
Expand Down Expand Up @@ -1845,6 +1898,12 @@ def start(self):
~np.isnan(direction_table[f"{self.prefix}_alt"].data, dtype=bool),
name=f"{self.prefix}_is_valid",
)
# Deduplicate the subarray direction table to have only one entry per event
direction_table = super().deduplicate_first_valid(
table=direction_table,
keys=SUBARRAY_EVENT_KEYS,
valid_col=f"{self.prefix}_is_valid",
)
direction_table.sort(SUBARRAY_EVENT_KEYS)
# Add the default values and meta data to the table
add_defaults_and_meta(
Expand Down
Loading