diff --git a/seqio/dataset_providers.py b/seqio/dataset_providers.py index 3ae616dd..ac9701d2 100644 --- a/seqio/dataset_providers.py +++ b/seqio/dataset_providers.py @@ -1931,6 +1931,7 @@ def get_task_dataset( try_in_mem_cache: bool = True, ) -> tf.data.Dataset: """.""" + def filter_features(ex): return {k: v for k, v in ex.items() if k in output_feature_keys}