diff --git a/bsmetadata/metadata_processors.py b/bsmetadata/metadata_processors.py
index f4a3f121..169cead1 100644
--- a/bsmetadata/metadata_processors.py
+++ b/bsmetadata/metadata_processors.py
@@ -92,7 +92,13 @@ class HtmlProcessor(MetadataProcessor):
def process_local(self, metadata_attrs: Dict[str, Any]) -> Optional[Tuple[str, str]]:
# We represent a html tag `T` by enclosing the corresponding text span with "" and "".
# Example: An apple is an edible fruit.
- return f"<{metadata_attrs['value']}>", f"{metadata_attrs['value']}>"
+ attributes = " ".join(
+ f'{attr}:"{value}"'
+ for attr, value in zip(metadata_attrs["value"]["attrs"]["attr"], metadata_attrs["value"]["attrs"]["value"])
+ )
+ if attributes:
+ attributes = " " + attributes
+ return f"<{metadata_attrs['value']['tag']}{attributes}>", f"{metadata_attrs['value']['tag']}>"
class UrlProcessor(MetadataProcessor):
diff --git a/bsmetadata/metadata_utils.py b/bsmetadata/metadata_utils.py
index 7268102a..4136aca9 100644
--- a/bsmetadata/metadata_utils.py
+++ b/bsmetadata/metadata_utils.py
@@ -15,6 +15,7 @@
"""
import random
from collections import defaultdict
+from dataclasses import dataclass, field
from typing import Any, Dict, List, Tuple
from transformers import PreTrainedTokenizerFast
@@ -27,12 +28,10 @@ def add_metadata_and_chunk_examples(
examples: Dict[str, List], tokenizer: PreTrainedTokenizerFast, cfg: DataConfig
) -> Dict[str, List]:
"""Adds metadata to the provided input examples, encodes them and groups them in chunks of size `cfg.max_seq_len`.
-
Args:
examples: The examples to process, with required keys "text" and "metadata".
tokenizer: The pretrained tokenizer to use.
cfg: The config to use for adding metadata and chunking.
-
Returns:
A new (potentially larger) collection of examples with keys "input_ids", "attention_mask" and "metadata_mask", where:
- the input ids are a list of token ids corresponding to the input text with metadata;
@@ -100,11 +99,9 @@ def is_metadata(idx: int) -> bool:
def create_global_metadata_prefix(example: Dict[str, Any], cfg: DataConfig) -> str:
"""Creates a prefix containing all global metadata information (including URLs, timestamps, etc).
-
Args:
example: The example to create a global metadata prefix for.
cfg: The data config to use.
-
Returns:
A string containing the global metadata prefix.
"""
@@ -122,19 +119,25 @@ def create_global_metadata_prefix(example: Dict[str, Any], cfg: DataConfig) -> s
return cfg.metadata_sep.join(sorted_metadata) + cfg.global_metadata_sep if sorted_metadata else ""
+@dataclass
+class MetadataIdxStorage:
+ start_idx_tag_with_content: dict = field(default_factory=(lambda: defaultdict(list)))
+ end_idx_tag_with_content: dict = field(default_factory=(lambda: defaultdict(list)))
+ start_idx_tag_without_content: dict = field(default_factory=(lambda: defaultdict(list)))
+ end_idx_tag_without_content: dict = field(default_factory=(lambda: defaultdict(list)))
+
+
def add_local_metadata_to_text(example: Dict[str, Any], cfg: DataConfig) -> Tuple[str, List[bool]]:
"""Adds local metadata (such as HTML tags and entity names) to the given input text.
-
Args:
example: The example for which local metadata should be added.
cfg: The data config to use.
-
Returns:
A tuple of two elements, where:
- the first element is the text with metadata;
- the second element is a boolean mask where `mask[i]` is set iff `text[i]` is some kind of metadata.
"""
- metadata_start_texts, metadata_end_texts = defaultdict(list), defaultdict(list)
+ metadata_idx_storage = MetadataIdxStorage()
# Filter and sort all metadata so that they are processed in the requested order.
filtered_metadata = [md for md in example["metadata"] if md["type"] == "local" and md["key"] in cfg.metadata_list]
@@ -152,27 +155,58 @@ def add_local_metadata_to_text(example: Dict[str, Any], cfg: DataConfig) -> Tupl
char_start_idx = metadata.get("char_start_idx", -1)
char_end_idx = metadata.get("char_end_idx", -1)
- metadata_start_texts[char_start_idx].insert(0, start_text)
- metadata_end_texts[char_end_idx].append(end_text)
+ if char_start_idx == char_end_idx:
+ metadata_idx_storage.start_idx_tag_without_content[char_start_idx].insert(0, start_text)
+ metadata_idx_storage.end_idx_tag_without_content[char_end_idx].append(end_text)
+ else:
+ metadata_idx_storage.start_idx_tag_with_content[char_start_idx].insert(0, start_text)
+ metadata_idx_storage.end_idx_tag_with_content[char_end_idx].append(end_text)
# Build the final text with local metadata and the corresponding mask.
text_with_local_metadata = []
metadata_mask = []
+ def _add_metadata_to_text(metadata_text_list, text_with_local_metadata, metadata_mask):
+ for metadata_text in metadata_text_list:
+ text_with_local_metadata.append(metadata_text)
+ metadata_mask += [True] * len(metadata_text)
+
for idx, char in enumerate(example["text"]):
+ if idx in metadata_idx_storage.end_idx_tag_with_content:
+ metadata_text_list = metadata_idx_storage.end_idx_tag_with_content[idx]
+ _add_metadata_to_text(metadata_text_list, text_with_local_metadata, metadata_mask)
+
+ if idx in metadata_idx_storage.start_idx_tag_without_content:
+ metadata_text_list = metadata_idx_storage.start_idx_tag_without_content[idx]
+ _add_metadata_to_text(metadata_text_list, text_with_local_metadata, metadata_mask)
+
+ if idx in metadata_idx_storage.end_idx_tag_without_content:
+ metadata_text_list = metadata_idx_storage.end_idx_tag_without_content[idx]
+ _add_metadata_to_text(metadata_text_list, text_with_local_metadata, metadata_mask)
- if idx in metadata_start_texts:
- for start_text in metadata_start_texts[idx]:
- text_with_local_metadata.append(start_text)
- metadata_mask += [True] * len(start_text)
+ if idx in metadata_idx_storage.start_idx_tag_with_content:
+ metadata_text_list = metadata_idx_storage.start_idx_tag_with_content[idx]
+ _add_metadata_to_text(metadata_text_list, text_with_local_metadata, metadata_mask)
text_with_local_metadata.append(char)
metadata_mask += [False]
- if idx + 1 in metadata_end_texts:
- for end_text in metadata_end_texts[idx + 1]:
- text_with_local_metadata.append(end_text)
- metadata_mask += [True] * len(end_text)
+ idx += 1
+ if idx in metadata_idx_storage.end_idx_tag_with_content:
+ metadata_text_list = metadata_idx_storage.end_idx_tag_with_content[idx]
+ _add_metadata_to_text(metadata_text_list, text_with_local_metadata, metadata_mask)
+
+ if idx in metadata_idx_storage.start_idx_tag_without_content:
+ metadata_text_list = metadata_idx_storage.start_idx_tag_without_content[idx]
+ _add_metadata_to_text(metadata_text_list, text_with_local_metadata, metadata_mask)
+
+ if idx in metadata_idx_storage.end_idx_tag_without_content:
+ metadata_text_list = metadata_idx_storage.end_idx_tag_without_content[idx]
+ _add_metadata_to_text(metadata_text_list, text_with_local_metadata, metadata_mask)
+
+ if idx in metadata_idx_storage.start_idx_tag_with_content:
+ metadata_text_list = metadata_idx_storage.start_idx_tag_with_content[idx]
+ _add_metadata_to_text(metadata_text_list, text_with_local_metadata, metadata_mask)
return "".join(text_with_local_metadata), metadata_mask
diff --git a/tests/test_metadata_utils.py b/tests/test_metadata_utils.py
index c54c2826..73f91da6 100644
--- a/tests/test_metadata_utils.py
+++ b/tests/test_metadata_utils.py
@@ -5,7 +5,7 @@
from transformers import GPT2TokenizerFast
from bsmetadata.input_pipeline import DataConfig
-from bsmetadata.metadata_processors import PROCESSORS, MetadataProcessor
+from bsmetadata.metadata_processors import PROCESSORS, HtmlProcessor, MetadataProcessor
from bsmetadata.metadata_utils import (
add_local_metadata_to_text,
add_metadata_and_chunk_examples,
@@ -57,6 +57,76 @@ def setUp(self) -> None:
{"key": "url", "type": "global", "value": "callto:RickAndMorty/Year%202021/"},
],
},
+ {
+ "id": "0004",
+ "text": "useless text The Walking Dead (season 8)\n",
+ "metadata": [
+ {
+ "char_start_idx": 13,
+ "value": {
+ "tag": "h1",
+ "attrs": {"attr": [], "value": []},
+ },
+ "char_end_idx": 40,
+ "key": "html",
+ "type": "local",
+ },
+ {
+ "char_start_idx": 13,
+ "value": {
+ "tag": "div",
+ "attrs": {"attr": [], "value": []},
+ },
+ "char_end_idx": 13,
+ "key": "html",
+ "type": "local",
+ },
+ {
+ "char_start_idx": 0,
+ "value": {"tag": "a", "attrs": {"attr": [], "value": []}},
+ "char_end_idx": 13,
+ "key": "html",
+ "type": "local",
+ },
+ {
+ "char_start_idx": 13,
+ "value": {
+ "tag": "div",
+ "attrs": {"attr": [], "value": []},
+ },
+ "char_end_idx": 13,
+ "key": "html",
+ "type": "local",
+ },
+ {
+ "char_start_idx": 13,
+ "value": {
+ "tag": "a",
+ "attrs": {"attr": [], "value": []},
+ },
+ "char_end_idx": 13,
+ "key": "html",
+ "type": "local",
+ },
+ {
+ "char_start_idx": 13,
+ "value": {
+ "tag": "div",
+ "attrs": {"attr": [], "value": []},
+ },
+ "char_end_idx": 13,
+ "key": "html",
+ "type": "local",
+ },
+ {
+ "char_start_idx": 13,
+ "value": {"tag": "i", "attrs": {"attr": [], "value": []}},
+ "char_end_idx": 29,
+ "key": "html",
+ "type": "local",
+ },
+ ],
+ },
]
def test_chunks(self):
@@ -133,6 +203,18 @@ def test_add_no_metadata_and_chunk_examples(self):
for example in mapped_ds:
self.assertTrue(all(not x for x in example["metadata_mask"]))
+ def test_add_html_tags(self):
+ cfg = DataConfig()
+ cfg.metadata_list = ["html"]
+ PROCESSORS["html"] = HtmlProcessor
+
+ text1, mask1 = add_local_metadata_to_text(self.examples[3], cfg)
+ target_text = (
+ "useless text
The Walking Dead (season 8)
\n"
+ )
+
+ self.assertEqual(text1, target_text)
+
def test_add_metadata_and_chunk_examples(self):
cfg = DataConfig()
cfg.metadata_list = ["url", "timestamp", "html", "entity"]