diff --git a/bsmetadata/metadata_processors.py b/bsmetadata/metadata_processors.py index f4a3f121..169cead1 100644 --- a/bsmetadata/metadata_processors.py +++ b/bsmetadata/metadata_processors.py @@ -92,7 +92,13 @@ class HtmlProcessor(MetadataProcessor): def process_local(self, metadata_attrs: Dict[str, Any]) -> Optional[Tuple[str, str]]: # We represent a html tag `T` by enclosing the corresponding text span with "" and "". # Example: An apple is an edible fruit. - return f"<{metadata_attrs['value']}>", f"" + attributes = " ".join( + f'{attr}:"{value}"' + for attr, value in zip(metadata_attrs["value"]["attrs"]["attr"], metadata_attrs["value"]["attrs"]["value"]) + ) + if attributes: + attributes = " " + attributes + return f"<{metadata_attrs['value']['tag']}{attributes}>", f"" class UrlProcessor(MetadataProcessor): diff --git a/bsmetadata/metadata_utils.py b/bsmetadata/metadata_utils.py index 7268102a..4136aca9 100644 --- a/bsmetadata/metadata_utils.py +++ b/bsmetadata/metadata_utils.py @@ -15,6 +15,7 @@ """ import random from collections import defaultdict +from dataclasses import dataclass, field from typing import Any, Dict, List, Tuple from transformers import PreTrainedTokenizerFast @@ -27,12 +28,10 @@ def add_metadata_and_chunk_examples( examples: Dict[str, List], tokenizer: PreTrainedTokenizerFast, cfg: DataConfig ) -> Dict[str, List]: """Adds metadata to the provided input examples, encodes them and groups them in chunks of size `cfg.max_seq_len`. - Args: examples: The examples to process, with required keys "text" and "metadata". tokenizer: The pretrained tokenizer to use. cfg: The config to use for adding metadata and chunking. - Returns: A new (potentially larger) collection of examples with keys "input_ids", "attention_mask" and "metadata_mask", where: - the input ids are a list of token ids corresponding to the input text with metadata; @@ -100,11 +99,9 @@ def is_metadata(idx: int) -> bool: def create_global_metadata_prefix(example: Dict[str, Any], cfg: DataConfig) -> str: """Creates a prefix containing all global metadata information (including URLs, timestamps, etc). - Args: example: The example to create a global metadata prefix for. cfg: The data config to use. - Returns: A string containing the global metadata prefix. """ @@ -122,19 +119,25 @@ def create_global_metadata_prefix(example: Dict[str, Any], cfg: DataConfig) -> s return cfg.metadata_sep.join(sorted_metadata) + cfg.global_metadata_sep if sorted_metadata else "" +@dataclass +class MetadataIdxStorage: + start_idx_tag_with_content: dict = field(default_factory=(lambda: defaultdict(list))) + end_idx_tag_with_content: dict = field(default_factory=(lambda: defaultdict(list))) + start_idx_tag_without_content: dict = field(default_factory=(lambda: defaultdict(list))) + end_idx_tag_without_content: dict = field(default_factory=(lambda: defaultdict(list))) + + def add_local_metadata_to_text(example: Dict[str, Any], cfg: DataConfig) -> Tuple[str, List[bool]]: """Adds local metadata (such as HTML tags and entity names) to the given input text. - Args: example: The example for which local metadata should be added. cfg: The data config to use. - Returns: A tuple of two elements, where: - the first element is the text with metadata; - the second element is a boolean mask where `mask[i]` is set iff `text[i]` is some kind of metadata. """ - metadata_start_texts, metadata_end_texts = defaultdict(list), defaultdict(list) + metadata_idx_storage = MetadataIdxStorage() # Filter and sort all metadata so that they are processed in the requested order. filtered_metadata = [md for md in example["metadata"] if md["type"] == "local" and md["key"] in cfg.metadata_list] @@ -152,27 +155,58 @@ def add_local_metadata_to_text(example: Dict[str, Any], cfg: DataConfig) -> Tupl char_start_idx = metadata.get("char_start_idx", -1) char_end_idx = metadata.get("char_end_idx", -1) - metadata_start_texts[char_start_idx].insert(0, start_text) - metadata_end_texts[char_end_idx].append(end_text) + if char_start_idx == char_end_idx: + metadata_idx_storage.start_idx_tag_without_content[char_start_idx].insert(0, start_text) + metadata_idx_storage.end_idx_tag_without_content[char_end_idx].append(end_text) + else: + metadata_idx_storage.start_idx_tag_with_content[char_start_idx].insert(0, start_text) + metadata_idx_storage.end_idx_tag_with_content[char_end_idx].append(end_text) # Build the final text with local metadata and the corresponding mask. text_with_local_metadata = [] metadata_mask = [] + def _add_metadata_to_text(metadata_text_list, text_with_local_metadata, metadata_mask): + for metadata_text in metadata_text_list: + text_with_local_metadata.append(metadata_text) + metadata_mask += [True] * len(metadata_text) + for idx, char in enumerate(example["text"]): + if idx in metadata_idx_storage.end_idx_tag_with_content: + metadata_text_list = metadata_idx_storage.end_idx_tag_with_content[idx] + _add_metadata_to_text(metadata_text_list, text_with_local_metadata, metadata_mask) + + if idx in metadata_idx_storage.start_idx_tag_without_content: + metadata_text_list = metadata_idx_storage.start_idx_tag_without_content[idx] + _add_metadata_to_text(metadata_text_list, text_with_local_metadata, metadata_mask) + + if idx in metadata_idx_storage.end_idx_tag_without_content: + metadata_text_list = metadata_idx_storage.end_idx_tag_without_content[idx] + _add_metadata_to_text(metadata_text_list, text_with_local_metadata, metadata_mask) - if idx in metadata_start_texts: - for start_text in metadata_start_texts[idx]: - text_with_local_metadata.append(start_text) - metadata_mask += [True] * len(start_text) + if idx in metadata_idx_storage.start_idx_tag_with_content: + metadata_text_list = metadata_idx_storage.start_idx_tag_with_content[idx] + _add_metadata_to_text(metadata_text_list, text_with_local_metadata, metadata_mask) text_with_local_metadata.append(char) metadata_mask += [False] - if idx + 1 in metadata_end_texts: - for end_text in metadata_end_texts[idx + 1]: - text_with_local_metadata.append(end_text) - metadata_mask += [True] * len(end_text) + idx += 1 + if idx in metadata_idx_storage.end_idx_tag_with_content: + metadata_text_list = metadata_idx_storage.end_idx_tag_with_content[idx] + _add_metadata_to_text(metadata_text_list, text_with_local_metadata, metadata_mask) + + if idx in metadata_idx_storage.start_idx_tag_without_content: + metadata_text_list = metadata_idx_storage.start_idx_tag_without_content[idx] + _add_metadata_to_text(metadata_text_list, text_with_local_metadata, metadata_mask) + + if idx in metadata_idx_storage.end_idx_tag_without_content: + metadata_text_list = metadata_idx_storage.end_idx_tag_without_content[idx] + _add_metadata_to_text(metadata_text_list, text_with_local_metadata, metadata_mask) + + if idx in metadata_idx_storage.start_idx_tag_with_content: + metadata_text_list = metadata_idx_storage.start_idx_tag_with_content[idx] + _add_metadata_to_text(metadata_text_list, text_with_local_metadata, metadata_mask) return "".join(text_with_local_metadata), metadata_mask diff --git a/tests/test_metadata_utils.py b/tests/test_metadata_utils.py index c54c2826..73f91da6 100644 --- a/tests/test_metadata_utils.py +++ b/tests/test_metadata_utils.py @@ -5,7 +5,7 @@ from transformers import GPT2TokenizerFast from bsmetadata.input_pipeline import DataConfig -from bsmetadata.metadata_processors import PROCESSORS, MetadataProcessor +from bsmetadata.metadata_processors import PROCESSORS, HtmlProcessor, MetadataProcessor from bsmetadata.metadata_utils import ( add_local_metadata_to_text, add_metadata_and_chunk_examples, @@ -57,6 +57,76 @@ def setUp(self) -> None: {"key": "url", "type": "global", "value": "callto:RickAndMorty/Year%202021/"}, ], }, + { + "id": "0004", + "text": "useless text The Walking Dead (season 8)\n", + "metadata": [ + { + "char_start_idx": 13, + "value": { + "tag": "h1", + "attrs": {"attr": [], "value": []}, + }, + "char_end_idx": 40, + "key": "html", + "type": "local", + }, + { + "char_start_idx": 13, + "value": { + "tag": "div", + "attrs": {"attr": [], "value": []}, + }, + "char_end_idx": 13, + "key": "html", + "type": "local", + }, + { + "char_start_idx": 0, + "value": {"tag": "a", "attrs": {"attr": [], "value": []}}, + "char_end_idx": 13, + "key": "html", + "type": "local", + }, + { + "char_start_idx": 13, + "value": { + "tag": "div", + "attrs": {"attr": [], "value": []}, + }, + "char_end_idx": 13, + "key": "html", + "type": "local", + }, + { + "char_start_idx": 13, + "value": { + "tag": "a", + "attrs": {"attr": [], "value": []}, + }, + "char_end_idx": 13, + "key": "html", + "type": "local", + }, + { + "char_start_idx": 13, + "value": { + "tag": "div", + "attrs": {"attr": [], "value": []}, + }, + "char_end_idx": 13, + "key": "html", + "type": "local", + }, + { + "char_start_idx": 13, + "value": {"tag": "i", "attrs": {"attr": [], "value": []}}, + "char_end_idx": 29, + "key": "html", + "type": "local", + }, + ], + }, ] def test_chunks(self): @@ -133,6 +203,18 @@ def test_add_no_metadata_and_chunk_examples(self): for example in mapped_ds: self.assertTrue(all(not x for x in example["metadata_mask"])) + def test_add_html_tags(self): + cfg = DataConfig() + cfg.metadata_list = ["html"] + PROCESSORS["html"] = HtmlProcessor + + text1, mask1 = add_local_metadata_to_text(self.examples[3], cfg) + target_text = ( + "useless text

The Walking Dead (season 8)

\n" + ) + + self.assertEqual(text1, target_text) + def test_add_metadata_and_chunk_examples(self): cfg = DataConfig() cfg.metadata_list = ["url", "timestamp", "html", "entity"]