From 040e0e0db07480baaa6427a30c394c8b9974223f Mon Sep 17 00:00:00 2001 From: SeqIO Team Date: Thu, 4 Sep 2025 14:13:56 -0700 Subject: [PATCH] Allows ("targets", "predictions", "inputs") positional arguments for metric_fns PiperOrigin-RevId: 803170566 --- seqio/dataset_providers.py | 10 ++++++---- seqio/dataset_providers_test.py | 7 ++++--- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/seqio/dataset_providers.py b/seqio/dataset_providers.py index e7a1a7ff..f0d6cf3a 100644 --- a/seqio/dataset_providers.py +++ b/seqio/dataset_providers.py @@ -1227,12 +1227,14 @@ def _all_metric_fns( score_fns.append(metric_fn) elif pos_args == ("targets", "predictions", "aux_values"): predict_with_aux_fns.append(metric_fn) + elif pos_args == ("targets", "predictions", "inputs"): + predict_fns.append(metric_fn) else: raise ValueError( - "Metric functions must have positional arguments matching either " - "('targets', 'scores'), ('targets', 'predictions') or " - "('targets', 'predictions', 'aux_values'). " - f"Got: {pos_args}" + "Metric functions must have positional arguments matching either" + " ('targets', 'scores'), ('targets', 'predictions'), ('targets'," + " 'predictions', 'aux_values') or ('targets', 'predictions'," + f" 'inputs').Got: {pos_args}" ) return predict_fns, score_fns, predict_with_aux_fns diff --git a/seqio/dataset_providers_test.py b/seqio/dataset_providers_test.py index bdf56a3d..0b8fcfe8 100644 --- a/seqio/dataset_providers_test.py +++ b/seqio/dataset_providers_test.py @@ -132,9 +132,10 @@ def extra_arg_metric_fn(targets, predictions, extra_param): return {} expected_error_message_prefix = ( - "Metric functions must have positional arguments matching either " - "('targets', 'scores'), ('targets', 'predictions') or ('targets', " - "'predictions', 'aux_values'). Got: " + "Metric functions must have positional arguments matching either" + " ('targets', 'scores'), ('targets', 'predictions'), ('targets'," + " 'predictions', 'aux_values') or ('targets', 'predictions', 'inputs')." + " Got: " ) with self.assertRaisesWithLiteralMatch(