diff --git a/seqio/dataset_providers.py b/seqio/dataset_providers.py index e7a1a7ff..f0d6cf3a 100644 --- a/seqio/dataset_providers.py +++ b/seqio/dataset_providers.py @@ -1227,12 +1227,14 @@ def _all_metric_fns( score_fns.append(metric_fn) elif pos_args == ("targets", "predictions", "aux_values"): predict_with_aux_fns.append(metric_fn) + elif pos_args == ("targets", "predictions", "inputs"): + predict_fns.append(metric_fn) else: raise ValueError( - "Metric functions must have positional arguments matching either " - "('targets', 'scores'), ('targets', 'predictions') or " - "('targets', 'predictions', 'aux_values'). " - f"Got: {pos_args}" + "Metric functions must have positional arguments matching either" + " ('targets', 'scores'), ('targets', 'predictions'), ('targets'," + " 'predictions', 'aux_values') or ('targets', 'predictions'," + f" 'inputs').Got: {pos_args}" ) return predict_fns, score_fns, predict_with_aux_fns diff --git a/seqio/dataset_providers_test.py b/seqio/dataset_providers_test.py index bdf56a3d..0b8fcfe8 100644 --- a/seqio/dataset_providers_test.py +++ b/seqio/dataset_providers_test.py @@ -132,9 +132,10 @@ def extra_arg_metric_fn(targets, predictions, extra_param): return {} expected_error_message_prefix = ( - "Metric functions must have positional arguments matching either " - "('targets', 'scores'), ('targets', 'predictions') or ('targets', " - "'predictions', 'aux_values'). Got: " + "Metric functions must have positional arguments matching either" + " ('targets', 'scores'), ('targets', 'predictions'), ('targets'," + " 'predictions', 'aux_values') or ('targets', 'predictions', 'inputs')." + " Got: " ) with self.assertRaisesWithLiteralMatch(