From 078464cc5a9eb0eceba1d2ce1a61ede0eb41b6aa Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Tue, 20 May 2025 20:21:08 +0200 Subject: [PATCH 01/41] adding tests --- pyproject.toml | 137 -------------------------- tests/test_evaluator.py | 210 +++++++++++++++++++++++++++++++++++++--- tests/test_loaders.py | 96 ++++++++++++++++++ 3 files changed, 293 insertions(+), 150 deletions(-) delete mode 100644 pyproject.toml diff --git a/pyproject.toml b/pyproject.toml deleted file mode 100644 index 3c92524..0000000 --- a/pyproject.toml +++ /dev/null @@ -1,137 +0,0 @@ -[build-system] -requires = ["setuptools", "setuptools-scm"] -build-backend = "setuptools.build_meta" - -[project] -name = "nervaluate" -version = "0.2.0" -authors = [ - { name="David S. Batista"}, - { name="Matthew Upson"} -] -description = "NER evaluation considering partial match scoring" -readme = "README.md" -requires-python = ">=3.11" -keywords = ["named-entity-recognition", "ner", "evaluation-metrics", "partial-match-scoring", "nlp"] -license = {text = "MIT License"} -classifiers = [ - "Programming Language :: Python :: 3", - "Operating System :: OS Independent" -] - -dependencies = [ - "pandas==2.2.3" -] - -[project.optional-dependencies] -dev = [ - "black==24.3.0", - "coverage==7.2.5", - "gitchangelog", - "mypy==1.3.0", - "pre-commit==3.3.1", - "pylint==2.17.4", - "pytest==7.3.1", - "pytest-cov==4.1.0", -] - -[project.urls] -"Homepage" = "https://github.com/MantisAI/nervaluate" -"Bug Tracker" = "https://github.com/MantisAI/nervaluate/issues" - -[tool.pytest.ini_options] -testpaths = ["tests"] -python_files = ["test_*.py"] -addopts = "--cov=nervaluate --cov-report=term-missing" - -[tool.coverage.run] -source = ["nervaluate"] -omit = ["*__init__*"] - -[tool.coverage.report] -show_missing = true -precision = 2 -sort = "Miss" - -[tool.black] -line-length = 120 -target-version = ["py311"] - -[tool.pylint.messages_control] -disable = [ - "C0111", # missing-docstring - "C0103", # invalid-name - "W0511", # fixme - "W0603", # global-statement - "W1202", # logging-format-interpolation - "W1203", # logging-fstring-interpolation - "E1126", # invalid-sequence-index - "E1137", # invalid-slice-index - "I0011", # bad-option-value - "I0020", # bad-option-value - "R0801", # duplicate-code - "W9020", # bad-option-value -] - -[tool.pylint.format] -max-line-length = 120 - -[tool.pylint.basic] -accept-no-param-doc = true -accept-no-raise-doc = true -accept-no-return-doc = true -accept-no-yields-doc = true -default-docstring-type = "numpy" - -[tool.pylint.master] -load-plugins = ["pylint.extensions.docparams"] -ignore-paths = ["./examples/.*"] - -[tool.flake8] -max-line-length = 120 -extend-ignore = ["E203"] -exclude = [".git", "__pycache__", "build", "dist", "./examples/*"] -max-complexity = 10 -per-file-ignores = ["*/__init__.py: F401"] - -[tool.mypy] -python_version = "3.11" -ignore_missing_imports = true -disallow_any_unimported = true -disallow_untyped_defs = true -warn_redundant_casts = true -warn_unused_ignores = true -warn_unused_configs = true - - -[[tool.mypy.overrides]] -module = "examples.*" -follow_imports = "skip" - -[tool.hatch.envs.dev] -dependencies = [ - "black==24.3.0", - "coverage==7.2.5", - "gitchangelog", - "mypy==1.3.0", - "pre-commit==3.3.1", - "pylint==2.17.4", - "pytest==7.3.1", - "pytest-cov==4.1.0", -] - -[tool.hatch.envs.dev.scripts] -lint = [ - "black -t py311 -l 120 src tests", - "pylint src tests" -] -typing = "mypy src" -test = "pytest" -clean = "rm -rf dist src/nervaluate.egg-info .coverage .mypy_cache .pytest_cache" -changelog = "gitchangelog > CHANGELOG.rst" -all = [ - "clean", - "lint", - "typing", - "test" -] diff --git a/tests/test_evaluator.py b/tests/test_evaluator.py index 6c12d4d..dc0ce9c 100644 --- a/tests/test_evaluator.py +++ b/tests/test_evaluator.py @@ -1,7 +1,11 @@ # pylint: disable=too-many-lines import pandas as pd +import pytest +from nervaluate.entities import Entity +from nervaluate.evaluator import Evaluator +from nervaluate.evaluation_strategies import StrictEvaluation, PartialEvaluation, EntityTypeEvaluation -from nervaluate import Evaluator +from nervaluate import Evaluator as OldEvaluator def test_results_to_dataframe(): @@ -9,7 +13,7 @@ def test_results_to_dataframe(): Test the results_to_dataframe method. """ # Setup - evaluator = Evaluator( + evaluator = OldEvaluator( true=[["B-LOC", "I-LOC", "O"], ["B-PER", "O", "O"]], pred=[["B-LOC", "I-LOC", "O"], ["B-PER", "I-PER", "O"]], tags=["LOC", "PER"], @@ -130,7 +134,7 @@ def test_evaluator_simple_case(): {"label": "LOC", "start": 3, "end": 4}, ], ] - evaluator = Evaluator(true, pred, tags=["LOC", "PER"]) + evaluator = OldEvaluator(true, pred, tags=["LOC", "PER"]) results, _, _, _ = evaluator.evaluate() expected = { "strict": { @@ -205,7 +209,7 @@ def test_evaluator_simple_case_filtered_tags(): {"label": "LOC", "start": 3, "end": 4}, ], ] - evaluator = Evaluator(true, pred, tags=["PER", "LOC"]) + evaluator = OldEvaluator(true, pred, tags=["PER", "LOC"]) results, _, _, _ = evaluator.evaluate() expected = { "strict": { @@ -272,7 +276,7 @@ def test_evaluator_extra_classes(): pred = [ [{"label": "FOO", "start": 1, "end": 3}], ] - evaluator = Evaluator(true, pred, tags=["ORG", "FOO"]) + evaluator = OldEvaluator(true, pred, tags=["ORG", "FOO"]) results, _, _, _ = evaluator.evaluate() expected = { "strict": { @@ -339,7 +343,7 @@ def test_evaluator_no_entities_in_prediction(): pred = [ [], ] - evaluator = Evaluator(true, pred, tags=["PER"]) + evaluator = OldEvaluator(true, pred, tags=["PER"]) results, _, _, _ = evaluator.evaluate() expected = { "strict": { @@ -406,7 +410,7 @@ def test_evaluator_compare_results_and_results_agg(): pred = [ [{"label": "PER", "start": 2, "end": 4}], ] - evaluator = Evaluator(true, pred, tags=["PER"]) + evaluator = OldEvaluator(true, pred, tags=["PER"]) results, results_agg, _, _ = evaluator.evaluate() expected = { "strict": { @@ -539,7 +543,7 @@ def test_evaluator_compare_results_and_results_agg_1(): [{"label": "ORG", "start": 2, "end": 4}], [{"label": "MISC", "start": 2, "end": 4}], ] - evaluator = Evaluator(true, pred, tags=["PER", "ORG", "MISC"]) + evaluator = OldEvaluator(true, pred, tags=["PER", "ORG", "MISC"]) results, results_agg, _, _ = evaluator.evaluate() expected = { "strict": { @@ -725,7 +729,7 @@ def test_evaluator_with_extra_keys_in_pred(): {"label": "LOC", "start": 3, "end": 4, "token_start": 0, "token_end": 3}, ], ] - evaluator = Evaluator(true, pred, tags=["LOC", "PER"]) + evaluator = OldEvaluator(true, pred, tags=["LOC", "PER"]) results, _, _, _ = evaluator.evaluate() expected = { "strict": { @@ -799,7 +803,7 @@ def test_evaluator_with_extra_keys_in_true(): {"label": "LOC", "start": 3, "end": 4}, ], ] - evaluator = Evaluator(true, pred, tags=["LOC", "PER"]) + evaluator = OldEvaluator(true, pred, tags=["LOC", "PER"]) results, _, _, _ = evaluator.evaluate() expected = { "strict": { @@ -872,7 +876,7 @@ def test_issue_29(): {"label": "PER", "start": 6, "end": 10}, ] ] - evaluator = Evaluator(true, pred, tags=["PER"]) + evaluator = OldEvaluator(true, pred, tags=["PER"]) results, _, _, _ = evaluator.evaluate() expected = { "strict": { @@ -939,7 +943,7 @@ def test_evaluator_compare_results_indices_and_results_agg_indices(): pred = [ [{"label": "PER", "start": 2, "end": 4}], ] - evaluator = Evaluator(true, pred, tags=["PER"]) + evaluator = OldEvaluator(true, pred, tags=["PER"]) _, _, evaluation_indices, evaluation_agg_indices = evaluator.evaluate() expected_evaluation_indices = { "strict": { @@ -1031,7 +1035,7 @@ def test_evaluator_compare_results_indices_and_results_agg_indices_1(): [{"label": "ORG", "start": 2, "end": 4}], [{"label": "MISC", "start": 2, "end": 4}], ] - evaluator = Evaluator(true, pred, tags=["PER", "ORG", "MISC"]) + evaluator = OldEvaluator(true, pred, tags=["PER", "ORG", "MISC"]) _, _, evaluation_indices, evaluation_agg_indices = evaluator.evaluate() expected_evaluation_indices = { @@ -1170,3 +1174,183 @@ def test_evaluator_compare_results_indices_and_results_agg_indices_1(): assert evaluation_indices["ent_type"] == expected_evaluation_indices["ent_type"] assert evaluation_indices["partial"] == expected_evaluation_indices["partial"] assert evaluation_indices["exact"] == expected_evaluation_indices["exact"] + + +@pytest.fixture +def sample_entities(): + return [ + Entity(label="PER", start=0, end=0), + Entity(label="ORG", start=2, end=3), + Entity(label="LOC", start=5, end=5), + ] + + +@pytest.fixture +def sample_predictions(): + return [ + Entity(label="PER", start=0, end=0), # Correct + Entity(label="ORG", start=2, end=2), # Partial + Entity(label="PER", start=5, end=5), # Wrong type + ] + + +def test_strict_evaluation(sample_entities, sample_predictions): + strategy = StrictEvaluation() + result, indices = strategy.evaluate(sample_entities, sample_predictions, ["PER", "ORG", "LOC"]) + + assert result.correct == 1 + assert result.incorrect == 0 + assert result.partial == 0 + assert result.missed == 2 + assert result.spurious == 2 + + +def test_partial_evaluation(sample_entities, sample_predictions): + strategy = PartialEvaluation() + result, indices = strategy.evaluate(sample_entities, sample_predictions, ["PER", "ORG", "LOC"]) + + assert result.correct == 1 + assert result.partial == 1 + assert result.incorrect == 1 + assert result.missed == 1 + assert result.spurious == 0 + + +def test_entity_type_evaluation(sample_entities, sample_predictions): + strategy = EntityTypeEvaluation() + result, indices = strategy.evaluate(sample_entities, sample_predictions, ["PER", "ORG", "LOC"]) + + assert result.correct == 2 + assert result.incorrect == 0 + assert result.partial == 0 + assert result.missed == 1 + assert result.spurious == 1 + + +def test_evaluator_integration(): + # Test with list format + true = [["O", "PER", "O", "ORG", "ORG", "LOC"]] + pred = [["O", "PER", "O", "ORG", "O", "PER"]] + + evaluator = OldEvaluator(true, pred, ["PER", "ORG", "LOC"], loader="list") + results = evaluator.evaluate() + + assert "overall" in results + assert "entities" in results + assert "strict" in results["overall"] + assert "partial" in results["overall"] + assert "ent_type" in results["overall"] + + # Test with CoNLL format + true_conll = "word\tO\nword\tPER\nword\tO\nword\tORG\nword\tORG\nword\tLOC\n\n" + pred_conll = "word\tO\nword\tPER\nword\tO\nword\tORG\nword\tO\nword\tPER\n\n" + + evaluator = OldEvaluator(true_conll, pred_conll, ["PER", "ORG", "LOC"], loader="conll") + results = evaluator.evaluate() + + assert "overall" in results + assert "entities" in results + assert "strict" in results["overall"] + assert "partial" in results["overall"] + assert "ent_type" in results["overall"] + + +@pytest.fixture +def sample_data(): + true = [ + [ + Entity(label="PER", start=0, end=0), + Entity(label="ORG", start=2, end=3), + Entity(label="LOC", start=5, end=5) + ], + [ + Entity(label="PER", start=0, end=0), + Entity(label="ORG", start=2, end=2) + ] + ] + + pred = [ + [ + Entity(label="PER", start=0, end=0), # Correct + Entity(label="ORG", start=2, end=2), # Partial + Entity(label="PER", start=5, end=5) # Wrong type + ], + [ + Entity(label="PER", start=0, end=0), # Correct + Entity(label="LOC", start=2, end=2) # Wrong type + ] + ] + + return true, pred + + +def test_evaluator_initialization(sample_data): + """Test evaluator initialization.""" + true, pred = sample_data + evaluator = Evaluator(true, pred, ["PER", "ORG", "LOC"]) + + assert len(evaluator.true) == 2 + assert len(evaluator.pred) == 2 + assert evaluator.tags == ["PER", "ORG", "LOC"] + + +def test_evaluator_evaluation(sample_data): + """Test evaluation process.""" + true, pred = sample_data + evaluator = Evaluator(true, pred, ["PER", "ORG", "LOC"]) + results = evaluator.evaluate() + + # Check that we have results for all strategies + assert "strict" in results + assert "partial" in results + assert "ent_type" in results + + # Check that we have results for overall and each entity type + for strategy in results: + assert "overall" in results[strategy] + assert "PER" in results[strategy] + assert "ORG" in results[strategy] + assert "LOC" in results[strategy] + + +def test_evaluator_dataframe_conversion(sample_data): + """Test conversion of results to DataFrame.""" + true, pred = sample_data + evaluator = Evaluator(true, pred, ["PER", "ORG", "LOC"]) + results = evaluator.evaluate() + df = evaluator.results_to_dataframe() + + assert isinstance(df, pd.DataFrame) + assert len(df) > 0 + assert "strategy" in df.columns + assert "entity_type" in df.columns + assert "precision" in df.columns + assert "recall" in df.columns + assert "f1" in df.columns + + +def test_evaluator_with_empty_inputs(): + """Test evaluator with empty inputs.""" + evaluator = Evaluator([], [], ["PER", "ORG", "LOC"]) + results = evaluator.evaluate() + + for strategy in results: + assert results[strategy]["overall"].correct == 0 + assert results[strategy]["overall"].incorrect == 0 + assert results[strategy]["overall"].partial == 0 + assert results[strategy]["overall"].missed == 0 + assert results[strategy]["overall"].spurious == 0 + + +def test_evaluator_with_invalid_tags(sample_data): + """Test evaluator with invalid tags.""" + true, pred = sample_data + evaluator = Evaluator(true, pred, ["INVALID"]) + results = evaluator.evaluate() + + for strategy in results: + assert results[strategy]["overall"].correct == 0 + assert results[strategy]["overall"].incorrect == 0 + assert results[strategy]["overall"].partial == 0 + assert results[strategy]["overall"].missed == 0 + assert results[strategy]["overall"].spurious == 0 diff --git a/tests/test_loaders.py b/tests/test_loaders.py index 80cc921..968f81a 100644 --- a/tests/test_loaders.py +++ b/tests/test_loaders.py @@ -1,4 +1,7 @@ from nervaluate import Evaluator +import pytest +from nervaluate.entities import Entity +from nervaluate.loaders import ConllLoader, ListLoader, DictLoader def test_loaders_produce_the_same_results(): @@ -56,3 +59,96 @@ def test_loaders_produce_the_same_results(): assert evaluator_prod.pred == evaluator_list.pred == evaluator_conll.pred assert evaluator_prod.true == evaluator_list.true == evaluator_conll.true + + +@pytest.fixture +def conll_data(): + return """PER\t0\t0 +ORG\t2\t3 +LOC\t5\t5 + +PER\t0\t0 +ORG\t2\t2""" + + +@pytest.fixture +def list_data(): + return [["PER", "O", "ORG", "O", "LOC"], ["PER", "O", "ORG"]] + + +@pytest.fixture +def dict_data(): + return [ + [ + {"label": "PER", "start": 0, "end": 0}, + {"label": "ORG", "start": 2, "end": 3}, + {"label": "LOC", "start": 5, "end": 5}, + ], + [{"label": "PER", "start": 0, "end": 0}, {"label": "ORG", "start": 2, "end": 2}], + ] + + +def test_conll_loader(conll_data): + """Test CoNLL format loader.""" + loader = ConllLoader() + entities = loader.load(conll_data) + + assert len(entities) == 2 # Two documents + assert len(entities[0]) == 3 # First document has 3 entities + assert len(entities[1]) == 2 # Second document has 2 entities + + # Check first entity + assert entities[0][0].label == "PER" + assert entities[0][0].start == 0 + assert entities[0][0].end == 0 + + +def test_list_loader(list_data): + """Test list format loader.""" + loader = ListLoader() + entities = loader.load(list_data) + + assert len(entities) == 2 # Two documents + assert len(entities[0]) == 3 # First document has 3 entities + assert len(entities[1]) == 2 # Second document has 2 entities + + # Check first entity + assert entities[0][0].label == "PER" + assert entities[0][0].start == 0 + assert entities[0][0].end == 0 + + +def test_dict_loader(dict_data): + """Test dictionary format loader.""" + loader = DictLoader() + entities = loader.load(dict_data) + + assert len(entities) == 2 # Two documents + assert len(entities[0]) == 3 # First document has 3 entities + assert len(entities[1]) == 2 # Second document has 2 entities + + # Check first entity + assert entities[0][0].label == "PER" + assert entities[0][0].start == 0 + assert entities[0][0].end == 0 + + +def test_loader_with_empty_input(): + """Test loaders with empty input.""" + loaders = [ConllLoader(), ListLoader(), DictLoader()] + + for loader in loaders: + entities = loader.load([]) + assert len(entities) == 0 + + +def test_loader_with_invalid_data(): + """Test loaders with invalid data.""" + with pytest.raises(Exception): + ConllLoader().load("invalid\tdata") + + with pytest.raises(Exception): + ListLoader().load([["invalid"]]) + + with pytest.raises(Exception): + DictLoader().load([[{"invalid": "data"}]]) From 8864dc168dd299caffef2930ca2ff6c70ac13447 Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Tue, 20 May 2025 20:54:56 +0200 Subject: [PATCH 02/41] fixing loading test_list_loader --- tests/test_loaders.py | 139 +++++++++++++++++++++++++++++++----------- 1 file changed, 102 insertions(+), 37 deletions(-) diff --git a/tests/test_loaders.py b/tests/test_loaders.py index 968f81a..76ff884 100644 --- a/tests/test_loaders.py +++ b/tests/test_loaders.py @@ -61,21 +61,6 @@ def test_loaders_produce_the_same_results(): assert evaluator_prod.true == evaluator_list.true == evaluator_conll.true -@pytest.fixture -def conll_data(): - return """PER\t0\t0 -ORG\t2\t3 -LOC\t5\t5 - -PER\t0\t0 -ORG\t2\t2""" - - -@pytest.fixture -def list_data(): - return [["PER", "O", "ORG", "O", "LOC"], ["PER", "O", "ORG"]] - - @pytest.fixture def dict_data(): return [ @@ -88,34 +73,114 @@ def dict_data(): ] -def test_conll_loader(conll_data): +def test_conll_loader(): """Test CoNLL format loader.""" - loader = ConllLoader() - entities = loader.load(conll_data) - - assert len(entities) == 2 # Two documents - assert len(entities[0]) == 3 # First document has 3 entities - assert len(entities[1]) == 2 # Second document has 2 entities - - # Check first entity - assert entities[0][0].label == "PER" - assert entities[0][0].start == 0 - assert entities[0][0].end == 0 + true_conll = ( + "word\tO\nword\tO\nword\tO\nword\tO\nword\tO\nword\tO\n\n" + "word\tO\nword\tO\nword\tB-ORG\nword\tI-ORG\nword\tO\nword\tO\n\n" + "word\tO\nword\tO\nword\tB-MISC\nword\tI-MISC\nword\tO\nword\tO\n\n" + "word\tB-MISC\nword\tI-MISC\nword\tI-MISC\nword\tI-MISC\nword\tI-MISC\nword\tI-MISC\n" + ) + pred_conll = ( + "word\tO\nword\tO\nword\tB-PER\nword\tI-PER\nword\tO\nword\tO\n\n" + "word\tO\nword\tO\nword\tB-ORG\nword\tI-ORG\nword\tO\nword\tO\n\n" + "word\tO\nword\tO\nword\tB-MISC\nword\tI-MISC\nword\tO\nword\tO\n\n" + "word\tB-MISC\nword\tI-MISC\nword\tI-MISC\nword\tI-MISC\nword\tI-MISC\nword\tI-MISC\n" + ) -def test_list_loader(list_data): + loader = ConllLoader() + true_entities = loader.load(true_conll) + pred_entities = loader.load(pred_conll) + + # Test true entities + assert len(true_entities) == 4 # Four documents + assert len(true_entities[0]) == 0 # First document has no entities (all O tags) + assert len(true_entities[1]) == 1 # Second document has 1 entity (ORG) + assert len(true_entities[2]) == 1 # Third document has 1 entity (MISC) + assert len(true_entities[3]) == 1 # Fourth document has 1 entity (MISC) + + # Check first entity in second document + assert true_entities[1][0].label == "ORG" + assert true_entities[1][0].start == 2 + assert true_entities[1][0].end == 3 + + # Test pred entities + assert len(pred_entities) == 4 # Four documents + assert len(pred_entities[0]) == 1 # First document has 1 entity (PER) + assert len(pred_entities[1]) == 1 # Second document has 1 entity (ORG) + assert len(pred_entities[2]) == 1 # Third document has 1 entity (MISC) + assert len(pred_entities[3]) == 1 # Fourth document has 1 entity (MISC) + + # Check first entity in first document + assert pred_entities[0][0].label == "PER" + assert pred_entities[0][0].start == 2 + assert pred_entities[0][0].end == 3 + + # Test empty document handling + empty_doc = "word\tO\nword\tO\nword\tO\n\n" + empty_entities = loader.load(empty_doc) + assert len(empty_entities) == 1 # One document + assert len(empty_entities[0]) == 0 # Empty list for document with only O tags + + +def test_list_loader(): """Test list format loader.""" - loader = ListLoader() - entities = loader.load(list_data) + true_list = [ + ["O", "O", "O", "O", "O", "O"], + ["O", "O", "B-ORG", "I-ORG", "O", "O"], + ["O", "O", "B-MISC", "I-MISC", "O", "O"], + ["B-MISC", "I-MISC", "I-MISC", "I-MISC", "I-MISC", "I-MISC"], + ] - assert len(entities) == 2 # Two documents - assert len(entities[0]) == 3 # First document has 3 entities - assert len(entities[1]) == 2 # Second document has 2 entities + pred_list = [ + ["O", "O", "B-PER", "I-PER", "O", "O"], + ["O", "O", "B-ORG", "I-ORG", "O", "O"], + ["O", "O", "B-MISC", "I-MISC", "O", "O"], + ["B-MISC", "I-MISC", "I-MISC", "I-MISC", "I-MISC", "I-MISC"], + ] - # Check first entity - assert entities[0][0].label == "PER" - assert entities[0][0].start == 0 - assert entities[0][0].end == 0 + loader = ListLoader() + true_entities = loader.load(true_list) + pred_entities = loader.load(pred_list) + + # Test true entities + assert len(true_entities) == 4 # Four documents + assert len(true_entities[0]) == 0 # First document has no entities (all O tags) + assert len(true_entities[1]) == 1 # Second document has 1 entity (ORG) + assert len(true_entities[2]) == 1 # Third document has 1 entity (MISC) + assert len(true_entities[3]) == 1 # Fourth document has 1 entity (MISC) + + # Check no entities in the first document + assert len(true_entities[0]) == 0 + + # Check first entity in second document + assert true_entities[1][0].label == "ORG" + assert true_entities[1][0].start == 2 + assert true_entities[1][0].end == 3 + + # Check only entity in the last document + assert true_entities[3][0].label == "MISC" + assert true_entities[3][0].start == 0 + assert true_entities[3][0].end == 5 + + # Test pred entities + assert len(pred_entities) == 4 # Four documents + assert len(pred_entities[0]) == 1 # First document has 1 entity (PER) + assert len(pred_entities[1]) == 1 # Second document has 1 entity (ORG) + assert len(pred_entities[2]) == 1 # Third document has 1 entity (MISC) + assert len(pred_entities[3]) == 1 # Fourth document has 1 entity (MISC) + + # Check first entity in first document + assert pred_entities[0][0].label == "PER" + assert pred_entities[0][0].start == 2 + assert pred_entities[0][0].end == 3 + + # Test empty document handling + empty_doc = [["O", "O", "O"]] + empty_entities = loader.load(empty_doc) + assert len(empty_entities) == 1 # One document + assert len(empty_entities[0]) == 0 # Empty list for document with only O tags def test_dict_loader(dict_data): From 36b1b93ce383c9d4fc04c787f43df0b645e95813 Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Tue, 20 May 2025 20:58:24 +0200 Subject: [PATCH 03/41] fixing loading test_dict_loader --- tests/test_loaders.py | 71 ++++++++++++++++++++++++++++++------------- 1 file changed, 50 insertions(+), 21 deletions(-) diff --git a/tests/test_loaders.py b/tests/test_loaders.py index 76ff884..cad4ef7 100644 --- a/tests/test_loaders.py +++ b/tests/test_loaders.py @@ -61,18 +61,6 @@ def test_loaders_produce_the_same_results(): assert evaluator_prod.true == evaluator_list.true == evaluator_conll.true -@pytest.fixture -def dict_data(): - return [ - [ - {"label": "PER", "start": 0, "end": 0}, - {"label": "ORG", "start": 2, "end": 3}, - {"label": "LOC", "start": 5, "end": 5}, - ], - [{"label": "PER", "start": 0, "end": 0}, {"label": "ORG", "start": 2, "end": 2}], - ] - - def test_conll_loader(): """Test CoNLL format loader.""" true_conll = ( @@ -183,19 +171,60 @@ def test_list_loader(): assert len(empty_entities[0]) == 0 # Empty list for document with only O tags -def test_dict_loader(dict_data): +def test_dict_loader(): """Test dictionary format loader.""" + true_prod = [ + [], + [{"label": "ORG", "start": 2, "end": 3}], + [{"label": "MISC", "start": 2, "end": 3}], + [{"label": "MISC", "start": 0, "end": 5}], + ] + + pred_prod = [ + [{"label": "PER", "start": 2, "end": 3}], + [{"label": "ORG", "start": 2, "end": 3}], + [{"label": "MISC", "start": 2, "end": 3}], + [{"label": "MISC", "start": 0, "end": 5}], + ] + loader = DictLoader() - entities = loader.load(dict_data) + true_entities = loader.load(true_prod) + pred_entities = loader.load(pred_prod) - assert len(entities) == 2 # Two documents - assert len(entities[0]) == 3 # First document has 3 entities - assert len(entities[1]) == 2 # Second document has 2 entities + # Test true entities + assert len(true_entities) == 4 # Four documents + assert len(true_entities[0]) == 0 # First document has no entities + assert len(true_entities[1]) == 1 # Second document has 1 entity (ORG) + assert len(true_entities[2]) == 1 # Third document has 1 entity (MISC) + assert len(true_entities[3]) == 1 # Fourth document has 1 entity (MISC) - # Check first entity - assert entities[0][0].label == "PER" - assert entities[0][0].start == 0 - assert entities[0][0].end == 0 + # Check first entity in second document + assert true_entities[1][0].label == "ORG" + assert true_entities[1][0].start == 2 + assert true_entities[1][0].end == 3 + + # Check only entity in the last document + assert true_entities[3][0].label == "MISC" + assert true_entities[3][0].start == 0 + assert true_entities[3][0].end == 5 + + # Test pred entities + assert len(pred_entities) == 4 # Four documents + assert len(pred_entities[0]) == 1 # First document has 1 entity (PER) + assert len(pred_entities[1]) == 1 # Second document has 1 entity (ORG) + assert len(pred_entities[2]) == 1 # Third document has 1 entity (MISC) + assert len(pred_entities[3]) == 1 # Fourth document has 1 entity (MISC) + + # Check first entity in first document + assert pred_entities[0][0].label == "PER" + assert pred_entities[0][0].start == 2 + assert pred_entities[0][0].end == 3 + + # Test empty document handling + empty_doc = [[]] + empty_entities = loader.load(empty_doc) + assert len(empty_entities) == 1 # One document + assert len(empty_entities[0]) == 0 # Empty list for empty document def test_loader_with_empty_input(): From 6edf70ce1efb5483ae3589c9be5cd9fac1749ccd Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Tue, 20 May 2025 21:05:19 +0200 Subject: [PATCH 04/41] fixing loading test_conll_loader --- tests/test_loaders.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_loaders.py b/tests/test_loaders.py index cad4ef7..91f3fda 100644 --- a/tests/test_loaders.py +++ b/tests/test_loaders.py @@ -1,6 +1,6 @@ -from nervaluate import Evaluator import pytest -from nervaluate.entities import Entity + +from nervaluate import Evaluator from nervaluate.loaders import ConllLoader, ListLoader, DictLoader From 7b222ef203b3b586b9b74d7983183036b6d732f6 Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Tue, 20 May 2025 21:09:53 +0200 Subject: [PATCH 05/41] fixing loaders --- tests/test_loaders.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/tests/test_loaders.py b/tests/test_loaders.py index 91f3fda..2c92f23 100644 --- a/tests/test_loaders.py +++ b/tests/test_loaders.py @@ -229,11 +229,20 @@ def test_dict_loader(): def test_loader_with_empty_input(): """Test loaders with empty input.""" - loaders = [ConllLoader(), ListLoader(), DictLoader()] - - for loader in loaders: - entities = loader.load([]) - assert len(entities) == 0 + # Test ConllLoader with empty string + conll_loader = ConllLoader() + entities = conll_loader.load("") + assert len(entities) == 0 + + # Test ListLoader with empty list + list_loader = ListLoader() + entities = list_loader.load([]) + assert len(entities) == 0 + + # Test DictLoader with empty list + dict_loader = DictLoader() + entities = dict_loader.load([]) + assert len(entities) == 0 def test_loader_with_invalid_data(): From 3a5e658542546187d45520a9629958fb1a553ecd Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Tue, 20 May 2025 21:30:49 +0200 Subject: [PATCH 06/41] separating new and old evaluator logic --- tests/test_evaluator.py | 211 +++------------------------------------- 1 file changed, 13 insertions(+), 198 deletions(-) diff --git a/tests/test_evaluator.py b/tests/test_evaluator.py index dc0ce9c..7bea536 100644 --- a/tests/test_evaluator.py +++ b/tests/test_evaluator.py @@ -1,11 +1,6 @@ # pylint: disable=too-many-lines import pandas as pd -import pytest -from nervaluate.entities import Entity -from nervaluate.evaluator import Evaluator -from nervaluate.evaluation_strategies import StrictEvaluation, PartialEvaluation, EntityTypeEvaluation - -from nervaluate import Evaluator as OldEvaluator +from nervaluate import Evaluator def test_results_to_dataframe(): @@ -13,7 +8,7 @@ def test_results_to_dataframe(): Test the results_to_dataframe method. """ # Setup - evaluator = OldEvaluator( + evaluator = Evaluator( true=[["B-LOC", "I-LOC", "O"], ["B-PER", "O", "O"]], pred=[["B-LOC", "I-LOC", "O"], ["B-PER", "I-PER", "O"]], tags=["LOC", "PER"], @@ -134,7 +129,7 @@ def test_evaluator_simple_case(): {"label": "LOC", "start": 3, "end": 4}, ], ] - evaluator = OldEvaluator(true, pred, tags=["LOC", "PER"]) + evaluator = Evaluator(true, pred, tags=["LOC", "PER"]) results, _, _, _ = evaluator.evaluate() expected = { "strict": { @@ -209,7 +204,7 @@ def test_evaluator_simple_case_filtered_tags(): {"label": "LOC", "start": 3, "end": 4}, ], ] - evaluator = OldEvaluator(true, pred, tags=["PER", "LOC"]) + evaluator = Evaluator(true, pred, tags=["PER", "LOC"]) results, _, _, _ = evaluator.evaluate() expected = { "strict": { @@ -276,7 +271,7 @@ def test_evaluator_extra_classes(): pred = [ [{"label": "FOO", "start": 1, "end": 3}], ] - evaluator = OldEvaluator(true, pred, tags=["ORG", "FOO"]) + evaluator = Evaluator(true, pred, tags=["ORG", "FOO"]) results, _, _, _ = evaluator.evaluate() expected = { "strict": { @@ -343,7 +338,7 @@ def test_evaluator_no_entities_in_prediction(): pred = [ [], ] - evaluator = OldEvaluator(true, pred, tags=["PER"]) + evaluator = Evaluator(true, pred, tags=["PER"]) results, _, _, _ = evaluator.evaluate() expected = { "strict": { @@ -410,7 +405,7 @@ def test_evaluator_compare_results_and_results_agg(): pred = [ [{"label": "PER", "start": 2, "end": 4}], ] - evaluator = OldEvaluator(true, pred, tags=["PER"]) + evaluator = Evaluator(true, pred, tags=["PER"]) results, results_agg, _, _ = evaluator.evaluate() expected = { "strict": { @@ -543,7 +538,7 @@ def test_evaluator_compare_results_and_results_agg_1(): [{"label": "ORG", "start": 2, "end": 4}], [{"label": "MISC", "start": 2, "end": 4}], ] - evaluator = OldEvaluator(true, pred, tags=["PER", "ORG", "MISC"]) + evaluator = Evaluator(true, pred, tags=["PER", "ORG", "MISC"]) results, results_agg, _, _ = evaluator.evaluate() expected = { "strict": { @@ -729,7 +724,7 @@ def test_evaluator_with_extra_keys_in_pred(): {"label": "LOC", "start": 3, "end": 4, "token_start": 0, "token_end": 3}, ], ] - evaluator = OldEvaluator(true, pred, tags=["LOC", "PER"]) + evaluator = Evaluator(true, pred, tags=["LOC", "PER"]) results, _, _, _ = evaluator.evaluate() expected = { "strict": { @@ -803,7 +798,7 @@ def test_evaluator_with_extra_keys_in_true(): {"label": "LOC", "start": 3, "end": 4}, ], ] - evaluator = OldEvaluator(true, pred, tags=["LOC", "PER"]) + evaluator = Evaluator(true, pred, tags=["LOC", "PER"]) results, _, _, _ = evaluator.evaluate() expected = { "strict": { @@ -876,7 +871,7 @@ def test_issue_29(): {"label": "PER", "start": 6, "end": 10}, ] ] - evaluator = OldEvaluator(true, pred, tags=["PER"]) + evaluator = Evaluator(true, pred, tags=["PER"]) results, _, _, _ = evaluator.evaluate() expected = { "strict": { @@ -943,7 +938,7 @@ def test_evaluator_compare_results_indices_and_results_agg_indices(): pred = [ [{"label": "PER", "start": 2, "end": 4}], ] - evaluator = OldEvaluator(true, pred, tags=["PER"]) + evaluator = Evaluator(true, pred, tags=["PER"]) _, _, evaluation_indices, evaluation_agg_indices = evaluator.evaluate() expected_evaluation_indices = { "strict": { @@ -1035,7 +1030,7 @@ def test_evaluator_compare_results_indices_and_results_agg_indices_1(): [{"label": "ORG", "start": 2, "end": 4}], [{"label": "MISC", "start": 2, "end": 4}], ] - evaluator = OldEvaluator(true, pred, tags=["PER", "ORG", "MISC"]) + evaluator = Evaluator(true, pred, tags=["PER", "ORG", "MISC"]) _, _, evaluation_indices, evaluation_agg_indices = evaluator.evaluate() expected_evaluation_indices = { @@ -1174,183 +1169,3 @@ def test_evaluator_compare_results_indices_and_results_agg_indices_1(): assert evaluation_indices["ent_type"] == expected_evaluation_indices["ent_type"] assert evaluation_indices["partial"] == expected_evaluation_indices["partial"] assert evaluation_indices["exact"] == expected_evaluation_indices["exact"] - - -@pytest.fixture -def sample_entities(): - return [ - Entity(label="PER", start=0, end=0), - Entity(label="ORG", start=2, end=3), - Entity(label="LOC", start=5, end=5), - ] - - -@pytest.fixture -def sample_predictions(): - return [ - Entity(label="PER", start=0, end=0), # Correct - Entity(label="ORG", start=2, end=2), # Partial - Entity(label="PER", start=5, end=5), # Wrong type - ] - - -def test_strict_evaluation(sample_entities, sample_predictions): - strategy = StrictEvaluation() - result, indices = strategy.evaluate(sample_entities, sample_predictions, ["PER", "ORG", "LOC"]) - - assert result.correct == 1 - assert result.incorrect == 0 - assert result.partial == 0 - assert result.missed == 2 - assert result.spurious == 2 - - -def test_partial_evaluation(sample_entities, sample_predictions): - strategy = PartialEvaluation() - result, indices = strategy.evaluate(sample_entities, sample_predictions, ["PER", "ORG", "LOC"]) - - assert result.correct == 1 - assert result.partial == 1 - assert result.incorrect == 1 - assert result.missed == 1 - assert result.spurious == 0 - - -def test_entity_type_evaluation(sample_entities, sample_predictions): - strategy = EntityTypeEvaluation() - result, indices = strategy.evaluate(sample_entities, sample_predictions, ["PER", "ORG", "LOC"]) - - assert result.correct == 2 - assert result.incorrect == 0 - assert result.partial == 0 - assert result.missed == 1 - assert result.spurious == 1 - - -def test_evaluator_integration(): - # Test with list format - true = [["O", "PER", "O", "ORG", "ORG", "LOC"]] - pred = [["O", "PER", "O", "ORG", "O", "PER"]] - - evaluator = OldEvaluator(true, pred, ["PER", "ORG", "LOC"], loader="list") - results = evaluator.evaluate() - - assert "overall" in results - assert "entities" in results - assert "strict" in results["overall"] - assert "partial" in results["overall"] - assert "ent_type" in results["overall"] - - # Test with CoNLL format - true_conll = "word\tO\nword\tPER\nword\tO\nword\tORG\nword\tORG\nword\tLOC\n\n" - pred_conll = "word\tO\nword\tPER\nword\tO\nword\tORG\nword\tO\nword\tPER\n\n" - - evaluator = OldEvaluator(true_conll, pred_conll, ["PER", "ORG", "LOC"], loader="conll") - results = evaluator.evaluate() - - assert "overall" in results - assert "entities" in results - assert "strict" in results["overall"] - assert "partial" in results["overall"] - assert "ent_type" in results["overall"] - - -@pytest.fixture -def sample_data(): - true = [ - [ - Entity(label="PER", start=0, end=0), - Entity(label="ORG", start=2, end=3), - Entity(label="LOC", start=5, end=5) - ], - [ - Entity(label="PER", start=0, end=0), - Entity(label="ORG", start=2, end=2) - ] - ] - - pred = [ - [ - Entity(label="PER", start=0, end=0), # Correct - Entity(label="ORG", start=2, end=2), # Partial - Entity(label="PER", start=5, end=5) # Wrong type - ], - [ - Entity(label="PER", start=0, end=0), # Correct - Entity(label="LOC", start=2, end=2) # Wrong type - ] - ] - - return true, pred - - -def test_evaluator_initialization(sample_data): - """Test evaluator initialization.""" - true, pred = sample_data - evaluator = Evaluator(true, pred, ["PER", "ORG", "LOC"]) - - assert len(evaluator.true) == 2 - assert len(evaluator.pred) == 2 - assert evaluator.tags == ["PER", "ORG", "LOC"] - - -def test_evaluator_evaluation(sample_data): - """Test evaluation process.""" - true, pred = sample_data - evaluator = Evaluator(true, pred, ["PER", "ORG", "LOC"]) - results = evaluator.evaluate() - - # Check that we have results for all strategies - assert "strict" in results - assert "partial" in results - assert "ent_type" in results - - # Check that we have results for overall and each entity type - for strategy in results: - assert "overall" in results[strategy] - assert "PER" in results[strategy] - assert "ORG" in results[strategy] - assert "LOC" in results[strategy] - - -def test_evaluator_dataframe_conversion(sample_data): - """Test conversion of results to DataFrame.""" - true, pred = sample_data - evaluator = Evaluator(true, pred, ["PER", "ORG", "LOC"]) - results = evaluator.evaluate() - df = evaluator.results_to_dataframe() - - assert isinstance(df, pd.DataFrame) - assert len(df) > 0 - assert "strategy" in df.columns - assert "entity_type" in df.columns - assert "precision" in df.columns - assert "recall" in df.columns - assert "f1" in df.columns - - -def test_evaluator_with_empty_inputs(): - """Test evaluator with empty inputs.""" - evaluator = Evaluator([], [], ["PER", "ORG", "LOC"]) - results = evaluator.evaluate() - - for strategy in results: - assert results[strategy]["overall"].correct == 0 - assert results[strategy]["overall"].incorrect == 0 - assert results[strategy]["overall"].partial == 0 - assert results[strategy]["overall"].missed == 0 - assert results[strategy]["overall"].spurious == 0 - - -def test_evaluator_with_invalid_tags(sample_data): - """Test evaluator with invalid tags.""" - true, pred = sample_data - evaluator = Evaluator(true, pred, ["INVALID"]) - results = evaluator.evaluate() - - for strategy in results: - assert results[strategy]["overall"].correct == 0 - assert results[strategy]["overall"].incorrect == 0 - assert results[strategy]["overall"].partial == 0 - assert results[strategy]["overall"].missed == 0 - assert results[strategy]["overall"].spurious == 0 From db39b2d8f0c44b1136583451b8a43ee26ef447d6 Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Tue, 20 May 2025 21:55:59 +0200 Subject: [PATCH 07/41] adding refactored code --- src/nervaluate/entities.py | 74 ++++++++++ src/nervaluate/evaluation_strategies.py | 143 ++++++++++++++++++ src/nervaluate/evaluator.py | 120 ++++++++++++++++ src/nervaluate/loaders.py | 183 ++++++++++++++++++++++++ tests/test_entities.py | 46 ++++++ tests/test_evaluation_strategies.py | 89 ++++++++++++ tests/test_evaluator_new.py | 175 ++++++++++++++++++++++ 7 files changed, 830 insertions(+) create mode 100644 src/nervaluate/entities.py create mode 100644 src/nervaluate/evaluation_strategies.py create mode 100644 src/nervaluate/evaluator.py create mode 100644 src/nervaluate/loaders.py create mode 100644 tests/test_entities.py create mode 100644 tests/test_evaluation_strategies.py create mode 100644 tests/test_evaluator_new.py diff --git a/src/nervaluate/entities.py b/src/nervaluate/entities.py new file mode 100644 index 0000000..fc7598a --- /dev/null +++ b/src/nervaluate/entities.py @@ -0,0 +1,74 @@ +from dataclasses import dataclass +from typing import List + + +@dataclass +class Entity: + """Represents a named entity with its position and label.""" + + label: str + start: int + end: int + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Entity): + return NotImplemented + return self.label == other.label and self.start == other.start and self.end == other.end + + def __hash__(self) -> int: + return hash((self.label, self.start, self.end)) + + +@dataclass +class EvaluationResult: + """Represents the evaluation metrics for a single entity type or overall.""" + + correct: int = 0 + incorrect: int = 0 + partial: int = 0 + missed: int = 0 + spurious: int = 0 + precision: float = 0.0 + recall: float = 0.0 + f1: float = 0.0 + actual: int = 0 + possible: int = 0 + + def compute_metrics(self, partial_or_type: bool = False) -> None: + """Compute precision, recall and F1 score.""" + self.actual = self.correct + self.incorrect + self.partial + self.spurious + self.possible = self.correct + self.incorrect + self.partial + self.missed + + if partial_or_type: + precision = (self.correct + 0.5 * self.partial) / self.actual if self.actual > 0 else 0 + recall = (self.correct + 0.5 * self.partial) / self.possible if self.possible > 0 else 0 + else: + precision = self.correct / self.actual if self.actual > 0 else 0 + recall = self.correct / self.possible if self.possible > 0 else 0 + + self.precision = precision + self.recall = recall + self.f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 + + +@dataclass +class EvaluationIndices: + """Represents the indices of entities in different evaluation categories.""" + + correct_indices: List[tuple[int, int]] = None + incorrect_indices: List[tuple[int, int]] = None + partial_indices: List[tuple[int, int]] = None + missed_indices: List[tuple[int, int]] = None + spurious_indices: List[tuple[int, int]] = None + + def __post_init__(self): + if self.correct_indices is None: + self.correct_indices = [] + if self.incorrect_indices is None: + self.incorrect_indices = [] + if self.partial_indices is None: + self.partial_indices = [] + if self.missed_indices is None: + self.missed_indices = [] + if self.spurious_indices is None: + self.spurious_indices = [] diff --git a/src/nervaluate/evaluation_strategies.py b/src/nervaluate/evaluation_strategies.py new file mode 100644 index 0000000..ccb19c9 --- /dev/null +++ b/src/nervaluate/evaluation_strategies.py @@ -0,0 +1,143 @@ +from abc import ABC, abstractmethod +from typing import List, Tuple + +from .entities import Entity, EvaluationResult, EvaluationIndices + + +class EvaluationStrategy(ABC): + """Abstract base class for evaluation strategies.""" + + @abstractmethod + def evaluate( + self, true_entities: List[Entity], pred_entities: List[Entity], tags: List[str], instance_index: int = 0 + ) -> Tuple[EvaluationResult, EvaluationIndices]: + """Evaluate the predicted entities against the true entities.""" + + +class StrictEvaluation(EvaluationStrategy): + """Strict evaluation strategy - entities must match exactly.""" + + def evaluate( + self, true_entities: List[Entity], pred_entities: List[Entity], tags: List[str], instance_index: int = 0 + ) -> Tuple[EvaluationResult, EvaluationIndices]: + """ + Evaluate the predicted entities against the true entities using strict matching. + """ + + result = EvaluationResult() + indices = EvaluationIndices() + + for pred_idx, pred in enumerate(pred_entities): + + print(pred) + print(pred.start, pred.end) + print(pred.label) + print(true_entities) + + if pred in true_entities: + result.correct += 1 + indices.correct_indices.append((instance_index, pred_idx)) + else: + result.spurious += 1 + indices.spurious_indices.append((instance_index, pred_idx)) + + for true_idx, true in enumerate(true_entities): + if true not in pred_entities: + result.missed += 1 + indices.missed_indices.append((instance_index, true_idx)) + + result.compute_metrics() + return result, indices + + +class PartialEvaluation(EvaluationStrategy): + """Partial evaluation strategy - allows for partial matches.""" + + def evaluate( + self, true_entities: List[Entity], pred_entities: List[Entity], tags: List[str], instance_index: int = 0 + ) -> Tuple[EvaluationResult, EvaluationIndices]: + result = EvaluationResult() + indices = EvaluationIndices() + matched_true = set() + + for pred_idx, pred in enumerate(pred_entities): + found_match = False + + for true_idx, true in enumerate(true_entities): + if true_idx in matched_true: + continue + + # Check for overlap + if pred.start <= true.end and pred.end >= true.start: + if pred.label == true.label: + if pred.start == true.start and pred.end == true.end: + result.correct += 1 + indices.correct_indices.append((instance_index, pred_idx)) + else: + result.partial += 1 + indices.partial_indices.append((instance_index, pred_idx)) + matched_true.add(true_idx) + found_match = True + break + + result.incorrect += 1 + indices.incorrect_indices.append((instance_index, pred_idx)) + found_match = True + break + + if not found_match: + result.spurious += 1 + indices.spurious_indices.append((instance_index, pred_idx)) + + for true_idx, true in enumerate(true_entities): + if true_idx not in matched_true: + result.missed += 1 + indices.missed_indices.append((instance_index, true_idx)) + + result.compute_metrics(partial_or_type=True) + return result, indices + + +class EntityTypeEvaluation(EvaluationStrategy): + """ + Entity type evaluation strategy - only checks entity types. + + Some overlap between the system tagged entity and the gold annotation is required. + # ToDo: define a minimum overlap threshold - see: https://github.com/MantisAI/nervaluate/pull/83 + """ + + def evaluate( + self, true_entities: List[Entity], pred_entities: List[Entity], tags: List[str], instance_index: int = 0 + ) -> Tuple[EvaluationResult, EvaluationIndices]: + result = EvaluationResult() + indices = EvaluationIndices() + + for pred_idx, pred in enumerate(pred_entities): + found_match = False + found_overlap = False + for true_idx, true in enumerate(true_entities): + + print(f"Checking {pred} against {true}") + + # check for a minimum overlap between the system tagged entity and the gold annotation + if pred.start <= true.end and pred.end >= true.start: + found_overlap = True + + # check if the labels match + if found_overlap and pred.label == true.label: + result.correct += 1 + indices.correct_indices.append((instance_index, pred_idx)) + found_match = True + break + + if not found_match: + result.spurious += 1 + indices.spurious_indices.append((instance_index, pred_idx)) + + for true_idx, true in enumerate(true_entities): + if not any(pred.label == true.label for pred in pred_entities): + result.missed += 1 + indices.missed_indices.append((instance_index, true_idx)) + + result.compute_metrics(partial_or_type=True) + return result, indices diff --git a/src/nervaluate/evaluator.py b/src/nervaluate/evaluator.py new file mode 100644 index 0000000..c697d82 --- /dev/null +++ b/src/nervaluate/evaluator.py @@ -0,0 +1,120 @@ +from typing import List, Dict, Any +import pandas as pd + +from .entities import EvaluationResult +from .evaluation_strategies import EvaluationStrategy, StrictEvaluation, PartialEvaluation, EntityTypeEvaluation +from .loaders import DataLoader, ConllLoader, ListLoader, DictLoader + + +class Evaluator: + """Main evaluator class for NER evaluation.""" + + def __init__(self, true: Any, pred: Any, tags: List[str], loader: str = "default") -> None: + """ + Initialize the evaluator. + + Args: + true: True entities in any supported format + pred: Predicted entities in any supported format + tags: List of valid entity tags + loader: Name of the loader to use + """ + self.tags = tags + self._setup_loaders() + self._load_data(true, pred, loader) + self._setup_evaluation_strategies() + + def _setup_loaders(self) -> None: + """Setup available data loaders.""" + self.loaders: Dict[str, DataLoader] = {"conll": ConllLoader(), "list": ListLoader(), "dict": DictLoader()} + + def _setup_evaluation_strategies(self) -> None: + """Setup evaluation strategies.""" + self.strategies: Dict[str, EvaluationStrategy] = { + "strict": StrictEvaluation(), + "partial": PartialEvaluation(), + "ent_type": EntityTypeEvaluation(), + } + + def _load_data(self, true: Any, pred: Any, loader: str) -> None: + """Load the true and predicted data.""" + if loader == "default": + # Try to infer the loader based on input type + if isinstance(true, str): + loader = "conll" + elif isinstance(true, list) and true and isinstance(true[0], list): + if isinstance(true[0][0], dict): + loader = "dict" + else: + loader = "list" + else: + raise ValueError("Could not infer loader from input type") + + if loader not in self.loaders: + raise ValueError(f"Unknown loader: {loader}") + + self.true = self.loaders[loader].load(true) + self.pred = self.loaders[loader].load(pred) + + if len(self.true) != len(self.pred): + raise ValueError("Number of predicted documents does not equal true") + + def evaluate(self) -> Dict[str, Any]: + """ + Run the evaluation. + + Returns: + Dictionary containing evaluation results for each strategy and entity type + """ + results = {} + entity_results = {tag: {} for tag in self.tags} + + # Evaluate each document + for doc_idx, (true_doc, pred_doc) in enumerate(zip(self.true, self.pred)): + # Filter entities by valid tags + true_doc = [e for e in true_doc if e.label in self.tags] + pred_doc = [e for e in pred_doc if e.label in self.tags] + + # Evaluate with each strategy + for strategy_name, strategy in self.strategies.items(): + result, _ = strategy.evaluate(true_doc, pred_doc, self.tags, doc_idx) + + # Update overall results + if strategy_name not in results: + results[strategy_name] = result + else: + self._merge_results(results[strategy_name], result) + + # Update entity-specific results + for tag in self.tags: + if tag not in entity_results: + entity_results[tag] = {} + if strategy_name not in entity_results[tag]: + entity_results[tag][strategy_name] = result + else: + self._merge_results(entity_results[tag][strategy_name], result) + + return {"overall": results, "entities": entity_results} + + def _merge_results(self, target: EvaluationResult, source: EvaluationResult) -> None: + """Merge two evaluation results.""" + target.correct += source.correct + target.incorrect += source.incorrect + target.partial += source.partial + target.missed += source.missed + target.spurious += source.spurious + target.compute_metrics() + + def results_to_dataframe(self) -> pd.DataFrame: + """Convert results to a pandas DataFrame.""" + results = self.evaluate() + + # Flatten the results structure + flat_results = {} + for category, category_results in results.items(): + for strategy, strategy_results in category_results.items(): + for metric, value in strategy_results.__dict__.items(): + key = f"{category}.{strategy}.{metric}" + flat_results[key] = value + + return pd.DataFrame([flat_results]) diff --git a/src/nervaluate/loaders.py b/src/nervaluate/loaders.py new file mode 100644 index 0000000..28d70be --- /dev/null +++ b/src/nervaluate/loaders.py @@ -0,0 +1,183 @@ +from abc import ABC, abstractmethod +from typing import List, Dict, Any + +from .entities import Entity + + +class DataLoader(ABC): + """Abstract base class for data loaders.""" + + @abstractmethod + def load(self, data: Any) -> List[List[Entity]]: + """Load data into a list of entity lists.""" + + +class ConllLoader(DataLoader): + """Loader for CoNLL format data.""" + + def load(self, data: str) -> List[List[Entity]]: + """Load CoNLL format data into a list of Entity lists.""" + if not isinstance(data, str): + raise ValueError("ConllLoader expects string input") + + if not data: + return [] + + result = [] + # Strip trailing whitespace and newlines to avoid empty documents + documents = data.rstrip().split("\n\n") + + for doc in documents: + if not doc.strip(): + result.append([]) + continue + + current_doc = [] + start_offset = None + end_offset = None + ent_type = None + has_entities = False + + for offset, line in enumerate(doc.split("\n")): + if not line.strip(): + continue + + parts = line.split("\t") + if len(parts) < 2: + raise ValueError(f"Invalid CoNLL format: line '{line}' does not contain a tab separator") + + token_tag = parts[1] + + if token_tag == "O": + if ent_type is not None and start_offset is not None: + end_offset = offset - 1 + current_doc.append(Entity(label=ent_type, start=start_offset, end=end_offset)) + start_offset = None + end_offset = None + ent_type = None + + elif ent_type is None: + if not (token_tag.startswith("B-") or token_tag.startswith("I-")): + raise ValueError(f"Invalid tag format: {token_tag}") + ent_type = token_tag[2:] # Remove B- or I- prefix + start_offset = offset + has_entities = True + + elif ent_type != token_tag[2:] or (ent_type == token_tag[2:] and token_tag[:1] == "B"): + end_offset = offset - 1 + current_doc.append(Entity(label=ent_type, start=start_offset, end=end_offset)) + + # start of a new entity + if not (token_tag.startswith("B-") or token_tag.startswith("I-")): + raise ValueError(f"Invalid tag format: {token_tag}") + ent_type = token_tag[2:] + start_offset = offset + end_offset = None + has_entities = True + + # Catches an entity that goes up until the last token + if ent_type is not None and start_offset is not None and end_offset is None: + current_doc.append(Entity(label=ent_type, start=start_offset, end=len(doc.split("\n")) - 1)) + has_entities = True + + result.append(current_doc if has_entities else []) + + return result + + +class ListLoader(DataLoader): + """Loader for list format data.""" + + def load(self, data: List[List[str]]) -> List[List[Entity]]: + """Load list format data into a list of entity lists.""" + if not isinstance(data, list): + raise ValueError("ListLoader expects list input") + + if not data: + return [] + + result = [] + + for doc in data: + if not isinstance(doc, list): + raise ValueError("Each document must be a list of tags") + + current_doc = [] + start_offset = None + end_offset = None + ent_type = None + + for offset, token_tag in enumerate(doc): + if not isinstance(token_tag, str): + raise ValueError(f"Invalid tag type: {type(token_tag)}") + + if token_tag == "O": + if ent_type is not None and start_offset is not None: + end_offset = offset - 1 + current_doc.append(Entity(label=ent_type, start=start_offset, end=end_offset)) + start_offset = None + end_offset = None + ent_type = None + + elif ent_type is None: + if not (token_tag.startswith("B-") or token_tag.startswith("I-")): + raise ValueError(f"Invalid tag format: {token_tag}") + ent_type = token_tag[2:] # Remove B- or I- prefix + start_offset = offset + + elif ent_type != token_tag[2:] or (ent_type == token_tag[2:] and token_tag[:1] == "B"): + end_offset = offset - 1 + current_doc.append(Entity(label=ent_type, start=start_offset, end=end_offset)) + + # start of a new entity + if not (token_tag.startswith("B-") or token_tag.startswith("I-")): + raise ValueError(f"Invalid tag format: {token_tag}") + ent_type = token_tag[2:] + start_offset = offset + end_offset = None + + # Catches an entity that goes up until the last token + if ent_type is not None and start_offset is not None and end_offset is None: + current_doc.append(Entity(label=ent_type, start=start_offset, end=len(doc) - 1)) + + result.append(current_doc) + + return result + + +class DictLoader(DataLoader): + """Loader for dictionary format data.""" + + def load(self, data: List[List[Dict[str, Any]]]) -> List[List[Entity]]: + """Load dictionary format data into a list of entity lists.""" + if not isinstance(data, list): + raise ValueError("DictLoader expects list input") + + if not data: + return [] + + result = [] + + for doc in data: + if not isinstance(doc, list): + raise ValueError("Each document must be a list of entity dictionaries") + + current_doc = [] + for entity in doc: + if not isinstance(entity, dict): + raise ValueError(f"Invalid entity type: {type(entity)}") + + required_keys = {"label", "start", "end"} + if not all(key in entity for key in required_keys): + raise ValueError(f"Entity missing required keys: {required_keys}") + + if not isinstance(entity["label"], str): + raise ValueError("Entity label must be a string") + + if not isinstance(entity["start"], int) or not isinstance(entity["end"], int): + raise ValueError("Entity start and end must be integers") + + current_doc.append(Entity(label=entity["label"], start=entity["start"], end=entity["end"])) + result.append(current_doc) + + return result diff --git a/tests/test_entities.py b/tests/test_entities.py new file mode 100644 index 0000000..31b962a --- /dev/null +++ b/tests/test_entities.py @@ -0,0 +1,46 @@ +from nervaluate.entities import Entity, EvaluationResult + + +def test_entity_equality(): + """Test Entity equality comparison.""" + entity1 = Entity(label="PER", start=0, end=1) + entity2 = Entity(label="PER", start=0, end=1) + entity3 = Entity(label="ORG", start=0, end=1) + + assert entity1 == entity2 + assert entity1 != entity3 + assert entity1 != "not an entity" + + +def test_entity_hash(): + """Test Entity hashing.""" + entity1 = Entity(label="PER", start=0, end=1) + entity2 = Entity(label="PER", start=0, end=1) + entity3 = Entity(label="ORG", start=0, end=1) + + assert hash(entity1) == hash(entity2) + assert hash(entity1) != hash(entity3) + + +def test_evaluation_result_compute_metrics(): + """Test computation of evaluation metrics.""" + result = EvaluationResult(correct=5, incorrect=2, partial=1, missed=1, spurious=1) + + # Test strict metrics + result.compute_metrics(partial_or_type=False) + assert result.precision == 5 / 9 # 5/(5+2+1+1) + assert result.recall == 5 / (5 + 2 + 1 + 1) + + # Test partial metrics + result.compute_metrics(partial_or_type=True) + assert result.precision == 5.5 / 9 # (5+0.5*1)/(5+2+1+1) + assert result.recall == (5 + 0.5 * 1) / (5 + 2 + 1 + 1) + + +def test_evaluation_result_zero_cases(): + """Test evaluation metrics with zero values.""" + result = EvaluationResult() + result.compute_metrics() + assert result.precision == 0 + assert result.recall == 0 + assert result.f1 == 0 diff --git a/tests/test_evaluation_strategies.py b/tests/test_evaluation_strategies.py new file mode 100644 index 0000000..88cb548 --- /dev/null +++ b/tests/test_evaluation_strategies.py @@ -0,0 +1,89 @@ +import pytest +from nervaluate.entities import Entity +from nervaluate.evaluation_strategies import StrictEvaluation, PartialEvaluation, EntityTypeEvaluation + + +@pytest.fixture +def sample_entities(): + return [ + Entity(label="PER", start=0, end=0), + Entity(label="ORG", start=2, end=3), + Entity(label="LOC", start=5, end=5), + ] + + +@pytest.fixture +def sample_predictions(): + return [ + Entity(label="PER", start=0, end=0), # Correct + Entity(label="ORG", start=2, end=2), # Partial + Entity(label="PER", start=5, end=5), # Wrong type + ] + + +# pylint: disable=redefined-outer-name +def test_strict_evaluation(sample_entities, sample_predictions): + """Test strict evaluation strategy.""" + strategy = StrictEvaluation() + result, indices = strategy.evaluate(sample_entities, sample_predictions, ["PER", "ORG", "LOC"]) + + assert result.correct == 1 + assert result.incorrect == 0 + assert result.partial == 0 + assert result.missed == 2 + assert result.spurious == 2 + + assert len(indices.correct_indices) == 1 + assert len(indices.missed_indices) == 2 + assert len(indices.spurious_indices) == 2 + + +# pylint: disable=redefined-outer-name +def test_partial_evaluation(sample_entities, sample_predictions): + """Test partial evaluation strategy.""" + strategy = PartialEvaluation() + result, indices = strategy.evaluate(sample_entities, sample_predictions, ["PER", "ORG", "LOC"]) + + assert result.correct == 1 + assert result.partial == 1 + assert result.incorrect == 1 + assert result.missed == 1 + assert result.spurious == 0 + + assert len(indices.correct_indices) == 1 + assert len(indices.partial_indices) == 1 + assert len(indices.incorrect_indices) == 1 + assert len(indices.missed_indices) == 1 + + +# pylint: disable=redefined-outer-name +def test_entity_type_evaluation(sample_entities, sample_predictions): + """Test entity type evaluation strategy.""" + strategy = EntityTypeEvaluation() + result, indices = strategy.evaluate(sample_entities, sample_predictions, ["PER", "ORG", "LOC"]) + + assert result.correct == 2 + assert result.incorrect == 0 + assert result.partial == 0 + assert result.missed == 1 + assert result.spurious == 1 + + assert len(indices.correct_indices) == 2 + assert len(indices.missed_indices) == 1 + assert len(indices.spurious_indices) == 1 + + +def test_evaluation_with_empty_inputs(): + """Test evaluation with empty inputs.""" + strategy = StrictEvaluation() + result, indices = strategy.evaluate([], [], ["PER", "ORG", "LOC"]) + + assert result.correct == 0 + assert result.incorrect == 0 + assert result.partial == 0 + assert result.missed == 0 + assert result.spurious == 0 + + assert len(indices.correct_indices) == 0 + assert len(indices.missed_indices) == 0 + assert len(indices.spurious_indices) == 0 diff --git a/tests/test_evaluator_new.py b/tests/test_evaluator_new.py new file mode 100644 index 0000000..8ff7f04 --- /dev/null +++ b/tests/test_evaluator_new.py @@ -0,0 +1,175 @@ +import pandas as pd +import pytest +from nervaluate.entities import Entity +from nervaluate.evaluator import Evaluator +from nervaluate.evaluation_strategies import StrictEvaluation, PartialEvaluation, EntityTypeEvaluation + + +@pytest.fixture +def sample_entities(): + return [ + Entity(label="PER", start=0, end=0), + Entity(label="ORG", start=2, end=3), + Entity(label="LOC", start=5, end=5), + ] + + +@pytest.fixture +def sample_predictions(): + return [ + Entity(label="PER", start=0, end=0), # Correct + Entity(label="ORG", start=2, end=2), # Partial + Entity(label="PER", start=5, end=5), # Wrong type + ] + + +def test_strict_evaluation(sample_entities, sample_predictions): + strategy = StrictEvaluation() + result, indices = strategy.evaluate(sample_entities, sample_predictions, ["PER", "ORG", "LOC"]) + + assert result.correct == 1 + assert result.incorrect == 0 + assert result.partial == 0 + assert result.missed == 2 + assert result.spurious == 2 + + +def test_partial_evaluation(sample_entities, sample_predictions): + strategy = PartialEvaluation() + result, indices = strategy.evaluate(sample_entities, sample_predictions, ["PER", "ORG", "LOC"]) + + assert result.correct == 1 + assert result.partial == 1 + assert result.incorrect == 1 + assert result.missed == 1 + assert result.spurious == 0 + + +def test_entity_type_evaluation(sample_entities, sample_predictions): + strategy = EntityTypeEvaluation() + result, indices = strategy.evaluate(sample_entities, sample_predictions, ["PER", "ORG", "LOC"]) + + assert result.correct == 2 + assert result.incorrect == 0 + assert result.partial == 0 + assert result.missed == 1 + assert result.spurious == 1 + + +def test_evaluator_integration(): + # Test with list format + true = [["O", "PER", "O", "ORG", "ORG", "LOC"]] + pred = [["O", "PER", "O", "ORG", "O", "PER"]] + + evaluator = OldEvaluator(true, pred, ["PER", "ORG", "LOC"], loader="list") + results = evaluator.evaluate() + + assert "overall" in results + assert "entities" in results + assert "strict" in results["overall"] + assert "partial" in results["overall"] + assert "ent_type" in results["overall"] + + # Test with CoNLL format + true_conll = "word\tO\nword\tPER\nword\tO\nword\tORG\nword\tORG\nword\tLOC\n\n" + pred_conll = "word\tO\nword\tPER\nword\tO\nword\tORG\nword\tO\nword\tPER\n\n" + + evaluator = OldEvaluator(true_conll, pred_conll, ["PER", "ORG", "LOC"], loader="conll") + results = evaluator.evaluate() + + assert "overall" in results + assert "entities" in results + assert "strict" in results["overall"] + assert "partial" in results["overall"] + assert "ent_type" in results["overall"] + + +@pytest.fixture +def sample_data(): + true = [ + [Entity(label="PER", start=0, end=0), Entity(label="ORG", start=2, end=3), Entity(label="LOC", start=5, end=5)], + [Entity(label="PER", start=0, end=0), Entity(label="ORG", start=2, end=2)], + ] + + pred = [ + [ + Entity(label="PER", start=0, end=0), # Correct + Entity(label="ORG", start=2, end=2), # Partial + Entity(label="PER", start=5, end=5), # Wrong type + ], + [Entity(label="PER", start=0, end=0), Entity(label="LOC", start=2, end=2)], # Correct # Wrong type + ] + + return true, pred + + +def test_evaluator_initialization(sample_data): + """Test evaluator initialization.""" + true, pred = sample_data + evaluator = Evaluator(true, pred, ["PER", "ORG", "LOC"]) + + assert len(evaluator.true) == 2 + assert len(evaluator.pred) == 2 + assert evaluator.tags == ["PER", "ORG", "LOC"] + + +def test_evaluator_evaluation(sample_data): + """Test evaluation process.""" + true, pred = sample_data + evaluator = Evaluator(true, pred, ["PER", "ORG", "LOC"]) + results = evaluator.evaluate() + + # Check that we have results for all strategies + assert "strict" in results + assert "partial" in results + assert "ent_type" in results + + # Check that we have results for overall and each entity type + for strategy in results: + assert "overall" in results[strategy] + assert "PER" in results[strategy] + assert "ORG" in results[strategy] + assert "LOC" in results[strategy] + + +def test_evaluator_dataframe_conversion(sample_data): + """Test conversion of results to DataFrame.""" + true, pred = sample_data + evaluator = Evaluator(true, pred, ["PER", "ORG", "LOC"]) + results = evaluator.evaluate() + df = evaluator.results_to_dataframe() + + assert isinstance(df, pd.DataFrame) + assert len(df) > 0 + assert "strategy" in df.columns + assert "entity_type" in df.columns + assert "precision" in df.columns + assert "recall" in df.columns + assert "f1" in df.columns + + +def test_evaluator_with_empty_inputs(): + """Test evaluator with empty inputs.""" + evaluator = Evaluator([], [], ["PER", "ORG", "LOC"]) + results = evaluator.evaluate() + + for strategy in results: + assert results[strategy]["overall"].correct == 0 + assert results[strategy]["overall"].incorrect == 0 + assert results[strategy]["overall"].partial == 0 + assert results[strategy]["overall"].missed == 0 + assert results[strategy]["overall"].spurious == 0 + + +def test_evaluator_with_invalid_tags(sample_data): + """Test evaluator with invalid tags.""" + true, pred = sample_data + evaluator = Evaluator(true, pred, ["INVALID"]) + results = evaluator.evaluate() + + for strategy in results: + assert results[strategy]["overall"].correct == 0 + assert results[strategy]["overall"].incorrect == 0 + assert results[strategy]["overall"].partial == 0 + assert results[strategy]["overall"].missed == 0 + assert results[strategy]["overall"].spurious == 0 From 93a39075a3886b4a7db9b3afdb732e1229854871 Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Tue, 20 May 2025 22:44:23 +0200 Subject: [PATCH 08/41] fixing all tests --- src/nervaluate/evaluation_strategies.py | 8 - tests/test_evaluator_new.py | 215 ++++++++++++------------ 2 files changed, 103 insertions(+), 120 deletions(-) diff --git a/src/nervaluate/evaluation_strategies.py b/src/nervaluate/evaluation_strategies.py index ccb19c9..9f6195b 100644 --- a/src/nervaluate/evaluation_strategies.py +++ b/src/nervaluate/evaluation_strategies.py @@ -28,12 +28,6 @@ def evaluate( indices = EvaluationIndices() for pred_idx, pred in enumerate(pred_entities): - - print(pred) - print(pred.start, pred.end) - print(pred.label) - print(true_entities) - if pred in true_entities: result.correct += 1 indices.correct_indices.append((instance_index, pred_idx)) @@ -117,8 +111,6 @@ def evaluate( found_overlap = False for true_idx, true in enumerate(true_entities): - print(f"Checking {pred} against {true}") - # check for a minimum overlap between the system tagged entity and the gold annotation if pred.start <= true.end and pred.end >= true.start: found_overlap = True diff --git a/tests/test_evaluator_new.py b/tests/test_evaluator_new.py index 8ff7f04..32794cc 100644 --- a/tests/test_evaluator_new.py +++ b/tests/test_evaluator_new.py @@ -1,103 +1,121 @@ -import pandas as pd import pytest -from nervaluate.entities import Entity from nervaluate.evaluator import Evaluator -from nervaluate.evaluation_strategies import StrictEvaluation, PartialEvaluation, EntityTypeEvaluation @pytest.fixture def sample_entities(): return [ - Entity(label="PER", start=0, end=0), - Entity(label="ORG", start=2, end=3), - Entity(label="LOC", start=5, end=5), + ["O", "B-PER", "O", "B-ORG", "I-ORG", "B-LOC"], ] @pytest.fixture def sample_predictions(): return [ - Entity(label="PER", start=0, end=0), # Correct - Entity(label="ORG", start=2, end=2), # Partial - Entity(label="PER", start=5, end=5), # Wrong type + ["O", "B-PER", "O", "B-ORG", "O", "B-PER"], ] def test_strict_evaluation(sample_entities, sample_predictions): - strategy = StrictEvaluation() - result, indices = strategy.evaluate(sample_entities, sample_predictions, ["PER", "ORG", "LOC"]) + evaluator = Evaluator(sample_entities, sample_predictions, ["PER", "ORG", "LOC"], loader="list") + results = evaluator.evaluate() - assert result.correct == 1 - assert result.incorrect == 0 - assert result.partial == 0 - assert result.missed == 2 - assert result.spurious == 2 + # Test overall results + assert results["overall"]["strict"].correct == 1 + assert results["overall"]["strict"].incorrect == 0 + assert results["overall"]["strict"].partial == 0 + assert results["overall"]["strict"].missed == 2 + assert results["overall"]["strict"].spurious == 2 + assert results["overall"]["strict"].precision == 0.3333333333333333 + assert results["overall"]["strict"].recall == 0.3333333333333333 + assert results["overall"]["strict"].f1 == 0.3333333333333333 + assert results["overall"]["strict"].actual == 3 + assert results["overall"]["strict"].possible == 3 + + # Test entity-specific results + for entity in ["PER", "ORG", "LOC"]: + assert results["entities"][entity]["strict"].correct == 1 + assert results["entities"][entity]["strict"].incorrect == 0 + assert results["entities"][entity]["strict"].partial == 0 + assert results["entities"][entity]["strict"].missed == 2 + assert results["entities"][entity]["strict"].spurious == 2 + assert results["entities"][entity]["strict"].precision == 0.3333333333333333 + assert results["entities"][entity]["strict"].recall == 0.3333333333333333 + assert results["entities"][entity]["strict"].f1 == 0.3333333333333333 + assert results["entities"][entity]["strict"].actual == 3 + assert results["entities"][entity]["strict"].possible == 3 def test_partial_evaluation(sample_entities, sample_predictions): - strategy = PartialEvaluation() - result, indices = strategy.evaluate(sample_entities, sample_predictions, ["PER", "ORG", "LOC"]) - - assert result.correct == 1 - assert result.partial == 1 - assert result.incorrect == 1 - assert result.missed == 1 - assert result.spurious == 0 - - -def test_entity_type_evaluation(sample_entities, sample_predictions): - strategy = EntityTypeEvaluation() - result, indices = strategy.evaluate(sample_entities, sample_predictions, ["PER", "ORG", "LOC"]) - - assert result.correct == 2 - assert result.incorrect == 0 - assert result.partial == 0 - assert result.missed == 1 - assert result.spurious == 1 - - -def test_evaluator_integration(): - # Test with list format - true = [["O", "PER", "O", "ORG", "ORG", "LOC"]] - pred = [["O", "PER", "O", "ORG", "O", "PER"]] - - evaluator = OldEvaluator(true, pred, ["PER", "ORG", "LOC"], loader="list") + evaluator = Evaluator(sample_entities, sample_predictions, ["PER", "ORG", "LOC"], loader="list") results = evaluator.evaluate() - assert "overall" in results - assert "entities" in results - assert "strict" in results["overall"] - assert "partial" in results["overall"] - assert "ent_type" in results["overall"] + # Test overall results + assert results["overall"]["partial"].correct == 1 + assert results["overall"]["partial"].partial == 1 + assert results["overall"]["partial"].incorrect == 1 + assert results["overall"]["partial"].missed == 1 + assert results["overall"]["partial"].spurious == 0 + assert results["overall"]["partial"].precision == 0.5 + assert results["overall"]["partial"].recall == 0.375 + assert results["overall"]["partial"].f1 == 0.42857142857142855 + assert results["overall"]["partial"].actual == 3 + assert results["overall"]["partial"].possible == 4 + + # Test entity-specific results + for entity in ["PER", "ORG", "LOC"]: + assert results["entities"][entity]["partial"].correct == 1 + assert results["entities"][entity]["partial"].partial == 1 + assert results["entities"][entity]["partial"].incorrect == 1 + assert results["entities"][entity]["partial"].missed == 1 + assert results["entities"][entity]["partial"].spurious == 0 + assert results["entities"][entity]["partial"].precision == 0.5 + assert results["entities"][entity]["partial"].recall == 0.375 + assert results["entities"][entity]["partial"].f1 == 0.42857142857142855 + assert results["entities"][entity]["partial"].actual == 3 + assert results["entities"][entity]["partial"].possible == 4 - # Test with CoNLL format - true_conll = "word\tO\nword\tPER\nword\tO\nword\tORG\nword\tORG\nword\tLOC\n\n" - pred_conll = "word\tO\nword\tPER\nword\tO\nword\tORG\nword\tO\nword\tPER\n\n" - evaluator = OldEvaluator(true_conll, pred_conll, ["PER", "ORG", "LOC"], loader="conll") +def test_entity_type_evaluation(sample_entities, sample_predictions): + evaluator = Evaluator(sample_entities, sample_predictions, ["PER", "ORG", "LOC"], loader="list") results = evaluator.evaluate() - assert "overall" in results - assert "entities" in results - assert "strict" in results["overall"] - assert "partial" in results["overall"] - assert "ent_type" in results["overall"] + # Test overall results + assert results["overall"]["ent_type"].correct == 2 + assert results["overall"]["ent_type"].incorrect == 0 + assert results["overall"]["ent_type"].partial == 0 + assert results["overall"]["ent_type"].missed == 1 + assert results["overall"]["ent_type"].spurious == 1 + assert results["overall"]["ent_type"].precision == 0.6666666666666666 + assert results["overall"]["ent_type"].recall == 0.6666666666666666 + assert results["overall"]["ent_type"].f1 == 0.6666666666666666 + assert results["overall"]["ent_type"].actual == 3 + assert results["overall"]["ent_type"].possible == 3 + + # Test entity-specific results + for entity in ["PER", "ORG", "LOC"]: + assert results["entities"][entity]["ent_type"].correct == 2 + assert results["entities"][entity]["ent_type"].incorrect == 0 + assert results["entities"][entity]["ent_type"].partial == 0 + assert results["entities"][entity]["ent_type"].missed == 1 + assert results["entities"][entity]["ent_type"].spurious == 1 + assert results["entities"][entity]["ent_type"].precision == 0.6666666666666666 + assert results["entities"][entity]["ent_type"].recall == 0.6666666666666666 + assert results["entities"][entity]["ent_type"].f1 == 0.6666666666666666 + assert results["entities"][entity]["ent_type"].actual == 3 + assert results["entities"][entity]["ent_type"].possible == 3 @pytest.fixture def sample_data(): true = [ - [Entity(label="PER", start=0, end=0), Entity(label="ORG", start=2, end=3), Entity(label="LOC", start=5, end=5)], - [Entity(label="PER", start=0, end=0), Entity(label="ORG", start=2, end=2)], + ["O", "B-PER", "O", "B-ORG", "I-ORG", "B-LOC"], + ["O", "B-PER", "O", "B-ORG"], ] pred = [ - [ - Entity(label="PER", start=0, end=0), # Correct - Entity(label="ORG", start=2, end=2), # Partial - Entity(label="PER", start=5, end=5), # Wrong type - ], - [Entity(label="PER", start=0, end=0), Entity(label="LOC", start=2, end=2)], # Correct # Wrong type + ["O", "B-PER", "O", "B-ORG", "O", "B-PER"], + ["O", "B-PER", "O", "B-LOC"], ] return true, pred @@ -106,7 +124,7 @@ def sample_data(): def test_evaluator_initialization(sample_data): """Test evaluator initialization.""" true, pred = sample_data - evaluator = Evaluator(true, pred, ["PER", "ORG", "LOC"]) + evaluator = Evaluator(true, pred, ["PER", "ORG", "LOC"], loader="list") assert len(evaluator.true) == 2 assert len(evaluator.pred) == 2 @@ -116,60 +134,33 @@ def test_evaluator_initialization(sample_data): def test_evaluator_evaluation(sample_data): """Test evaluation process.""" true, pred = sample_data - evaluator = Evaluator(true, pred, ["PER", "ORG", "LOC"]) + evaluator = Evaluator(true, pred, ["PER", "ORG", "LOC"], loader="list") results = evaluator.evaluate() # Check that we have results for all strategies - assert "strict" in results - assert "partial" in results - assert "ent_type" in results - - # Check that we have results for overall and each entity type - for strategy in results: - assert "overall" in results[strategy] - assert "PER" in results[strategy] - assert "ORG" in results[strategy] - assert "LOC" in results[strategy] - - -def test_evaluator_dataframe_conversion(sample_data): - """Test conversion of results to DataFrame.""" - true, pred = sample_data - evaluator = Evaluator(true, pred, ["PER", "ORG", "LOC"]) - results = evaluator.evaluate() - df = evaluator.results_to_dataframe() - - assert isinstance(df, pd.DataFrame) - assert len(df) > 0 - assert "strategy" in df.columns - assert "entity_type" in df.columns - assert "precision" in df.columns - assert "recall" in df.columns - assert "f1" in df.columns - - -def test_evaluator_with_empty_inputs(): - """Test evaluator with empty inputs.""" - evaluator = Evaluator([], [], ["PER", "ORG", "LOC"]) - results = evaluator.evaluate() + assert "overall" in results + assert "entities" in results + assert "strict" in results["overall"] + assert "partial" in results["overall"] + assert "ent_type" in results["overall"] - for strategy in results: - assert results[strategy]["overall"].correct == 0 - assert results[strategy]["overall"].incorrect == 0 - assert results[strategy]["overall"].partial == 0 - assert results[strategy]["overall"].missed == 0 - assert results[strategy]["overall"].spurious == 0 + # Check that we have results for each entity type + for entity in ["PER", "ORG", "LOC"]: + assert entity in results["entities"] + assert "strict" in results["entities"][entity] + assert "partial" in results["entities"][entity] + assert "ent_type" in results["entities"][entity] def test_evaluator_with_invalid_tags(sample_data): """Test evaluator with invalid tags.""" true, pred = sample_data - evaluator = Evaluator(true, pred, ["INVALID"]) + evaluator = Evaluator(true, pred, ["INVALID"], loader="list") results = evaluator.evaluate() - for strategy in results: - assert results[strategy]["overall"].correct == 0 - assert results[strategy]["overall"].incorrect == 0 - assert results[strategy]["overall"].partial == 0 - assert results[strategy]["overall"].missed == 0 - assert results[strategy]["overall"].spurious == 0 + for strategy in ["strict", "partial", "ent_type"]: + assert results["overall"][strategy].correct == 0 + assert results["overall"][strategy].incorrect == 0 + assert results["overall"][strategy].partial == 0 + assert results["overall"][strategy].missed == 0 + assert results["overall"][strategy].spurious == 0 From ab20cefbf6b256cec77f3b8f1a9907e7fccd61c6 Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Tue, 20 May 2025 22:59:01 +0200 Subject: [PATCH 09/41] type checking --- src/nervaluate/entities.py | 14 +++++++------- src/nervaluate/evaluator.py | 4 ++-- src/nervaluate/loaders.py | 2 +- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/nervaluate/entities.py b/src/nervaluate/entities.py index fc7598a..4561cab 100644 --- a/src/nervaluate/entities.py +++ b/src/nervaluate/entities.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import List +from typing import List, Tuple @dataclass @@ -55,13 +55,13 @@ def compute_metrics(self, partial_or_type: bool = False) -> None: class EvaluationIndices: """Represents the indices of entities in different evaluation categories.""" - correct_indices: List[tuple[int, int]] = None - incorrect_indices: List[tuple[int, int]] = None - partial_indices: List[tuple[int, int]] = None - missed_indices: List[tuple[int, int]] = None - spurious_indices: List[tuple[int, int]] = None + correct_indices: List[Tuple[int, int]] = None # type: ignore + incorrect_indices: List[Tuple[int, int]] = None # type: ignore + partial_indices: List[Tuple[int, int]] = None # type: ignore + missed_indices: List[Tuple[int, int]] = None # type: ignore + spurious_indices: List[Tuple[int, int]] = None # type: ignore - def __post_init__(self): + def __post_init__(self) -> None: if self.correct_indices is None: self.correct_indices = [] if self.incorrect_indices is None: diff --git a/src/nervaluate/evaluator.py b/src/nervaluate/evaluator.py index c697d82..4b7eaff 100644 --- a/src/nervaluate/evaluator.py +++ b/src/nervaluate/evaluator.py @@ -67,7 +67,7 @@ def evaluate(self) -> Dict[str, Any]: Dictionary containing evaluation results for each strategy and entity type """ results = {} - entity_results = {tag: {} for tag in self.tags} + entity_results: Dict[str, Dict[str, EvaluationResult]] = {tag: {} for tag in self.tags} # Evaluate each document for doc_idx, (true_doc, pred_doc) in enumerate(zip(self.true, self.pred)): @@ -105,7 +105,7 @@ def _merge_results(self, target: EvaluationResult, source: EvaluationResult) -> target.spurious += source.spurious target.compute_metrics() - def results_to_dataframe(self) -> pd.DataFrame: + def results_to_dataframe(self) -> Any: """Convert results to a pandas DataFrame.""" results = self.evaluate() diff --git a/src/nervaluate/loaders.py b/src/nervaluate/loaders.py index 28d70be..8cd9ac6 100644 --- a/src/nervaluate/loaders.py +++ b/src/nervaluate/loaders.py @@ -23,7 +23,7 @@ def load(self, data: str) -> List[List[Entity]]: if not data: return [] - result = [] + result: List[List[Entity]] = [] # Strip trailing whitespace and newlines to avoid empty documents documents = data.rstrip().split("\n\n") From 303878629a56fdd3dd7d69d099a9d643ccaf9c1c Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Tue, 20 May 2025 23:01:32 +0200 Subject: [PATCH 10/41] adding missed pyproject.toml --- pyproject.toml | 147 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 147 insertions(+) create mode 100644 pyproject.toml diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..2f5b702 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,147 @@ +[build-system] +requires = ["setuptools", "setuptools-scm"] +build-backend = "setuptools.build_meta" + +[project] +name = "nervaluate" +version = "0.2.0" +authors = [ + { name="David S. Batista"}, + { name="Matthew Upson"} +] +description = "NER evaluation considering partial match scoring" +readme = "README.md" +requires-python = ">=3.11" +keywords = ["named-entity-recognition", "ner", "evaluation-metrics", "partial-match-scoring", "nlp"] +license = {text = "MIT License"} +classifiers = [ + "Programming Language :: Python :: 3", + "Operating System :: OS Independent" +] + +dependencies = [ + "pandas==2.2.3" +] + +[project.optional-dependencies] +dev = [ + "black==24.3.0", + "coverage==7.2.5", + "gitchangelog", + "mypy==1.3.0", + "pre-commit==3.3.1", + "pylint==2.17.4", + "pytest==7.3.1", + "pytest-cov==4.1.0", +] + +[project.urls] +"Homepage" = "https://github.com/MantisAI/nervaluate" +"Bug Tracker" = "https://github.com/MantisAI/nervaluate/issues" + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = ["test_*.py"] +addopts = "--cov=nervaluate --cov-report=term-missing" + +[tool.coverage.run] +source = ["nervaluate"] +omit = ["*__init__*"] + +[tool.coverage.report] +show_missing = true +precision = 2 +sort = "Miss" + +[tool.black] +line-length = 120 +target-version = ["py311"] + +[tool.pylint.messages_control] +disable = [ + "C0111", # missing-docstring + "C0103", # invalid-name + "W0511", # fixme + "W0603", # global-statement + "W1202", # logging-format-interpolation + "W1203", # logging-fstring-interpolation + "E1126", # invalid-sequence-index + "E1137", # invalid-slice-index + "I0011", # bad-option-value + "I0020", # bad-option-value + "R0801", # duplicate-code + "W9020", # bad-option-value + "W0621", # redefined-outer-name +] + +[tool.pylint.'DESIGN'] +max-args = 38 # Default is 5 +max-attributes = 28 # Default is 7 +max-branches = 14 # Default is 12 +max-locals = 45 # Default is 15 +max-module-lines = 2468 # Default is 1000 +max-nested-blocks = 9 # Default is 5 +max-statements = 206 # Default is 50 +min-public-methods = 1 # Allow classes with just one public method + +[tool.pylint.format] +max-line-length = 120 + +[tool.pylint.basic] +accept-no-param-doc = true +accept-no-raise-doc = true +accept-no-return-doc = true +accept-no-yields-doc = true +default-docstring-type = "numpy" + +[tool.pylint.master] +load-plugins = ["pylint.extensions.docparams"] +ignore-paths = ["./examples/.*"] + +[tool.flake8] +max-line-length = 120 +extend-ignore = ["E203"] +exclude = [".git", "__pycache__", "build", "dist", "./examples/*"] +max-complexity = 10 +per-file-ignores = ["*/__init__.py: F401"] + +[tool.mypy] +python_version = "3.11" +ignore_missing_imports = true +disallow_any_unimported = true +disallow_untyped_defs = true +warn_redundant_casts = true +warn_unused_ignores = true +warn_unused_configs = true + +[[tool.mypy.overrides]] +module = "examples.*" +follow_imports = "skip" + +[tool.hatch.envs.dev] +dependencies = [ + "black==24.3.0", + "coverage==7.2.5", + "gitchangelog", + "mypy==1.3.0", + "pre-commit==3.3.1", + "pylint==2.17.4", + "pytest==7.3.1", + "pytest-cov==4.1.0", +] + +[tool.hatch.envs.dev.scripts] +lint = [ + "black -t py311 -l 120 src tests", + "pylint src tests" +] +typing = "mypy src" +test = "pytest" +clean = "rm -rf dist src/nervaluate.egg-info .coverage .mypy_cache .pytest_cache" +changelog = "gitchangelog > CHANGELOG.rst" +all = [ + "clean", + "lint", + "typing", + "test" +] From 21123d9b4d810b71116573113e63662bf2256e7a Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Tue, 20 May 2025 23:40:26 +0200 Subject: [PATCH 11/41] cleaning up README.MD --- README.md | 167 +++++++++++++++++++++++++++++------------------------- 1 file changed, 90 insertions(+), 77 deletions(-) diff --git a/README.md b/README.md index 31f4b47..b057601 100644 --- a/README.md +++ b/README.md @@ -12,14 +12,15 @@ # nervaluate -`nervaluate` is a python module for evaluating Named Entity Recognition (NER) models as defined in the SemEval 2013 - 9.1 task. +`nervaluate` is a module for evaluating Named Entity Recognition (NER) models as defined in the SemEval 2013 - 9.1 task. The evaluation metrics output by nervaluate go beyond a simple token/tag based schema, and consider different scenarios based on whether all the tokens that belong to a named entity were classified or not, and also whether the correct entity type was assigned. This full problem is described in detail in the [original blog](http://www.davidsbatista.net/blog/2018/05/09/Named_Entity_Evaluation/) -post by [David Batista](https://github.com/davidsbatista), and extends the code in the [original repository](https://github.com/davidsbatista/NER-Evaluation) which accompanied the blog post. +post by [David Batista](https://github.com/davidsbatista), and extends the code in the[original repository](https://github.com/davidsbatista/NER-Evaluation) +which accompanied the blog post. The code draws heavily on: @@ -35,30 +36,30 @@ When comparing the golden standard annotations with the output of a NER system d __I. Surface string and entity type match__ -|Token|Gold|Prediction| -|---|---|---| -|in|O|O| -|New|B-LOC|B-LOC| -|York|I-LOC|I-LOC| -|.|O|O| +| Token | Gold | Prediction | +|-------|-------|------------| +| in | O | O | +| New | B-LOC | B-LOC | +| York | I-LOC | I-LOC | +| . | O | O | __II. System hypothesized an incorrect entity__ -|Token|Gold|Prediction| -|---|---|---| -|an|O|O| -|Awful|O|B-ORG| -|Headache|O|I-ORG| -|in|O|O| +| Token | Gold | Prediction | +|----------|------|------------| +| an | O | O | +| Awful | O | B-ORG | +| Headache | O | I-ORG | +| in | O | O | __III. System misses an entity__ -|Token|Gold|Prediction| -|---|---|---| -|in|O|O| -|Palo|B-LOC|O| -|Alto|I-LOC|O| -|,|O|O| +| Token | Gold | Prediction | +|-------|-------|------------| +| in | O | O | +| Palo | B-LOC | O | +| Alto | I-LOC | O | +| , | O | O | Based on these three scenarios we have a simple classification evaluation that can be measured in terms of false positives, true positives, false negatives and false positives, and subsequently compute precision, recall and @@ -72,65 +73,65 @@ For example: __IV. System assigns the wrong entity type__ -|Token|Gold|Prediction| -|---|---|---| -|I|O|O| -|live|O|O| -|in|O|O| -|Palo|B-LOC|B-ORG| -|Alto|I-LOC|I-ORG| -|,|O|O| +| Token | Gold | Prediction | +|-------|-------|------------| +| I | O | O | +| live | O | O | +| in | O | O | +| Palo | B-LOC | B-ORG | +| Alto | I-LOC | I-ORG | +| , | O | O | __V. System gets the boundaries of the surface string wrong__ -|Token|Gold|Prediction| -|---|---|---| -|Unless|O|B-PER| -|Karl|B-PER|I-PER| -|Smith|I-PER|I-PER| -|resigns|O|O| +| Token | Gold | Prediction | +|---------|-------|------------| +| Unless | O | B-PER | +| Karl | B-PER | I-PER | +| Smith | I-PER | I-PER | +| resigns | O | O | __VI. System gets the boundaries and entity type wrong__ -|Token|Gold|Prediction| -|---|---|---| -|Unless|O|B-ORG| -|Karl|B-PER|I-ORG| -|Smith|I-PER|I-ORG| -|resigns|O|O| +| Token | Gold | Prediction | +|---------|-------|------------| +| Unless | O | B-ORG | +| Karl | B-PER | I-ORG | +| Smith | I-PER | I-ORG | +| resigns | O | O | How can we incorporate these described scenarios into evaluation metrics? See the [original blog](http://www.davidsbatista.net/blog/2018/05/09/Named_Entity_Evaluation/) for a great explanation, a summary is included here: We can use the following five metrics to consider difference categories of errors: -|Error type|Explanation| -|---|---| -|Correct (COR)|both are the same| -|Incorrect (INC)|the output of a system and the golden annotation don’t match| -|Partial (PAR)|system and the golden annotation are somewhat “similar” but not the same| -|Missing (MIS)|a golden annotation is not captured by a system| -|Spurious (SPU)|system produces a response which doesn’t exist in the golden annotation| +| Error type | Explanation | +|-----------------|--------------------------------------------------------------------------| +| Correct (COR) | both are the same | +| Incorrect (INC) | the output of a system and the golden annotation don’t match | +| Partial (PAR) | system and the golden annotation are somewhat “similar” but not the same | +| Missing (MIS) | a golden annotation is not captured by a system | +| Spurious (SPU) | system produces a response which doesn’t exist in the golden annotation | These five metrics can be measured in four different ways: -|Evaluation schema|Explanation| -|---|---| -|Strict|exact boundary surface string match and entity type| -|Exact|exact boundary match over the surface string, regardless of the type| -|Partial|partial boundary match over the surface string, regardless of the type| -|Type|some overlap between the system tagged entity and the gold annotation is required| +| Evaluation schema | Explanation | +|-------------------|-----------------------------------------------------------------------------------| +| Strict | exact boundary surface string match and entity type | +| Exact | exact boundary match over the surface string, regardless of the type | +| Partial | partial boundary match over the surface string, regardless of the type | +| Type | some overlap between the system tagged entity and the gold annotation is required | These five errors and four evaluation schema interact in the following ways: -|Scenario|Gold entity|Gold string|Pred entity|Pred string|Type|Partial|Exact|Strict| -|---|---|---|---|---|---|---|---|---| -|III|BRAND|tikosyn| | |MIS|MIS|MIS|MIS| -|II| | |BRAND|healthy|SPU|SPU|SPU|SPU| -|V|DRUG|warfarin|DRUG|of warfarin|COR|PAR|INC|INC| -|IV|DRUG|propranolol|BRAND|propranolol|INC|COR|COR|INC| -|I|DRUG|phenytoin|DRUG|phenytoin|COR|COR|COR|COR| -|VI|GROUP|contraceptives|DRUG|oral contraceptives|INC|PAR|INC|INC| +| Scenario | Gold entity | Gold string | Pred entity | Pred string | Type | Partial | Exact | Strict | +|----------|-------------|----------------|-------------|---------------------|------|---------|-------|--------| +| III | BRAND | tikosyn | | | MIS | MIS | MIS | MIS | +| II | | | BRAND | healthy | SPU | SPU | SPU | SPU | +| V | DRUG | warfarin | DRUG | of warfarin | COR | PAR | INC | INC | +| IV | DRUG | propranolol | BRAND | propranolol | INC | COR | COR | INC | +| I | DRUG | phenytoin | DRUG | phenytoin | COR | COR | COR | COR | +| VI | GROUP | contraceptives | DRUG | oral contraceptives | INC | PAR | INC | INC | Then precision/recall/f1-score are calculated for each different evaluation schema. In order to achieve data, two more quantities need to be calculated: @@ -188,12 +189,14 @@ pip install nervaluate ## Example: -The main `Evaluator` class will accept a number of formats: +The main `Evaluator` class will accept the following formats: * [prodi.gy](https://prodi.gy) style lists of spans. * Nested lists containing NER labels. * CoNLL style tab delimited strings. + ### Nested lists ``` + +from nervaluate import Evaluator + true = [ - ['O', 'O', 'B-PER', 'I-PER', 'O'], - ['O', 'B-LOC', 'I-LOC', 'B-LOC', 'I-LOC', 'O'], + ['O', 'B-PER', 'I-PER', 'O', 'O', 'O', 'B-ORG', 'I-ORG'], + ['O', 'B-LOC', 'B-PER', 'I-PER', 'O', 'O', 'B-DATE'], ] pred = [ - ['O', 'O', 'B-PER', 'I-PER', 'O'], - ['O', 'B-LOC', 'I-LOC', 'B-LOC', 'I-LOC', 'O'], + ['O', 'O', 'B-PER', 'I-PER', 'O', 'O', 'B-ORG', 'I-ORG', 'O'], + ['O', 'B-LOC', 'I-LOC', 'O', 'B-PER', 'I-PER', 'O', 'B-DATE', 'I-DATE', 'O'], ] -evaluator = Evaluator(true, pred, tags=['LOC', 'PER'], loader="list") +# Example text for reference: +# "The John Smith who works at Google Inc" +# "In Paris Marie Curie lived in 1895" + +evaluator = Evaluator(true, pred, tags=['PER', 'ORG', 'LOC', 'DATE'], loader="list") results, results_by_tag, result_indices, result_indices_by_tag = evaluator.evaluate() ``` + -## Extending the package to accept more formats +## Contributing to the `nervaluate` package -Additional formats can easily be added to the module by creating a conversion function in `nervaluate/utils.py`, -for example `conll_to_spans()`. This function must return the spans in the prodigy style dicts shown in the prodigy -example above. +### Extending the package to accept more formats -The new function can then be added to the list of loaders in `nervaluate/nervaluate.py`, and can then be selection -with the `loader` argument when instantiating the `Evaluator` class. +Additional formats can easily be added to the module by creating a new loader class in `nervaluate/loaders.py`. The +loader class should inherit from the `DataLoader` base class and implement the `load` method. The `load` method should + return a list of entity lists, where each entity is represented as a dictionary with `label`, `start`, and `end` keys. -A list of formats we intend to include is included in https://github.com/ivyleavedtoadflax/nervaluate/issues/3. +The new loader can then be added to the `_setup_loaders` method in the `Evaluator` class, and can be selected with the + `loader` argument when instantiating the `Evaluator` class. +Here is list of formats we intend to [include](https://github.com/MantisAI/nervaluate/issues/3). -## Contributing to the nervaluate package +### General Contributing -Improvements, adding new features and bug fixes are welcome. If you wish to participate in the development of nervaluate +Improvements, adding new features and bug fixes are welcome. If you wish to participate in the development of `nervaluate` please read the guidelines in the [CONTRIBUTING.md](CONTRIBUTING.md) file. --- From 93b3049fa48c7d8065eb71a6bfbab5490a211362 Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Wed, 21 May 2025 00:42:41 +0200 Subject: [PATCH 12/41] working on new versions of summary reports --- src/nervaluate/evaluator.py | 30 +++- src/nervaluate/reporting.py | 308 ++++++++++++++++++++++++++++-------- tests/test_evaluator_new.py | 31 ++++ 3 files changed, 298 insertions(+), 71 deletions(-) diff --git a/src/nervaluate/evaluator.py b/src/nervaluate/evaluator.py index 4b7eaff..59e6122 100644 --- a/src/nervaluate/evaluator.py +++ b/src/nervaluate/evaluator.py @@ -1,7 +1,7 @@ from typing import List, Dict, Any import pandas as pd -from .entities import EvaluationResult +from .entities import EvaluationResult, EvaluationIndices from .evaluation_strategies import EvaluationStrategy, StrictEvaluation, PartialEvaluation, EntityTypeEvaluation from .loaders import DataLoader, ConllLoader, ListLoader, DictLoader @@ -68,6 +68,8 @@ def evaluate(self) -> Dict[str, Any]: """ results = {} entity_results: Dict[str, Dict[str, EvaluationResult]] = {tag: {} for tag in self.tags} + indices = {} + entity_indices: Dict[str, Dict[str, EvaluationIndices]] = {tag: {} for tag in self.tags} # Evaluate each document for doc_idx, (true_doc, pred_doc) in enumerate(zip(self.true, self.pred)): @@ -77,26 +79,37 @@ def evaluate(self) -> Dict[str, Any]: # Evaluate with each strategy for strategy_name, strategy in self.strategies.items(): - result, _ = strategy.evaluate(true_doc, pred_doc, self.tags, doc_idx) + result, doc_indices = strategy.evaluate(true_doc, pred_doc, self.tags, doc_idx) # Update overall results if strategy_name not in results: results[strategy_name] = result + indices[strategy_name] = doc_indices else: self._merge_results(results[strategy_name], result) + self._merge_indices(indices[strategy_name], doc_indices) # Update entity-specific results for tag in self.tags: if tag not in entity_results: entity_results[tag] = {} + entity_indices[tag] = {} if strategy_name not in entity_results[tag]: entity_results[tag][strategy_name] = result + entity_indices[tag][strategy_name] = doc_indices else: self._merge_results(entity_results[tag][strategy_name], result) + self._merge_indices(entity_indices[tag][strategy_name], doc_indices) - return {"overall": results, "entities": entity_results} + return { + "overall": results, + "entities": entity_results, + "overall_indices": indices, + "entity_indices": entity_indices, + } - def _merge_results(self, target: EvaluationResult, source: EvaluationResult) -> None: + @staticmethod + def _merge_results(target: EvaluationResult, source: EvaluationResult) -> None: """Merge two evaluation results.""" target.correct += source.correct target.incorrect += source.incorrect @@ -105,6 +118,15 @@ def _merge_results(self, target: EvaluationResult, source: EvaluationResult) -> target.spurious += source.spurious target.compute_metrics() + @staticmethod + def _merge_indices(target: EvaluationIndices, source: EvaluationIndices) -> None: + """Merge two evaluation indices.""" + target.correct_indices.extend(source.correct_indices) + target.incorrect_indices.extend(source.incorrect_indices) + target.partial_indices.extend(source.partial_indices) + target.missed_indices.extend(source.missed_indices) + target.spurious_indices.extend(source.spurious_indices) + def results_to_dataframe(self) -> Any: """Convert results to a pandas DataFrame.""" results = self.evaluate() diff --git a/src/nervaluate/reporting.py b/src/nervaluate/reporting.py index 94bf57f..bdb2119 100644 --- a/src/nervaluate/reporting.py +++ b/src/nervaluate/reporting.py @@ -1,47 +1,78 @@ -def summary_report_ent(results_agg_entities_type: dict, scenario: str = "strict", digits: int = 2) -> str: +from typing import Union +from xml.dom.minidom import Entity + + +def summary_report(results: dict, mode: str = "overall", scenario: str = "strict", digits: int = 2) -> str: """ - Generate a summary report of the evaluation results for a given scenario. + Generate a summary report of the evaluation results. - :param results_agg_entities_type: Dictionary containing the evaluation results. + :param results: Dictionary containing the evaluation results. + :param mode: Either 'overall' for overall metrics or 'entities' for per-entity metrics. :param scenario: The scenario to report on. Must be one of: 'strict', 'ent_type', 'partial', 'exact'. - Defaults to 'strict'. + Only used when mode is 'entities'. Defaults to 'strict'. :param digits: The number of digits to round the results to. :returns: A string containing the summary report. - :raises ValueError: - If the scenario is invalid. + :raises: + ValueError: If the scenario or mode is invalid. """ valid_scenarios = {"strict", "ent_type", "partial", "exact"} - if scenario not in valid_scenarios: + valid_modes = {"overall", "entities"} + + if mode not in valid_modes: + raise ValueError(f"Invalid mode: must be one of {valid_modes}") + + if mode == "entities" and scenario not in valid_scenarios: raise ValueError(f"Invalid scenario: must be one of {valid_scenarios}") - target_names = sorted(results_agg_entities_type.keys()) headers = ["correct", "incorrect", "partial", "missed", "spurious", "precision", "recall", "f1-score"] rows = [headers] - # Aggregate results by entity type for the specified scenario - for ent_type in target_names: - if scenario not in results_agg_entities_type[ent_type]: - raise ValueError(f"Scenario '{scenario}' not found in results for entity type '{ent_type}'") - - results = results_agg_entities_type[ent_type][scenario] - rows.append( - [ - ent_type, - results["correct"], - results["incorrect"], - results["partial"], - results["missed"], - results["spurious"], - results["precision"], - results["recall"], - results["f1"], - ] - ) - - name_width = max(len(cn) for cn in target_names) + if mode == "overall": + # Process overall results + for eval_schema in valid_scenarios: + if eval_schema not in results: + continue + results_schema = results[eval_schema] + rows.append( + [ + eval_schema, + results_schema["correct"], + results_schema["incorrect"], + results_schema["partial"], + results_schema["missed"], + results_schema["spurious"], + results_schema["precision"], + results_schema["recall"], + results_schema["f1"], + ] + ) + else: + # Process entity-specific results + target_names = sorted(results.keys()) + for ent_type in target_names: + if scenario not in results[ent_type]: + raise ValueError(f"Scenario '{scenario}' not found in results for entity type '{ent_type}'") + + results_ent = results[ent_type][scenario] + rows.append( + [ + ent_type, + results_ent["correct"], + results_ent["incorrect"], + results_ent["partial"], + results_ent["missed"], + results_ent["spurious"], + results_ent["precision"], + results_ent["recall"], + results_ent["f1"], + ] + ) + + # Format the report + name_width = max(len(str(row[0])) for row in rows) width = max(name_width, digits) head_fmt = "{:>{width}s} " + " {:>11}" * len(headers) report = head_fmt.format("", *headers, width=width) @@ -54,46 +85,15 @@ def summary_report_ent(results_agg_entities_type: dict, scenario: str = "strict" return report -def summary_report_overall(results: dict, digits: int = 2) -> str: - """ - Generate a summary report of the evaluation results for the overall scenario. - - :param results: Dictionary containing the evaluation results. - :param digits: The number of digits to round the results to. - - :returns: - A string containing the summary report. - """ - headers = ["correct", "incorrect", "partial", "missed", "spurious", "precision", "recall", "f1-score"] - rows = [headers] - - for k, v in results.items(): - rows.append( - [ - k, - v["correct"], - v["incorrect"], - v["partial"], - v["missed"], - v["spurious"], - v["precision"], - v["recall"], - v["f1"], - ] - ) - - target_names = sorted(results.keys()) - name_width = max(len(cn) for cn in target_names) - width = max(name_width, digits) - head_fmt = "{:>{width}s} " + " {:>11}" * len(headers) - report = head_fmt.format("", *headers, width=width) - report += "\n\n" - row_fmt = "{:>{width}s} " + " {:>11}" * 5 + " {:>11.{digits}f}" * 3 + "\n" +# For backward compatibility +def summary_report_ent(results_agg_entities_type: dict, scenario: str = "strict", digits: int = 2) -> str: + """Alias for summary_report with mode='entities'""" + return summary_report(results_agg_entities_type, mode="entities", scenario=scenario, digits=digits) - for row in rows[1:]: - report += row_fmt.format(*row, width=width, digits=digits) - return report +def summary_report_overall(results: dict, digits: int = 2) -> str: + """Alias for summary_report with mode='overall'""" + return summary_report(results, mode="overall", digits=digits) def summary_report_ents_indices(evaluation_agg_indices: dict, error_schema: str, preds: list | None = None) -> str: @@ -164,3 +164,177 @@ def summary_report_overall_indices(evaluation_indices: dict, error_schema: str, report += "\n" return report + + +def summary_report_v2(results: dict, mode: str = "overall", scenario: str = "strict", digits: int = 2) -> str: + """ + Generate a summary report of the evaluation results for the new Evaluator class. + + Args: + results: Dictionary containing the evaluation results from the new Evaluator class. + mode: Either 'overall' for overall metrics or 'entities' for per-entity metrics. + scenario: The scenario to report on. Must be one of: 'strict', 'ent_type', 'partial', 'exact'. + Only used when mode is 'entities'. Defaults to 'strict'. + digits: The number of digits to round the results to. + + Returns: + A string containing the summary report. + + Raises: + ValueError: If the scenario or mode is invalid. + """ + valid_scenarios = {"strict", "ent_type", "partial", "exact"} + valid_modes = {"overall", "entities"} + + if mode not in valid_modes: + raise ValueError(f"Invalid mode: must be one of {valid_modes}") + + if mode == "entities" and scenario not in valid_scenarios: + raise ValueError(f"Invalid scenario: must be one of {valid_scenarios}") + + headers = ["correct", "incorrect", "partial", "missed", "spurious", "precision", "recall", "f1-score"] + rows = [headers] + + if mode == "overall": + # Process overall results + results_data = results["overall"] + for eval_schema in valid_scenarios: + if eval_schema not in results_data: + continue + results_schema = results_data[eval_schema] + rows.append( + [ + eval_schema, + results_schema.correct, + results_schema.incorrect, + results_schema.partial, + results_schema.missed, + results_schema.spurious, + results_schema.precision, + results_schema.recall, + results_schema.f1, + ] + ) + else: + # Process entity-specific results + results_data = results["entities"] + target_names = sorted(results_data.keys()) + for ent_type in target_names: + if scenario not in results_data[ent_type]: + raise ValueError(f"Scenario '{scenario}' not found in results for entity type '{ent_type}'") + + results_ent = results_data[ent_type][scenario] + rows.append( + [ + ent_type, + results_ent.correct, + results_ent.incorrect, + results_ent.partial, + results_ent.missed, + results_ent.spurious, + results_ent.precision, + results_ent.recall, + results_ent.f1, + ] + ) + + # Format the report + name_width = max(len(str(row[0])) for row in rows) + width = max(name_width, digits) + head_fmt = "{:>{width}s} " + " {:>11}" * len(headers) + report = head_fmt.format("", *headers, width=width) + report += "\n\n" + row_fmt = "{:>{width}s} " + " {:>11}" * 5 + " {:>11.{digits}f}" * 3 + "\n" + + for row in rows[1:]: + report += row_fmt.format(*row, width=width, digits=digits) + + return report + + +def summary_report_indices_v2( # pylint: disable=too-many-branches + results: dict, mode: str = "overall", scenario: str = "strict", preds: list | None = None +) -> str: + """ + Generate a summary report of the evaluation indices for the new Evaluator class. + + Args: + results: Dictionary containing the evaluation results from the new Evaluator class. + mode: Either 'overall' for overall metrics or 'entities' for per-entity metrics. + scenario: The scenario to report on. Must be one of: 'strict', 'ent_type', 'partial', 'exact'. + Only used when mode is 'entities'. Defaults to 'strict'. + preds: List of predicted named entities. Can be either: + - List of lists of entity objects with label, start, end attributes + - List of lists of strings (BIO tags) + + Returns: + A string containing the summary report of indices. + + Raises: + ValueError: If the scenario or mode is invalid. + """ + valid_scenarios = {"strict", "ent_type", "partial", "exact"} + valid_modes = {"overall", "entities"} + + if mode not in valid_modes: + raise ValueError(f"Invalid mode: must be one of {valid_modes}") + + if mode == "entities" and scenario not in valid_scenarios: + raise ValueError(f"Invalid scenario: must be one of {valid_scenarios}") + + if preds is None: + preds = [[]] + + def get_prediction_info(pred: Union[Entity, str]) -> str: + """Helper function to get prediction info based on pred type.""" + if isinstance(pred, Entity): + return f"Label={pred.label}, Start={pred.start}, End={pred.end}" # type: ignore + # String (BIO tag) + return f"Tag={pred}" + + report = "" + if mode == "overall": + # Get the indices from the overall results + indices_data = results["overall_indices"][scenario] + report += f"Indices for error schema '{scenario}':\n\n" + + for category, indices in indices_data.__dict__.items(): + if not category.endswith("_indices"): + continue + category_name = category.replace("_indices", "").replace("_", " ").capitalize() + report += f"{category_name}:\n" + if indices: + for instance_index, entity_index in indices: + if preds != [[]]: + pred = preds[instance_index][entity_index] + prediction_info = get_prediction_info(pred) + report += f" - Instance {instance_index}, Entity {entity_index}: {prediction_info}\n" + else: + report += f" - Instance {instance_index}, Entity {entity_index}\n" + else: + report += " - None\n" + report += "\n" + else: + # Get the indices from the entity-specific results + for entity_type, entity_results in results["entity_indices"].items(): + report += f"\nEntity Type: {entity_type}\n" + error_data = entity_results[scenario] + report += f" Error Schema: '{scenario}'\n" + + for category, indices in error_data.__dict__.items(): + if not category.endswith("_indices"): + continue + category_name = category.replace("_indices", "").replace("_", " ").capitalize() + report += f" ({entity_type}) {category_name}:\n" + if indices: + for instance_index, entity_index in indices: + if preds != [[]]: + pred = preds[instance_index][entity_index] + prediction_info = get_prediction_info(pred) + report += f" - Instance {instance_index}, Entity {entity_index}: {prediction_info}\n" + else: + report += f" - Instance {instance_index}, Entity {entity_index}\n" + else: + report += " - None\n" + + return report diff --git a/tests/test_evaluator_new.py b/tests/test_evaluator_new.py index 32794cc..55f3468 100644 --- a/tests/test_evaluator_new.py +++ b/tests/test_evaluator_new.py @@ -1,5 +1,6 @@ import pytest from nervaluate.evaluator import Evaluator +from nervaluate.reporting import summary_report_v2, summary_report_indices_v2 @pytest.fixture @@ -164,3 +165,33 @@ def test_evaluator_with_invalid_tags(sample_data): assert results["overall"][strategy].partial == 0 assert results["overall"][strategy].missed == 0 assert results["overall"][strategy].spurious == 0 + + +def test_evaluator_full(): + true = [ + ["O", "B-PER", "I-PER", "O", "O", "O", "B-ORG", "I-ORG"], + ["O", "B-LOC", "B-PER", "I-PER", "O", "O", "B-DATE"], + ] + + pred = [ + ["O", "O", "B-PER", "I-PER", "O", "O", "B-ORG", "I-ORG", "O"], + ["O", "B-LOC", "I-LOC", "O", "B-PER", "I-PER", "O", "B-DATE", "I-DATE", "O"], + ] + + evaluator = Evaluator(true, pred, tags=["PER", "ORG", "LOC", "DATE"], loader="list") + results = evaluator.evaluate() + print("\n\n") + + # For metrics report + report_overall = summary_report_v2(results, mode="overall") + print(report_overall) + + report_entities = summary_report_v2(results, mode="entities", scenario="strict") + print(report_entities) + + # For indices report + report_indices_overall = summary_report_indices_v2(results, mode="overall", preds=pred) + print(report_indices_overall) + + report_indices_entities = summary_report_indices_v2(results, mode="entities", scenario="strict", preds=pred) + print(report_indices_entities) From 42fb3ad69459c38b2e29eca0a782245fc85a60a7 Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Wed, 21 May 2025 00:50:36 +0200 Subject: [PATCH 13/41] moving reporting to the Evaluator class --- src/nervaluate/evaluator.py | 168 +++++++++++++++++++++++++++++++++- src/nervaluate/reporting.py | 174 ------------------------------------ tests/test_evaluator_new.py | 18 ++-- 3 files changed, 174 insertions(+), 186 deletions(-) diff --git a/src/nervaluate/evaluator.py b/src/nervaluate/evaluator.py index 59e6122..c83b4e0 100644 --- a/src/nervaluate/evaluator.py +++ b/src/nervaluate/evaluator.py @@ -1,10 +1,10 @@ -from typing import List, Dict, Any +from typing import List, Dict, Any, Union import pandas as pd from .entities import EvaluationResult, EvaluationIndices from .evaluation_strategies import EvaluationStrategy, StrictEvaluation, PartialEvaluation, EntityTypeEvaluation from .loaders import DataLoader, ConllLoader, ListLoader, DictLoader - +from .entities import Entity class Evaluator: """Main evaluator class for NER evaluation.""" @@ -140,3 +140,167 @@ def results_to_dataframe(self) -> Any: flat_results[key] = value return pd.DataFrame([flat_results]) + + def summary_report(self, mode: str = "overall", scenario: str = "strict", digits: int = 2) -> str: + """ + Generate a summary report of the evaluation results. + + Args: + mode: Either 'overall' for overall metrics or 'entities' for per-entity metrics. + scenario: The scenario to report on. Must be one of: 'strict', 'ent_type', 'partial', 'exact'. + Only used when mode is 'entities'. Defaults to 'strict'. + digits: The number of digits to round the results to. + + Returns: + A string containing the summary report. + + Raises: + ValueError: If the scenario or mode is invalid. + """ + valid_scenarios = {"strict", "ent_type", "partial", "exact"} + valid_modes = {"overall", "entities"} + + if mode not in valid_modes: + raise ValueError(f"Invalid mode: must be one of {valid_modes}") + + if mode == "entities" and scenario not in valid_scenarios: + raise ValueError(f"Invalid scenario: must be one of {valid_scenarios}") + + headers = ["correct", "incorrect", "partial", "missed", "spurious", "precision", "recall", "f1-score"] + rows = [headers] + + results = self.evaluate() + if mode == "overall": + # Process overall results + results_data = results["overall"] + for eval_schema in sorted(valid_scenarios): # Sort to ensure consistent order + if eval_schema not in results_data: + continue + results_schema = results_data[eval_schema] + rows.append( + [ + eval_schema, + results_schema.correct, + results_schema.incorrect, + results_schema.partial, + results_schema.missed, + results_schema.spurious, + results_schema.precision, + results_schema.recall, + results_schema.f1, + ] + ) + else: + # Process entity-specific results + results_data = results["entities"] + target_names = sorted(results_data.keys()) + for ent_type in target_names: + if scenario not in results_data[ent_type]: + raise ValueError(f"Scenario '{scenario}' not found in results for entity type '{ent_type}'") + + results_ent = results_data[ent_type][scenario] + rows.append( + [ + ent_type, + results_ent.correct, + results_ent.incorrect, + results_ent.partial, + results_ent.missed, + results_ent.spurious, + results_ent.precision, + results_ent.recall, + results_ent.f1, + ] + ) + + # Format the report + name_width = max(len(str(row[0])) for row in rows) + width = max(name_width, digits) + head_fmt = "{:>{width}s} " + " {:>11}" * len(headers) + report = head_fmt.format("", *headers, width=width) + report += "\n\n" + row_fmt = "{:>{width}s} " + " {:>11}" * 5 + " {:>11.{digits}f}" * 3 + "\n" + + for row in rows[1:]: + report += row_fmt.format(*row, width=width, digits=digits) + + return report + + def summary_report_indices(self, mode: str = "overall", scenario: str = "strict") -> str: + """ + Generate a summary report of the evaluation indices. + + Args: + mode: Either 'overall' for overall metrics or 'entities' for per-entity metrics. + scenario: The scenario to report on. Must be one of: 'strict', 'ent_type', 'partial', 'exact'. + Only used when mode is 'entities'. Defaults to 'strict'. + + Returns: + A string containing the summary report of indices. + + Raises: + ValueError: If the scenario or mode is invalid. + """ + valid_scenarios = {"strict", "ent_type", "partial", "exact"} + valid_modes = {"overall", "entities"} + + if mode not in valid_modes: + raise ValueError(f"Invalid mode: must be one of {valid_modes}") + + if mode == "entities" and scenario not in valid_scenarios: + raise ValueError(f"Invalid scenario: must be one of {valid_scenarios}") + + def get_prediction_info(pred: Union[Entity, str]) -> str: + """Helper function to get prediction info based on pred type.""" + if isinstance(pred, Entity): + return f"Label={pred.label}, Start={pred.start}, End={pred.end}" # type: ignore + # String (BIO tag) + return f"Tag={pred}" + + results = self.evaluate() + report = "" + if mode == "overall": + # Get the indices from the overall results + indices_data = results["overall_indices"][scenario] + report += f"Indices for error schema '{scenario}':\n\n" + + for category, indices in indices_data.__dict__.items(): + if not category.endswith("_indices"): + continue + category_name = category.replace("_indices", "").replace("_", " ").capitalize() + report += f"{category_name}:\n" + if indices: + for instance_index, entity_index in indices: + if self.pred != [[]]: + pred = self.pred[instance_index][entity_index] + prediction_info = get_prediction_info(pred) + report += f" - Instance {instance_index}, Entity {entity_index}: {prediction_info}\n" + else: + report += f" - Instance {instance_index}, Entity {entity_index}\n" + else: + report += " - None\n" + report += "\n" + else: + # Get the indices from the entity-specific results + for entity_type, entity_results in results["entity_indices"].items(): + report += f"\nEntity Type: {entity_type}\n" + error_data = entity_results[scenario] + report += f" Error Schema: '{scenario}'\n" + + for category, indices in error_data.__dict__.items(): + if not category.endswith("_indices"): + continue + category_name = category.replace("_indices", "").replace("_", " ").capitalize() + report += f" ({entity_type}) {category_name}:\n" + if indices: + for instance_index, entity_index in indices: + if self.pred != [[]]: + pred = self.pred[instance_index][entity_index] + prediction_info = get_prediction_info(pred) + report += f" - Instance {instance_index}, Entity {entity_index}: {prediction_info}\n" + else: + report += f" - Instance {instance_index}, Entity {entity_index}\n" + else: + report += " - None\n" + + return report diff --git a/src/nervaluate/reporting.py b/src/nervaluate/reporting.py index bdb2119..c32def3 100644 --- a/src/nervaluate/reporting.py +++ b/src/nervaluate/reporting.py @@ -164,177 +164,3 @@ def summary_report_overall_indices(evaluation_indices: dict, error_schema: str, report += "\n" return report - - -def summary_report_v2(results: dict, mode: str = "overall", scenario: str = "strict", digits: int = 2) -> str: - """ - Generate a summary report of the evaluation results for the new Evaluator class. - - Args: - results: Dictionary containing the evaluation results from the new Evaluator class. - mode: Either 'overall' for overall metrics or 'entities' for per-entity metrics. - scenario: The scenario to report on. Must be one of: 'strict', 'ent_type', 'partial', 'exact'. - Only used when mode is 'entities'. Defaults to 'strict'. - digits: The number of digits to round the results to. - - Returns: - A string containing the summary report. - - Raises: - ValueError: If the scenario or mode is invalid. - """ - valid_scenarios = {"strict", "ent_type", "partial", "exact"} - valid_modes = {"overall", "entities"} - - if mode not in valid_modes: - raise ValueError(f"Invalid mode: must be one of {valid_modes}") - - if mode == "entities" and scenario not in valid_scenarios: - raise ValueError(f"Invalid scenario: must be one of {valid_scenarios}") - - headers = ["correct", "incorrect", "partial", "missed", "spurious", "precision", "recall", "f1-score"] - rows = [headers] - - if mode == "overall": - # Process overall results - results_data = results["overall"] - for eval_schema in valid_scenarios: - if eval_schema not in results_data: - continue - results_schema = results_data[eval_schema] - rows.append( - [ - eval_schema, - results_schema.correct, - results_schema.incorrect, - results_schema.partial, - results_schema.missed, - results_schema.spurious, - results_schema.precision, - results_schema.recall, - results_schema.f1, - ] - ) - else: - # Process entity-specific results - results_data = results["entities"] - target_names = sorted(results_data.keys()) - for ent_type in target_names: - if scenario not in results_data[ent_type]: - raise ValueError(f"Scenario '{scenario}' not found in results for entity type '{ent_type}'") - - results_ent = results_data[ent_type][scenario] - rows.append( - [ - ent_type, - results_ent.correct, - results_ent.incorrect, - results_ent.partial, - results_ent.missed, - results_ent.spurious, - results_ent.precision, - results_ent.recall, - results_ent.f1, - ] - ) - - # Format the report - name_width = max(len(str(row[0])) for row in rows) - width = max(name_width, digits) - head_fmt = "{:>{width}s} " + " {:>11}" * len(headers) - report = head_fmt.format("", *headers, width=width) - report += "\n\n" - row_fmt = "{:>{width}s} " + " {:>11}" * 5 + " {:>11.{digits}f}" * 3 + "\n" - - for row in rows[1:]: - report += row_fmt.format(*row, width=width, digits=digits) - - return report - - -def summary_report_indices_v2( # pylint: disable=too-many-branches - results: dict, mode: str = "overall", scenario: str = "strict", preds: list | None = None -) -> str: - """ - Generate a summary report of the evaluation indices for the new Evaluator class. - - Args: - results: Dictionary containing the evaluation results from the new Evaluator class. - mode: Either 'overall' for overall metrics or 'entities' for per-entity metrics. - scenario: The scenario to report on. Must be one of: 'strict', 'ent_type', 'partial', 'exact'. - Only used when mode is 'entities'. Defaults to 'strict'. - preds: List of predicted named entities. Can be either: - - List of lists of entity objects with label, start, end attributes - - List of lists of strings (BIO tags) - - Returns: - A string containing the summary report of indices. - - Raises: - ValueError: If the scenario or mode is invalid. - """ - valid_scenarios = {"strict", "ent_type", "partial", "exact"} - valid_modes = {"overall", "entities"} - - if mode not in valid_modes: - raise ValueError(f"Invalid mode: must be one of {valid_modes}") - - if mode == "entities" and scenario not in valid_scenarios: - raise ValueError(f"Invalid scenario: must be one of {valid_scenarios}") - - if preds is None: - preds = [[]] - - def get_prediction_info(pred: Union[Entity, str]) -> str: - """Helper function to get prediction info based on pred type.""" - if isinstance(pred, Entity): - return f"Label={pred.label}, Start={pred.start}, End={pred.end}" # type: ignore - # String (BIO tag) - return f"Tag={pred}" - - report = "" - if mode == "overall": - # Get the indices from the overall results - indices_data = results["overall_indices"][scenario] - report += f"Indices for error schema '{scenario}':\n\n" - - for category, indices in indices_data.__dict__.items(): - if not category.endswith("_indices"): - continue - category_name = category.replace("_indices", "").replace("_", " ").capitalize() - report += f"{category_name}:\n" - if indices: - for instance_index, entity_index in indices: - if preds != [[]]: - pred = preds[instance_index][entity_index] - prediction_info = get_prediction_info(pred) - report += f" - Instance {instance_index}, Entity {entity_index}: {prediction_info}\n" - else: - report += f" - Instance {instance_index}, Entity {entity_index}\n" - else: - report += " - None\n" - report += "\n" - else: - # Get the indices from the entity-specific results - for entity_type, entity_results in results["entity_indices"].items(): - report += f"\nEntity Type: {entity_type}\n" - error_data = entity_results[scenario] - report += f" Error Schema: '{scenario}'\n" - - for category, indices in error_data.__dict__.items(): - if not category.endswith("_indices"): - continue - category_name = category.replace("_indices", "").replace("_", " ").capitalize() - report += f" ({entity_type}) {category_name}:\n" - if indices: - for instance_index, entity_index in indices: - if preds != [[]]: - pred = preds[instance_index][entity_index] - prediction_info = get_prediction_info(pred) - report += f" - Instance {instance_index}, Entity {entity_index}: {prediction_info}\n" - else: - report += f" - Instance {instance_index}, Entity {entity_index}\n" - else: - report += " - None\n" - - return report diff --git a/tests/test_evaluator_new.py b/tests/test_evaluator_new.py index 55f3468..0630b04 100644 --- a/tests/test_evaluator_new.py +++ b/tests/test_evaluator_new.py @@ -1,6 +1,5 @@ import pytest from nervaluate.evaluator import Evaluator -from nervaluate.reporting import summary_report_v2, summary_report_indices_v2 @pytest.fixture @@ -180,18 +179,17 @@ def test_evaluator_full(): evaluator = Evaluator(true, pred, tags=["PER", "ORG", "LOC", "DATE"], loader="list") results = evaluator.evaluate() - print("\n\n") # For metrics report - report_overall = summary_report_v2(results, mode="overall") - print(report_overall) + report_overall = evaluator.summary_report(mode="overall") + assert report_overall is not None # ToDo: Add actual assertions - report_entities = summary_report_v2(results, mode="entities", scenario="strict") - print(report_entities) + report_entities = evaluator.summary_report(mode="entities", scenario="strict") + assert report_entities is not None # ToDo: Add actual assertions # For indices report - report_indices_overall = summary_report_indices_v2(results, mode="overall", preds=pred) - print(report_indices_overall) + report_indices_overall = evaluator.summary_report_indices(mode="overall") + assert report_indices_overall is not None # ToDo: Add actual assertions - report_indices_entities = summary_report_indices_v2(results, mode="entities", scenario="strict", preds=pred) - print(report_indices_entities) + report_indices_entities = evaluator.summary_report_indices(mode="entities", scenario="strict") + assert report_indices_entities is not None # ToDo: Add actual assertions From 0798ce32404c3dafafa82d4299c516ccd6de1b85 Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Wed, 21 May 2025 12:56:42 +0200 Subject: [PATCH 14/41] fixing imports --- src/nervaluate/evaluator.py | 11 ++++++++--- src/nervaluate/loaders.py | 4 ++-- src/nervaluate/reporting.py | 4 ---- tests/test_evaluator_new.py | 9 ++++----- 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/nervaluate/evaluator.py b/src/nervaluate/evaluator.py index c83b4e0..3fe5997 100644 --- a/src/nervaluate/evaluator.py +++ b/src/nervaluate/evaluator.py @@ -6,6 +6,7 @@ from .loaders import DataLoader, ConllLoader, ListLoader, DictLoader from .entities import Entity + class Evaluator: """Main evaluator class for NER evaluation.""" @@ -226,7 +227,9 @@ def summary_report(self, mode: str = "overall", scenario: str = "strict", digits return report - def summary_report_indices(self, mode: str = "overall", scenario: str = "strict") -> str: + def summary_report_indices( # pylint: disable=too-many-branches + self, mode: str = "overall", scenario: str = "strict" + ) -> str: """ Generate a summary report of the evaluation indices. @@ -253,7 +256,7 @@ def summary_report_indices(self, mode: str = "overall", scenario: str = "strict" def get_prediction_info(pred: Union[Entity, str]) -> str: """Helper function to get prediction info based on pred type.""" if isinstance(pred, Entity): - return f"Label={pred.label}, Start={pred.start}, End={pred.end}" # type: ignore + return f"Label={pred.label}, Start={pred.start}, End={pred.end}" # String (BIO tag) return f"Tag={pred}" @@ -297,7 +300,9 @@ def get_prediction_info(pred: Union[Entity, str]) -> str: if self.pred != [[]]: pred = self.pred[instance_index][entity_index] prediction_info = get_prediction_info(pred) - report += f" - Instance {instance_index}, Entity {entity_index}: {prediction_info}\n" + report += ( + f" - Instance {instance_index}, Entity {entity_index}: {prediction_info}\n" + ) else: report += f" - Instance {instance_index}, Entity {entity_index}\n" else: diff --git a/src/nervaluate/loaders.py b/src/nervaluate/loaders.py index 8cd9ac6..ad8b82c 100644 --- a/src/nervaluate/loaders.py +++ b/src/nervaluate/loaders.py @@ -65,7 +65,7 @@ def load(self, data: str) -> List[List[Entity]]: elif ent_type != token_tag[2:] or (ent_type == token_tag[2:] and token_tag[:1] == "B"): end_offset = offset - 1 - current_doc.append(Entity(label=ent_type, start=start_offset, end=end_offset)) + current_doc.append(Entity(label=ent_type, start=start_offset, end=end_offset)) # type: ignore # start of a new entity if not (token_tag.startswith("B-") or token_tag.startswith("I-")): @@ -127,7 +127,7 @@ def load(self, data: List[List[str]]) -> List[List[Entity]]: elif ent_type != token_tag[2:] or (ent_type == token_tag[2:] and token_tag[:1] == "B"): end_offset = offset - 1 - current_doc.append(Entity(label=ent_type, start=start_offset, end=end_offset)) + current_doc.append(Entity(label=ent_type, start=start_offset, end=end_offset)) # type: ignore # start of a new entity if not (token_tag.startswith("B-") or token_tag.startswith("I-")): diff --git a/src/nervaluate/reporting.py b/src/nervaluate/reporting.py index c32def3..0c59bbf 100644 --- a/src/nervaluate/reporting.py +++ b/src/nervaluate/reporting.py @@ -1,7 +1,3 @@ -from typing import Union -from xml.dom.minidom import Entity - - def summary_report(results: dict, mode: str = "overall", scenario: str = "strict", digits: int = 2) -> str: """ Generate a summary report of the evaluation results. diff --git a/tests/test_evaluator_new.py b/tests/test_evaluator_new.py index 0630b04..8308fe6 100644 --- a/tests/test_evaluator_new.py +++ b/tests/test_evaluator_new.py @@ -178,18 +178,17 @@ def test_evaluator_full(): ] evaluator = Evaluator(true, pred, tags=["PER", "ORG", "LOC", "DATE"], loader="list") - results = evaluator.evaluate() # For metrics report report_overall = evaluator.summary_report(mode="overall") - assert report_overall is not None # ToDo: Add actual assertions + assert report_overall is not None # ToDo: Add actual assertions report_entities = evaluator.summary_report(mode="entities", scenario="strict") - assert report_entities is not None # ToDo: Add actual assertions + assert report_entities is not None # ToDo: Add actual assertions # For indices report report_indices_overall = evaluator.summary_report_indices(mode="overall") - assert report_indices_overall is not None # ToDo: Add actual assertions + assert report_indices_overall is not None # ToDo: Add actual assertions report_indices_entities = evaluator.summary_report_indices(mode="entities", scenario="strict") - assert report_indices_entities is not None # ToDo: Add actual assertions + assert report_indices_entities is not None # ToDo: Add actual assertions From bd082fa8a4584c1c4384541c0cc68fa7c31abca0 Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Wed, 21 May 2025 13:02:45 +0200 Subject: [PATCH 15/41] fxing empty entities --- src/nervaluate/loaders.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/src/nervaluate/loaders.py b/src/nervaluate/loaders.py index ad8b82c..98ba819 100644 --- a/src/nervaluate/loaders.py +++ b/src/nervaluate/loaders.py @@ -15,7 +15,7 @@ def load(self, data: Any) -> List[List[Entity]]: class ConllLoader(DataLoader): """Loader for CoNLL format data.""" - def load(self, data: str) -> List[List[Entity]]: + def load(self, data: str) -> List[List[Entity]]: # pylint: disable=too-many-branches """Load CoNLL format data into a list of Entity lists.""" if not isinstance(data, str): raise ValueError("ConllLoader expects string input") @@ -51,7 +51,8 @@ def load(self, data: str) -> List[List[Entity]]: if token_tag == "O": if ent_type is not None and start_offset is not None: end_offset = offset - 1 - current_doc.append(Entity(label=ent_type, start=start_offset, end=end_offset)) + if isinstance(start_offset, int) and isinstance(end_offset, int): + current_doc.append(Entity(label=ent_type, start=start_offset, end=end_offset)) start_offset = None end_offset = None ent_type = None @@ -65,7 +66,8 @@ def load(self, data: str) -> List[List[Entity]]: elif ent_type != token_tag[2:] or (ent_type == token_tag[2:] and token_tag[:1] == "B"): end_offset = offset - 1 - current_doc.append(Entity(label=ent_type, start=start_offset, end=end_offset)) # type: ignore + if isinstance(start_offset, int) and isinstance(end_offset, int): + current_doc.append(Entity(label=ent_type, start=start_offset, end=end_offset)) # start of a new entity if not (token_tag.startswith("B-") or token_tag.startswith("I-")): @@ -77,7 +79,8 @@ def load(self, data: str) -> List[List[Entity]]: # Catches an entity that goes up until the last token if ent_type is not None and start_offset is not None and end_offset is None: - current_doc.append(Entity(label=ent_type, start=start_offset, end=len(doc.split("\n")) - 1)) + if isinstance(start_offset, int): + current_doc.append(Entity(label=ent_type, start=start_offset, end=len(doc.split("\n")) - 1)) has_entities = True result.append(current_doc if has_entities else []) @@ -88,7 +91,7 @@ def load(self, data: str) -> List[List[Entity]]: class ListLoader(DataLoader): """Loader for list format data.""" - def load(self, data: List[List[str]]) -> List[List[Entity]]: + def load(self, data: List[List[str]]) -> List[List[Entity]]: # pylint: disable=too-many-branches """Load list format data into a list of entity lists.""" if not isinstance(data, list): raise ValueError("ListLoader expects list input") @@ -114,7 +117,8 @@ def load(self, data: List[List[str]]) -> List[List[Entity]]: if token_tag == "O": if ent_type is not None and start_offset is not None: end_offset = offset - 1 - current_doc.append(Entity(label=ent_type, start=start_offset, end=end_offset)) + if isinstance(start_offset, int) and isinstance(end_offset, int): + current_doc.append(Entity(label=ent_type, start=start_offset, end=end_offset)) start_offset = None end_offset = None ent_type = None @@ -127,7 +131,8 @@ def load(self, data: List[List[str]]) -> List[List[Entity]]: elif ent_type != token_tag[2:] or (ent_type == token_tag[2:] and token_tag[:1] == "B"): end_offset = offset - 1 - current_doc.append(Entity(label=ent_type, start=start_offset, end=end_offset)) # type: ignore + if isinstance(start_offset, int) and isinstance(end_offset, int): + current_doc.append(Entity(label=ent_type, start=start_offset, end=end_offset)) # start of a new entity if not (token_tag.startswith("B-") or token_tag.startswith("I-")): @@ -138,7 +143,8 @@ def load(self, data: List[List[str]]) -> List[List[Entity]]: # Catches an entity that goes up until the last token if ent_type is not None and start_offset is not None and end_offset is None: - current_doc.append(Entity(label=ent_type, start=start_offset, end=len(doc) - 1)) + if isinstance(start_offset, int): + current_doc.append(Entity(label=ent_type, start=start_offset, end=len(doc) - 1)) result.append(current_doc) From 4ed62daa522c9263741994f199f8c5421192ec42 Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Thu, 22 May 2025 00:14:16 +0200 Subject: [PATCH 16/41] fixes --- README.md | 24 ++++++++++++------------ compare_versions.py | 22 ++++++++++++++++++++++ src/nervaluate/evaluator.py | 10 +++++++--- 3 files changed, 41 insertions(+), 15 deletions(-) create mode 100644 compare_versions.py diff --git a/README.md b/README.md index b057601..df987f9 100644 --- a/README.md +++ b/README.md @@ -19,10 +19,10 @@ based on whether all the tokens that belong to a named entity were classified or entity type was assigned. This full problem is described in detail in the [original blog](http://www.davidsbatista.net/blog/2018/05/09/Named_Entity_Evaluation/) -post by [David Batista](https://github.com/davidsbatista), and extends the code in the[original repository](https://github.com/davidsbatista/NER-Evaluation) +post by [David Batista](https://github.com/davidsbatista), and extends the code in the [original repository](https://github.com/davidsbatista/NER-Evaluation) which accompanied the blog post. -The code draws heavily on: +The code draws heavily on the paper: * [SemEval-2013 Task 9 : Extraction of Drug-Drug Interactions from Biomedical Texts (DDIExtraction 2013)](https://www.aclweb.org/anthology/S13-2056) @@ -159,16 +159,16 @@ Recall = (COR + 0.5 × PAR)/POS = COR / ACT = TP / (TP + FN) __Putting all together:__ -|Measure|Type|Partial|Exact|Strict| -|---|---|---|---|---| -|Correct|3|3|3|2| -|Incorrect|2|0|2|3| -|Partial|0|2|0|0| -|Missed|1|1|1|1| -|Spurious|1|1|1|1| -|Precision|0.5|0.66|0.5|0.33| -|Recall|0.5|0.66|0.5|0.33| -|F1|0.5|0.66|0.5|0.33| +| Measure | Type | Partial | Exact | Strict | +|-----------|------|---------|-------|--------| +| Correct | 3 | 3 | 3 | 2 | +| Incorrect | 2 | 0 | 2 | 3 | +| Partial | 0 | 2 | 0 | 0 | +| Missed | 1 | 1 | 1 | 1 | +| Spurious | 1 | 1 | 1 | 1 | +| Precision | 0.5 | 0.66 | 0.5 | 0.33 | +| Recall | 0.5 | 0.66 | 0.5 | 0.33 | +| F1 | 0.5 | 0.66 | 0.5 | 0.33 | ## Notes: diff --git a/compare_versions.py b/compare_versions.py new file mode 100644 index 0000000..ee068b7 --- /dev/null +++ b/compare_versions.py @@ -0,0 +1,22 @@ +from nervaluate.evaluator import Evaluator + +true = [ + ['O', 'B-PER', 'I-PER', 'O', 'O', 'O', 'B-ORG', 'I-ORG'], + ['O', 'B-LOC', 'B-PER', 'I-PER', 'O', 'O', 'B-DATE'], +] + +pred = [ + ['O', 'B-PER', 'I-PER', 'O', 'O', 'O', 'B-ORG', 'I-ORG'], + ['O', 'B-LOC', 'I-LOC', 'O', 'B-PER', 'I-PER', 'O', 'B-DATE', 'I-DATE', 'O'], +] + +# Example text for reference: +# "The John Smith who works at Google Inc" +# "In Paris Marie Curie lived in 1895" + +new_evaluator = Evaluator(true, pred, tags=['PER', 'ORG', 'LOC', 'DATE'], loader="list") + +print(new_evaluator.summary_report()) + +from nervaluate import Evaluator +old_evaluator = Evaluator(true, pred, tags=['PER', 'ORG', 'LOC', 'DATE'], loader="list") \ No newline at end of file diff --git a/src/nervaluate/evaluator.py b/src/nervaluate/evaluator.py index 3fe5997..bf8677b 100644 --- a/src/nervaluate/evaluator.py +++ b/src/nervaluate/evaluator.py @@ -148,8 +148,12 @@ def summary_report(self, mode: str = "overall", scenario: str = "strict", digits Args: mode: Either 'overall' for overall metrics or 'entities' for per-entity metrics. - scenario: The scenario to report on. Must be one of: 'strict', 'ent_type', 'partial', 'exact'. - Only used when mode is 'entities'. Defaults to 'strict'. + scenario: The scenario to report on. Defaults to 'strict'. + Must be one of: + - 'strict' exact boundary surface string match and entity type; + - 'exact': exact boundary match over the surface string and entity type; + - 'partial': partial boundary match over the surface string, regardless of the type; + - 'ent_type': exact boundary match over the surface string, regardless of the type; digits: The number of digits to round the results to. Returns: @@ -218,7 +222,7 @@ def summary_report(self, mode: str = "overall", scenario: str = "strict", digits name_width = max(len(str(row[0])) for row in rows) width = max(name_width, digits) head_fmt = "{:>{width}s} " + " {:>11}" * len(headers) - report = head_fmt.format("", *headers, width=width) + report = f"Scenario: {scenario}\n\n" + head_fmt.format("", *headers, width=width) report += "\n\n" row_fmt = "{:>{width}s} " + " {:>11}" * 5 + " {:>11.{digits}f}" * 3 + "\n" From 9a0cd448d93d07e246f1195221fd542baab42901 Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Sat, 24 May 2025 12:25:07 +0200 Subject: [PATCH 17/41] working on comparative example --- compare_versions.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/compare_versions.py b/compare_versions.py index ee068b7..c41f2cc 100644 --- a/compare_versions.py +++ b/compare_versions.py @@ -18,5 +18,11 @@ print(new_evaluator.summary_report()) +# The old evaluator for comparison + from nervaluate import Evaluator -old_evaluator = Evaluator(true, pred, tags=['PER', 'ORG', 'LOC', 'DATE'], loader="list") \ No newline at end of file +old_evaluator = Evaluator(true, pred, tags=['PER', 'ORG', 'LOC', 'DATE'], loader="list") + +from nervaluate.reporting import summary_report_overall_indices, summary_report_ents_indices, summary_report_overall +results = old_evaluator.evaluate()[0] # Get the first element which contains the overall results +print(summary_report_overall(results)) \ No newline at end of file From 40a2748a166080dd39cb2ee2de82b89a0caa9502 Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Sat, 24 May 2025 12:49:23 +0200 Subject: [PATCH 18/41] fixing docs lenghts tests --- compare_versions.py | 2 +- src/nervaluate/evaluator.py | 10 ++++++++++ tests/test_evaluator.py | 1 + tests/test_evaluator_new.py | 32 +++++++++++--------------------- 4 files changed, 23 insertions(+), 22 deletions(-) diff --git a/compare_versions.py b/compare_versions.py index c41f2cc..a3c5e9d 100644 --- a/compare_versions.py +++ b/compare_versions.py @@ -7,7 +7,7 @@ pred = [ ['O', 'B-PER', 'I-PER', 'O', 'O', 'O', 'B-ORG', 'I-ORG'], - ['O', 'B-LOC', 'I-LOC', 'O', 'B-PER', 'I-PER', 'O', 'B-DATE', 'I-DATE', 'O'], + ['O', 'B-LOC', 'I-LOC', 'O', 'B-PER', 'O', 'B-DATE'], ] # Example text for reference: diff --git a/src/nervaluate/evaluator.py b/src/nervaluate/evaluator.py index bf8677b..85b8dd3 100644 --- a/src/nervaluate/evaluator.py +++ b/src/nervaluate/evaluator.py @@ -54,6 +54,16 @@ def _load_data(self, true: Any, pred: Any, loader: str) -> None: if loader not in self.loaders: raise ValueError(f"Unknown loader: {loader}") + # For list loader, check document lengths before loading + if loader == "list": + if len(true) != len(pred): + raise ValueError("Number of predicted documents does not equal true") + + # Check that each document has the same length + for i, (true_doc, pred_doc) in enumerate(zip(true, pred)): + if len(true_doc) != len(pred_doc): + raise ValueError(f"Document {i} has different lengths: true={len(true_doc)}, pred={len(pred_doc)}") + self.true = self.loaders[loader].load(true) self.pred = self.loaders[loader].load(pred) diff --git a/tests/test_evaluator.py b/tests/test_evaluator.py index 7bea536..4ae3997 100644 --- a/tests/test_evaluator.py +++ b/tests/test_evaluator.py @@ -1,4 +1,5 @@ # pylint: disable=too-many-lines +import pytest import pandas as pd from nervaluate import Evaluator diff --git a/tests/test_evaluator_new.py b/tests/test_evaluator_new.py index 8308fe6..826c72b 100644 --- a/tests/test_evaluator_new.py +++ b/tests/test_evaluator_new.py @@ -166,29 +166,19 @@ def test_evaluator_with_invalid_tags(sample_data): assert results["overall"][strategy].spurious == 0 -def test_evaluator_full(): +def test_evaluator_different_document_lengths(): + """Test that Evaluator raises ValueError when documents have different lengths.""" true = [ - ["O", "B-PER", "I-PER", "O", "O", "O", "B-ORG", "I-ORG"], - ["O", "B-LOC", "B-PER", "I-PER", "O", "O", "B-DATE"], + ["O", "B-PER", "I-PER", "O", "O", "O", "B-ORG", "I-ORG"], # 8 tokens + ["O", "B-LOC", "B-PER", "I-PER", "O", "O", "B-DATE"], # 7 tokens ] - pred = [ - ["O", "O", "B-PER", "I-PER", "O", "O", "B-ORG", "I-ORG", "O"], - ["O", "B-LOC", "I-LOC", "O", "B-PER", "I-PER", "O", "B-DATE", "I-DATE", "O"], + ["O", "B-PER", "I-PER", "O", "O", "O", "B-ORG", "I-ORG"], # 8 tokens + ["O", "B-LOC", "I-LOC", "O", "B-PER", "I-PER", "O", "B-DATE", "I-DATE", "O"], # 10 tokens ] + tags = ["PER", "ORG", "LOC", "DATE"] - evaluator = Evaluator(true, pred, tags=["PER", "ORG", "LOC", "DATE"], loader="list") - - # For metrics report - report_overall = evaluator.summary_report(mode="overall") - assert report_overall is not None # ToDo: Add actual assertions - - report_entities = evaluator.summary_report(mode="entities", scenario="strict") - assert report_entities is not None # ToDo: Add actual assertions - - # For indices report - report_indices_overall = evaluator.summary_report_indices(mode="overall") - assert report_indices_overall is not None # ToDo: Add actual assertions - - report_indices_entities = evaluator.summary_report_indices(mode="entities", scenario="strict") - assert report_indices_entities is not None # ToDo: Add actual assertions + # Test that ValueError is raised + with pytest.raises(ValueError, match="Document 1 has different lengths: true=7, pred=10"): + evaluator = Evaluator(true=true, pred=pred, tags=tags, loader="list") + evaluator.evaluate() From 40843310d1d83a42497ef8ab7c2fe66cbf4247fa Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Sat, 24 May 2025 13:56:10 +0200 Subject: [PATCH 19/41] fixing docs lenghts tests --- README.md | 36 ++++++++++++++++++++++++++++++++++++ compare_versions.py | 36 ++++++++++++++++++++++++++++++++++-- tests/test_evaluator.py | 2 +- 3 files changed, 71 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index df987f9..9becb99 100644 --- a/README.md +++ b/README.md @@ -426,6 +426,42 @@ evaluator = Evaluator(true, pred, tags=['PER', 'ORG', 'LOC', 'DATE'], loader="li results, results_by_tag, result_indices, result_indices_by_tag = evaluator.evaluate() ``` +Sentence: + - "In Paris Marie Curie lived in 1895" + +Tokens: + - ['In', 'Paris', 'Marie', 'Curie', 'lived', 'in', '1895'] + +Gold: + - ['O', 'B-LOC', 'B-PER', 'I-PER', 'O', 'O', 'B-DATE'] + +Predicted: + - ['O', 'B-LOC', 'I-LOC', 'O', 'B-PER', 'O', 'B-DATE'] + +Gold entities: + +| Type | String | +|------|-------------------| +| LOC | Paris | +| PER | Marie Curie | +| DATE | 1895 | + +Predicted entities: + +| Type | String | +|------|--------------| +| LOC | Paris Marie | +| PER | lived | +| DATE | 1895 | + +``` + + + + + + + +* [prodi.gy](https://prodi.gy) style lists of spans. ### Nested lists ``` +from nervaluate.evaluator import Evaluator + + true = [ + ['O', 'B-PER', 'I-PER', 'O', 'O', 'O', 'B-ORG', 'I-ORG'], + ['O', 'B-LOC', 'B-PER', 'I-PER', 'O', 'O', 'B-DATE'], + ] + + pred = [ + ['O', 'O', 'B-PER', 'I-PER', 'O', 'O', 'B-ORG', 'I-ORG'], + ['O', 'B-LOC', 'I-LOC', 'B-PER', 'O', 'O', 'B-DATE'], + ] + + # Example text for reference: + # "The John Smith who works at Google Inc" + # "In Paris Marie Curie lived in 1895" + + evaluator = Evaluator(true, pred, tags=['PER', 'ORG', 'LOC', 'DATE'], loader="list") + + +Scenario: all + + correct incorrect partial missed spurious precision recall f1-score + +ent_type 5 0 0 0 0 1.00 1.00 1.00 + exact 2 3 0 0 0 0.40 0.40 0.40 + partial 2 0 3 0 0 0.40 0.40 0.40 + strict 2 3 0 0 0 0.40 0.40 0.40 +``` + +and, aggregated by entity type: -from nervaluate import Evaluator - -true = [ - ['O', 'B-PER', 'I-PER', 'O', 'O', 'O', 'B-ORG', 'I-ORG'], - ['O', 'B-LOC', 'B-PER', 'I-PER', 'O', 'O', 'B-DATE'], -] - -pred = [ - ['O', 'O', 'B-PER', 'I-PER', 'O', 'O', 'B-ORG', 'I-ORG', 'O'], - ['O', 'B-LOC', 'I-LOC', 'O', 'B-PER', 'I-PER', 'O', 'B-DATE', 'I-DATE', 'O'], -] - -# Example text for reference: -# "The John Smith who works at Google Inc" -# "In Paris Marie Curie lived in 1895" - -evaluator = Evaluator(true, pred, tags=['PER', 'ORG', 'LOC', 'DATE'], loader="list") - -results, results_by_tag, result_indices, result_indices_by_tag = evaluator.evaluate() ``` +print(evaluator.summary_report(mode='entities')) + +Scenario: strict -Sentence: - - "In Paris Marie Curie lived in 1895" - -Tokens: - - ['In', 'Paris', 'Marie', 'Curie', 'lived', 'in', '1895'] - -Gold: - - ['O', 'B-LOC', 'B-PER', 'I-PER', 'O', 'O', 'B-DATE'] - -Predicted: - - ['O', 'B-LOC', 'I-LOC', 'O', 'B-PER', 'O', 'B-DATE'] - -Gold entities: - -| Type | String | -|------|-------------------| -| LOC | Paris | -| PER | Marie Curie | -| DATE | 1895 | - -Predicted entities: - -| Type | String | -|------|--------------| -| LOC | Paris Marie | -| PER | lived | -| DATE | 1895 | - -``` - - - - - - - - ## Contributing to the `nervaluate` package From 57cf441d22523d34b14f7efe0c8e93af8b5bdd1c Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Mon, 2 Jun 2025 10:37:54 +0200 Subject: [PATCH 37/41] removing old files --- src/nervaluate/__init__.py | 15 +- tests/test_evaluator.py | 1202 ++--------------------------------- tests/test_evaluator_new.py | 80 --- tests/test_loaders.py | 57 -- tests/test_nervaluate.py | 909 -------------------------- tests/test_reporting.py | 111 ---- tests/test_utils.py | 61 -- 7 files changed, 56 insertions(+), 2379 deletions(-) delete mode 100644 tests/test_evaluator_new.py delete mode 100644 tests/test_nervaluate.py delete mode 100644 tests/test_reporting.py diff --git a/src/nervaluate/__init__.py b/src/nervaluate/__init__.py index f34c277..830a9de 100644 --- a/src/nervaluate/__init__.py +++ b/src/nervaluate/__init__.py @@ -1,15 +1,2 @@ -from .evaluate import ( - Evaluator, - compute_actual_possible, - compute_metrics, - compute_precision_recall, - compute_precision_recall_wrapper, - find_overlap, -) -from .reporting import ( - summary_report_ent, - summary_report_ents_indices, - summary_report_overall, - summary_report_overall_indices, -) +from .evaluator import Evaluator from .utils import collect_named_entities, conll_to_spans, list_to_spans, split_list diff --git a/tests/test_evaluator.py b/tests/test_evaluator.py index 91315d4..f4fc0ef 100644 --- a/tests/test_evaluator.py +++ b/tests/test_evaluator.py @@ -1,1172 +1,80 @@ -# pylint: disable=too-many-lines +import pytest +from nervaluate.evaluator import Evaluator -import pandas as pd -from nervaluate import Evaluator - -def test_results_to_dataframe(): - """ - Test the results_to_dataframe method. - """ - # Setup - evaluator = Evaluator( - true=[["B-LOC", "I-LOC", "O"], ["B-PER", "O", "O"]], - pred=[["B-LOC", "I-LOC", "O"], ["B-PER", "I-PER", "O"]], - tags=["LOC", "PER"], - ) - - # Mock results data for the purpose of this test - evaluator.results = { - "strict": { - "correct": 10, - "incorrect": 5, - "partial": 3, - "missed": 2, - "spurious": 4, - "precision": 0.625, - "recall": 0.6667, - "f1": 0.6452, - "entities": { - "LOC": {"correct": 4, "incorrect": 1, "partial": 0, "missed": 1, "spurious": 2}, - "PER": {"correct": 3, "incorrect": 2, "partial": 1, "missed": 0, "spurious": 1}, - "ORG": {"correct": 3, "incorrect": 2, "partial": 2, "missed": 1, "spurious": 1}, - }, - }, - "ent_type": { - "correct": 8, - "incorrect": 4, - "partial": 1, - "missed": 3, - "spurious": 3, - "precision": 0.5714, - "recall": 0.6154, - "f1": 0.5926, - "entities": { - "LOC": {"correct": 3, "incorrect": 2, "partial": 1, "missed": 1, "spurious": 1}, - "PER": {"correct": 2, "incorrect": 1, "partial": 0, "missed": 2, "spurious": 0}, - "ORG": {"correct": 3, "incorrect": 1, "partial": 0, "missed": 0, "spurious": 2}, - }, - }, - "partial": { - "correct": 7, - "incorrect": 3, - "partial": 4, - "missed": 1, - "spurious": 5, - "precision": 0.5385, - "recall": 0.6364, - "f1": 0.5833, - "entities": { - "LOC": {"correct": 2, "incorrect": 1, "partial": 1, "missed": 1, "spurious": 2}, - "PER": {"correct": 3, "incorrect": 1, "partial": 1, "missed": 0, "spurious": 1}, - "ORG": {"correct": 2, "incorrect": 1, "partial": 2, "missed": 0, "spurious": 2}, - }, - }, - "exact": { - "correct": 9, - "incorrect": 6, - "partial": 2, - "missed": 2, - "spurious": 2, - "precision": 0.6, - "recall": 0.6429, - "f1": 0.6207, - "entities": { - "LOC": {"correct": 4, "incorrect": 1, "partial": 0, "missed": 1, "spurious": 1}, - "PER": {"correct": 3, "incorrect": 3, "partial": 0, "missed": 0, "spurious": 0}, - "ORG": {"correct": 2, "incorrect": 2, "partial": 2, "missed": 1, "spurious": 1}, - }, - }, - } - - # Expected DataFrame - expected_data = { - "correct": {"strict": 10, "ent_type": 8, "partial": 7, "exact": 9}, - "incorrect": {"strict": 5, "ent_type": 4, "partial": 3, "exact": 6}, - "partial": {"strict": 3, "ent_type": 1, "partial": 4, "exact": 2}, - "missed": {"strict": 2, "ent_type": 3, "partial": 1, "exact": 2}, - "spurious": {"strict": 4, "ent_type": 3, "partial": 5, "exact": 2}, - "precision": {"strict": 0.625, "ent_type": 0.5714, "partial": 0.5385, "exact": 0.6}, - "recall": {"strict": 0.6667, "ent_type": 0.6154, "partial": 0.6364, "exact": 0.6429}, - "f1": {"strict": 0.6452, "ent_type": 0.5926, "partial": 0.5833, "exact": 0.6207}, - "entities.LOC.correct": {"strict": 4, "ent_type": 3, "partial": 2, "exact": 4}, - "entities.LOC.incorrect": {"strict": 1, "ent_type": 2, "partial": 1, "exact": 1}, - "entities.LOC.partial": {"strict": 0, "ent_type": 1, "partial": 1, "exact": 0}, - "entities.LOC.missed": {"strict": 1, "ent_type": 1, "partial": 1, "exact": 1}, - "entities.LOC.spurious": {"strict": 2, "ent_type": 1, "partial": 2, "exact": 1}, - "entities.PER.correct": {"strict": 3, "ent_type": 2, "partial": 3, "exact": 3}, - "entities.PER.incorrect": {"strict": 2, "ent_type": 1, "partial": 1, "exact": 3}, - "entities.PER.partial": {"strict": 1, "ent_type": 0, "partial": 1, "exact": 0}, - "entities.PER.missed": {"strict": 0, "ent_type": 2, "partial": 0, "exact": 0}, - "entities.PER.spurious": {"strict": 1, "ent_type": 0, "partial": 1, "exact": 0}, - "entities.ORG.correct": {"strict": 3, "ent_type": 3, "partial": 2, "exact": 2}, - "entities.ORG.incorrect": {"strict": 2, "ent_type": 1, "partial": 1, "exact": 2}, - "entities.ORG.partial": {"strict": 2, "ent_type": 0, "partial": 2, "exact": 2}, - "entities.ORG.missed": {"strict": 1, "ent_type": 0, "partial": 0, "exact": 1}, - "entities.ORG.spurious": {"strict": 1, "ent_type": 2, "partial": 2, "exact": 1}, - } - - expected_df = pd.DataFrame(expected_data) - - # Execute - result_df = evaluator.results_to_dataframe() - - # Assert - pd.testing.assert_frame_equal(result_df, expected_df) - - -def test_evaluator_simple_case(): +@pytest.fixture +def sample_data(): true = [ - [{"label": "PER", "start": 2, "end": 4}], - [ - {"label": "LOC", "start": 1, "end": 2}, - {"label": "LOC", "start": 3, "end": 4}, - ], + ["O", "B-PER", "O", "B-ORG", "I-ORG", "B-LOC"], + ["O", "B-PER", "O", "B-ORG"], ] - pred = [ - [{"label": "PER", "start": 2, "end": 4}], - [ - {"label": "LOC", "start": 1, "end": 2}, - {"label": "LOC", "start": 3, "end": 4}, - ], - ] - evaluator = Evaluator(true, pred, tags=["LOC", "PER"]) - results, _, _, _ = evaluator.evaluate() - expected = { - "strict": { - "correct": 3, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 3, - "actual": 3, - "precision": 1.0, - "recall": 1.0, - "f1": 1.0, - }, - "ent_type": { - "correct": 3, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 3, - "actual": 3, - "precision": 1.0, - "recall": 1.0, - "f1": 1.0, - }, - "partial": { - "correct": 3, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 3, - "actual": 3, - "precision": 1.0, - "recall": 1.0, - "f1": 1.0, - }, - "exact": { - "correct": 3, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 3, - "actual": 3, - "precision": 1.0, - "recall": 1.0, - "f1": 1.0, - }, - } - - assert results["strict"] == expected["strict"] - assert results["ent_type"] == expected["ent_type"] - assert results["partial"] == expected["partial"] - assert results["exact"] == expected["exact"] - -def test_evaluator_simple_case_filtered_tags(): - """Check that tags can be excluded by passing the tags argument""" - true = [ - [{"label": "PER", "start": 2, "end": 4}], - [ - {"label": "LOC", "start": 1, "end": 2}, - {"label": "LOC", "start": 3, "end": 4}, - ], - ] pred = [ - [{"label": "PER", "start": 2, "end": 4}], - [ - {"label": "LOC", "start": 1, "end": 2}, - {"label": "LOC", "start": 3, "end": 4}, - ], + ["O", "B-PER", "O", "B-ORG", "O", "B-PER"], + ["O", "B-PER", "O", "B-LOC"], ] - evaluator = Evaluator(true, pred, tags=["PER", "LOC"]) - results, _, _, _ = evaluator.evaluate() - expected = { - "strict": { - "correct": 3, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 3, - "actual": 3, - "precision": 1.0, - "recall": 1.0, - "f1": 1.0, - }, - "ent_type": { - "correct": 3, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 3, - "actual": 3, - "precision": 1.0, - "recall": 1.0, - "f1": 1.0, - }, - "partial": { - "correct": 3, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 3, - "actual": 3, - "precision": 1.0, - "recall": 1.0, - "f1": 1.0, - }, - "exact": { - "correct": 3, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 3, - "actual": 3, - "precision": 1.0, - "recall": 1.0, - "f1": 1.0, - }, - } - - assert results["strict"] == expected["strict"] - assert results["ent_type"] == expected["ent_type"] - assert results["partial"] == expected["partial"] - assert results["exact"] == expected["exact"] + return true, pred -def test_evaluator_extra_classes(): - """Case when model predicts a class that is not in the gold (true) data""" - true = [ - [{"label": "ORG", "start": 1, "end": 3}], - ] - pred = [ - [{"label": "FOO", "start": 1, "end": 3}], - ] - evaluator = Evaluator(true, pred, tags=["ORG", "FOO"]) - results, _, _, _ = evaluator.evaluate() - expected = { - "strict": { - "correct": 0, - "incorrect": 1, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 1, - "actual": 1, - "precision": 0, - "recall": 0.0, - "f1": 0, - }, - "ent_type": { - "correct": 0, - "incorrect": 1, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 1, - "actual": 1, - "precision": 0, - "recall": 0.0, - "f1": 0, - }, - "partial": { - "correct": 1, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 1, - "actual": 1, - "precision": 1.0, - "recall": 1.0, - "f1": 1.0, - }, - "exact": { - "correct": 1, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 1, - "actual": 1, - "precision": 1.0, - "recall": 1.0, - "f1": 1.0, - }, - } - assert results["strict"] == expected["strict"] - assert results["ent_type"] == expected["ent_type"] - assert results["partial"] == expected["partial"] - assert results["exact"] == expected["exact"] +def test_evaluator_initialization(sample_data): + """Test evaluator initialization.""" + true, pred = sample_data + evaluator = Evaluator(true, pred, ["PER", "ORG", "LOC"], loader="list") + assert len(evaluator.true) == 2 + assert len(evaluator.pred) == 2 + assert evaluator.tags == ["PER", "ORG", "LOC"] -def test_evaluator_no_entities_in_prediction(): - """Case when model predicts a class that is not in the gold (true) data""" - true = [ - [{"label": "PER", "start": 2, "end": 4}], - ] - pred = [ - [], - ] - evaluator = Evaluator(true, pred, tags=["PER"]) - results, _, _, _ = evaluator.evaluate() - expected = { - "strict": { - "correct": 0, - "incorrect": 0, - "partial": 0, - "missed": 1, - "spurious": 0, - "possible": 1, - "actual": 0, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "ent_type": { - "correct": 0, - "incorrect": 0, - "partial": 0, - "missed": 1, - "spurious": 0, - "possible": 1, - "actual": 0, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "partial": { - "correct": 0, - "incorrect": 0, - "partial": 0, - "missed": 1, - "spurious": 0, - "possible": 1, - "actual": 0, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "exact": { - "correct": 0, - "incorrect": 0, - "partial": 0, - "missed": 1, - "spurious": 0, - "possible": 1, - "actual": 0, - "precision": 0, - "recall": 0, - "f1": 0, - }, - } - assert results["strict"] == expected["strict"] - assert results["ent_type"] == expected["ent_type"] - assert results["partial"] == expected["partial"] - assert results["exact"] == expected["exact"] +def test_evaluator_evaluation(sample_data): + """Test evaluation process.""" + true, pred = sample_data + evaluator = Evaluator(true, pred, ["PER", "ORG", "LOC"], loader="list") + results = evaluator.evaluate() + # Check that we have results for all strategies + assert "overall" in results + assert "entities" in results + assert "strict" in results["overall"] + assert "partial" in results["overall"] + assert "ent_type" in results["overall"] -def test_evaluator_compare_results_and_results_agg(): - """Check that the label level results match the total results.""" - true = [ - [{"label": "PER", "start": 2, "end": 4}], - ] - pred = [ - [{"label": "PER", "start": 2, "end": 4}], - ] - evaluator = Evaluator(true, pred, tags=["PER"]) - results, results_agg, _, _ = evaluator.evaluate() - expected = { - "strict": { - "correct": 1, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 1, - "actual": 1, - "precision": 1, - "recall": 1, - "f1": 1, - }, - "ent_type": { - "correct": 1, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 1, - "actual": 1, - "precision": 1, - "recall": 1, - "f1": 1, - }, - "partial": { - "correct": 1, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 1, - "actual": 1, - "precision": 1, - "recall": 1, - "f1": 1, - }, - "exact": { - "correct": 1, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 1, - "actual": 1, - "precision": 1, - "recall": 1, - "f1": 1, - }, - } - expected_agg = { - "PER": { - "strict": { - "correct": 1, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 1, - "actual": 1, - "precision": 1, - "recall": 1, - "f1": 1, - }, - "ent_type": { - "correct": 1, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 1, - "actual": 1, - "precision": 1, - "recall": 1, - "f1": 1, - }, - "partial": { - "correct": 1, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 1, - "actual": 1, - "precision": 1, - "recall": 1, - "f1": 1, - }, - "exact": { - "correct": 1, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 1, - "actual": 1, - "precision": 1, - "recall": 1, - "f1": 1, - }, - } - } + # Check that we have results for each entity type + for entity in ["PER", "ORG", "LOC"]: + assert entity in results["entities"] + assert "strict" in results["entities"][entity] + assert "partial" in results["entities"][entity] + assert "ent_type" in results["entities"][entity] - assert results_agg["PER"]["strict"] == expected_agg["PER"]["strict"] - assert results_agg["PER"]["ent_type"] == expected_agg["PER"]["ent_type"] - assert results_agg["PER"]["partial"] == expected_agg["PER"]["partial"] - assert results_agg["PER"]["exact"] == expected_agg["PER"]["exact"] - assert results["strict"] == expected["strict"] - assert results["ent_type"] == expected["ent_type"] - assert results["partial"] == expected["partial"] - assert results["exact"] == expected["exact"] +def test_evaluator_with_invalid_tags(sample_data): + """Test evaluator with invalid tags.""" + true, pred = sample_data + evaluator = Evaluator(true, pred, ["INVALID"], loader="list") + results = evaluator.evaluate() - assert results["strict"] == expected_agg["PER"]["strict"] - assert results["ent_type"] == expected_agg["PER"]["ent_type"] - assert results["partial"] == expected_agg["PER"]["partial"] - assert results["exact"] == expected_agg["PER"]["exact"] + for strategy in ["strict", "partial", "ent_type"]: + assert results["overall"][strategy].correct == 0 + assert results["overall"][strategy].incorrect == 0 + assert results["overall"][strategy].partial == 0 + assert results["overall"][strategy].missed == 0 + assert results["overall"][strategy].spurious == 0 -def test_evaluator_compare_results_and_results_agg_1(): - """Test case when model predicts a label not in the test data.""" +def test_evaluator_different_document_lengths(): + """Test that Evaluator raises ValueError when documents have different lengths.""" true = [ - [], - [{"label": "ORG", "start": 2, "end": 4}], - [{"label": "MISC", "start": 2, "end": 4}], + ["O", "B-PER", "I-PER", "O", "O", "O", "B-ORG", "I-ORG"], # 8 tokens + ["O", "B-LOC", "B-PER", "I-PER", "O", "O", "B-DATE"], # 7 tokens ] pred = [ - [{"label": "PER", "start": 2, "end": 4}], - [{"label": "ORG", "start": 2, "end": 4}], - [{"label": "MISC", "start": 2, "end": 4}], - ] - evaluator = Evaluator(true, pred, tags=["PER", "ORG", "MISC"]) - results, results_agg, _, _ = evaluator.evaluate() - expected = { - "strict": { - "correct": 2, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 1, - "possible": 2, - "actual": 3, - "precision": 0.6666666666666666, - "recall": 1.0, - "f1": 0.8, - }, - "ent_type": { - "correct": 2, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 1, - "possible": 2, - "actual": 3, - "precision": 0.6666666666666666, - "recall": 1.0, - "f1": 0.8, - }, - "partial": { - "correct": 2, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 1, - "possible": 2, - "actual": 3, - "precision": 0.6666666666666666, - "recall": 1.0, - "f1": 0.8, - }, - "exact": { - "correct": 2, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 1, - "possible": 2, - "actual": 3, - "precision": 0.6666666666666666, - "recall": 1.0, - "f1": 0.8, - }, - } - expected_agg = { - "ORG": { - "strict": { - "correct": 1, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 1, - "actual": 1, - "precision": 1.0, - "recall": 1, - "f1": 1.0, - }, - "ent_type": { - "correct": 1, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 1, - "actual": 1, - "precision": 1.0, - "recall": 1, - "f1": 1.0, - }, - "partial": { - "correct": 1, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 1, - "actual": 1, - "precision": 1, - "recall": 1, - "f1": 1.0, - }, - "exact": { - "correct": 1, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 1, - "actual": 1, - "precision": 1, - "recall": 1, - "f1": 1, - }, - }, - "MISC": { - "strict": { - "correct": 1, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 1, - "actual": 1, - "precision": 1, - "recall": 1, - "f1": 1, - }, - "ent_type": { - "correct": 1, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 1, - "actual": 1, - "precision": 1, - "recall": 1, - "f1": 1, - }, - "partial": { - "correct": 1, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 1, - "actual": 1, - "precision": 1, - "recall": 1, - "f1": 1, - }, - "exact": { - "correct": 1, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 1, - "actual": 1, - "precision": 1, - "recall": 1, - "f1": 1, - }, - }, - } - - assert results_agg["ORG"]["strict"] == expected_agg["ORG"]["strict"] - assert results_agg["ORG"]["ent_type"] == expected_agg["ORG"]["ent_type"] - assert results_agg["ORG"]["partial"] == expected_agg["ORG"]["partial"] - assert results_agg["ORG"]["exact"] == expected_agg["ORG"]["exact"] - - assert results_agg["MISC"]["strict"] == expected_agg["MISC"]["strict"] - assert results_agg["MISC"]["ent_type"] == expected_agg["MISC"]["ent_type"] - assert results_agg["MISC"]["partial"] == expected_agg["MISC"]["partial"] - assert results_agg["MISC"]["exact"] == expected_agg["MISC"]["exact"] - - assert results["strict"] == expected["strict"] - assert results["ent_type"] == expected["ent_type"] - assert results["partial"] == expected["partial"] - assert results["exact"] == expected["exact"] - - -def test_evaluator_with_extra_keys_in_pred(): - true = [ - [{"label": "PER", "start": 2, "end": 4}], - [ - {"label": "LOC", "start": 1, "end": 2}, - {"label": "LOC", "start": 3, "end": 4}, - ], + ["O", "B-PER", "I-PER", "O", "O", "O", "B-ORG", "I-ORG"], # 8 tokens + ["O", "B-LOC", "I-LOC", "O", "B-PER", "I-PER", "O", "B-DATE", "I-DATE", "O"], # 10 tokens ] - pred = [ - [{"label": "PER", "start": 2, "end": 4, "token_start": 0, "token_end": 5}], - [ - {"label": "LOC", "start": 1, "end": 2, "token_start": 0, "token_end": 6}, - {"label": "LOC", "start": 3, "end": 4, "token_start": 0, "token_end": 3}, - ], - ] - evaluator = Evaluator(true, pred, tags=["LOC", "PER"]) - results, _, _, _ = evaluator.evaluate() - expected = { - "strict": { - "correct": 3, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 3, - "actual": 3, - "precision": 1.0, - "recall": 1.0, - "f1": 1.0, - }, - "ent_type": { - "correct": 3, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 3, - "actual": 3, - "precision": 1.0, - "recall": 1.0, - "f1": 1.0, - }, - "partial": { - "correct": 3, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 3, - "actual": 3, - "precision": 1.0, - "recall": 1.0, - "f1": 1.0, - }, - "exact": { - "correct": 3, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 3, - "actual": 3, - "precision": 1.0, - "recall": 1.0, - "f1": 1.0, - }, - } - - assert results["strict"] == expected["strict"] - assert results["ent_type"] == expected["ent_type"] - assert results["partial"] == expected["partial"] - assert results["exact"] == expected["exact"] - - -def test_evaluator_with_extra_keys_in_true(): - true = [ - [{"label": "PER", "start": 2, "end": 4, "token_start": 0, "token_end": 4}], - [ - {"label": "LOC", "start": 1, "end": 2, "token_start": 0, "token_end": 5}, - {"label": "LOC", "start": 3, "end": 4, "token_start": 7, "token_end": 9}, - ], - ] - pred = [ - [{"label": "PER", "start": 2, "end": 4}], - [ - {"label": "LOC", "start": 1, "end": 2}, - {"label": "LOC", "start": 3, "end": 4}, - ], - ] - evaluator = Evaluator(true, pred, tags=["LOC", "PER"]) - results, _, _, _ = evaluator.evaluate() - expected = { - "strict": { - "correct": 3, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 3, - "actual": 3, - "precision": 1.0, - "recall": 1.0, - "f1": 1.0, - }, - "ent_type": { - "correct": 3, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 3, - "actual": 3, - "precision": 1.0, - "recall": 1.0, - "f1": 1.0, - }, - "partial": { - "correct": 3, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 3, - "actual": 3, - "precision": 1.0, - "recall": 1.0, - "f1": 1.0, - }, - "exact": { - "correct": 3, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 3, - "actual": 3, - "precision": 1.0, - "recall": 1.0, - "f1": 1.0, - }, - } - - assert results["strict"] == expected["strict"] - assert results["ent_type"] == expected["ent_type"] - assert results["partial"] == expected["partial"] - assert results["exact"] == expected["exact"] - - -def test_issue_29(): - true = [ - [ - {"label": "PER", "start": 1, "end": 2}, - {"label": "PER", "start": 3, "end": 10}, - ] - ] - pred = [ - [ - {"label": "PER", "start": 1, "end": 2}, - {"label": "PER", "start": 3, "end": 5}, - {"label": "PER", "start": 6, "end": 10}, - ] - ] - evaluator = Evaluator(true, pred, tags=["PER"]) - results, _, _, _ = evaluator.evaluate() - expected = { - "strict": { - "correct": 1, - "incorrect": 1, - "partial": 0, - "missed": 0, - "spurious": 1, - "possible": 2, - "actual": 3, - "precision": 0.3333333333333333, - "recall": 0.5, - "f1": 0.4, - }, - "ent_type": { - "correct": 2, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 1, - "possible": 2, - "actual": 3, - "precision": 0.6666666666666666, - "recall": 1.0, - "f1": 0.8, - }, - "partial": { - "correct": 1, - "incorrect": 0, - "partial": 1, - "missed": 0, - "spurious": 1, - "possible": 2, - "actual": 3, - "precision": 0.5, - "recall": 0.75, - "f1": 0.6, - }, - "exact": { - "correct": 1, - "incorrect": 1, - "partial": 0, - "missed": 0, - "spurious": 1, - "possible": 2, - "actual": 3, - "precision": 0.3333333333333333, - "recall": 0.5, - "f1": 0.4, - }, - } - - assert results["strict"] == expected["strict"] - assert results["ent_type"] == expected["ent_type"] - assert results["partial"] == expected["partial"] - assert results["exact"] == expected["exact"] - - -def test_evaluator_compare_results_indices_and_results_agg_indices(): - """Check that the label level results match the total results.""" - true = [ - [{"label": "PER", "start": 2, "end": 4}], - ] - pred = [ - [{"label": "PER", "start": 2, "end": 4}], - ] - evaluator = Evaluator(true, pred, tags=["PER"]) - _, _, evaluation_indices, evaluation_agg_indices = evaluator.evaluate() - expected_evaluation_indices = { - "strict": { - "correct_indices": [(0, 0)], - "incorrect_indices": [], - "partial_indices": [], - "missed_indices": [], - "spurious_indices": [], - }, - "ent_type": { - "correct_indices": [(0, 0)], - "incorrect_indices": [], - "partial_indices": [], - "missed_indices": [], - "spurious_indices": [], - }, - "partial": { - "correct_indices": [(0, 0)], - "incorrect_indices": [], - "partial_indices": [], - "missed_indices": [], - "spurious_indices": [], - }, - "exact": { - "correct_indices": [(0, 0)], - "incorrect_indices": [], - "partial_indices": [], - "missed_indices": [], - "spurious_indices": [], - }, - } - expected_evaluation_agg_indices = { - "PER": { - "strict": { - "correct_indices": [(0, 0)], - "incorrect_indices": [], - "partial_indices": [], - "missed_indices": [], - "spurious_indices": [], - }, - "ent_type": { - "correct_indices": [(0, 0)], - "incorrect_indices": [], - "partial_indices": [], - "missed_indices": [], - "spurious_indices": [], - }, - "partial": { - "correct_indices": [(0, 0)], - "incorrect_indices": [], - "partial_indices": [], - "missed_indices": [], - "spurious_indices": [], - }, - "exact": { - "correct_indices": [(0, 0)], - "incorrect_indices": [], - "partial_indices": [], - "missed_indices": [], - "spurious_indices": [], - }, - } - } - assert evaluation_agg_indices["PER"]["strict"] == expected_evaluation_agg_indices["PER"]["strict"] - assert evaluation_agg_indices["PER"]["ent_type"] == expected_evaluation_agg_indices["PER"]["ent_type"] - assert evaluation_agg_indices["PER"]["partial"] == expected_evaluation_agg_indices["PER"]["partial"] - assert evaluation_agg_indices["PER"]["exact"] == expected_evaluation_agg_indices["PER"]["exact"] - - assert evaluation_indices["strict"] == expected_evaluation_indices["strict"] - assert evaluation_indices["ent_type"] == expected_evaluation_indices["ent_type"] - assert evaluation_indices["partial"] == expected_evaluation_indices["partial"] - assert evaluation_indices["exact"] == expected_evaluation_indices["exact"] - - assert evaluation_indices["strict"] == expected_evaluation_agg_indices["PER"]["strict"] - assert evaluation_indices["ent_type"] == expected_evaluation_agg_indices["PER"]["ent_type"] - assert evaluation_indices["partial"] == expected_evaluation_agg_indices["PER"]["partial"] - assert evaluation_indices["exact"] == expected_evaluation_agg_indices["PER"]["exact"] - - -def test_evaluator_compare_results_indices_and_results_agg_indices_1(): - """Test case when model predicts a label not in the test data.""" - true = [ - [], - [{"label": "ORG", "start": 2, "end": 4}], - [{"label": "MISC", "start": 2, "end": 4}], - ] - pred = [ - [{"label": "PER", "start": 2, "end": 4}], - [{"label": "ORG", "start": 2, "end": 4}], - [{"label": "MISC", "start": 2, "end": 4}], - ] - evaluator = Evaluator(true, pred, tags=["PER", "ORG", "MISC"]) - _, _, evaluation_indices, evaluation_agg_indices = evaluator.evaluate() - - expected_evaluation_indices = { - "strict": { - "correct_indices": [(1, 0), (2, 0)], - "incorrect_indices": [], - "partial_indices": [], - "missed_indices": [], - "spurious_indices": [(0, 0)], - }, - "ent_type": { - "correct_indices": [(1, 0), (2, 0)], - "incorrect_indices": [], - "partial_indices": [], - "missed_indices": [], - "spurious_indices": [(0, 0)], - }, - "partial": { - "correct_indices": [(1, 0), (2, 0)], - "incorrect_indices": [], - "partial_indices": [], - "missed_indices": [], - "spurious_indices": [(0, 0)], - }, - "exact": { - "correct_indices": [(1, 0), (2, 0)], - "incorrect_indices": [], - "partial_indices": [], - "missed_indices": [], - "spurious_indices": [(0, 0)], - }, - } - expected_evaluation_agg_indices = { - "PER": { - "strict": { - "correct_indices": [], - "incorrect_indices": [], - "partial_indices": [], - "missed_indices": [], - "spurious_indices": [(0, 0)], - }, - "ent_type": { - "correct_indices": [], - "incorrect_indices": [], - "partial_indices": [], - "missed_indices": [], - "spurious_indices": [(0, 0)], - }, - "partial": { - "correct_indices": [], - "incorrect_indices": [], - "partial_indices": [], - "missed_indices": [], - "spurious_indices": [(0, 0)], - }, - "exact": { - "correct_indices": [], - "incorrect_indices": [], - "partial_indices": [], - "missed_indices": [], - "spurious_indices": [(0, 0)], - }, - }, - "ORG": { - "strict": { - "correct_indices": [(1, 0)], - "incorrect_indices": [], - "partial_indices": [], - "missed_indices": [], - "spurious_indices": [], - }, - "ent_type": { - "correct_indices": [(1, 0)], - "incorrect_indices": [], - "partial_indices": [], - "missed_indices": [], - "spurious_indices": [], - }, - "partial": { - "correct_indices": [(1, 0)], - "incorrect_indices": [], - "partial_indices": [], - "missed_indices": [], - "spurious_indices": [], - }, - "exact": { - "correct_indices": [(1, 0)], - "incorrect_indices": [], - "partial_indices": [], - "missed_indices": [], - "spurious_indices": [], - }, - }, - "MISC": { - "strict": { - "correct_indices": [(2, 0)], - "incorrect_indices": [], - "partial_indices": [], - "missed_indices": [], - "spurious_indices": [], - }, - "ent_type": { - "correct_indices": [(2, 0)], - "incorrect_indices": [], - "partial_indices": [], - "missed_indices": [], - "spurious_indices": [], - }, - "partial": { - "correct_indices": [(2, 0)], - "incorrect_indices": [], - "partial_indices": [], - "missed_indices": [], - "spurious_indices": [], - }, - "exact": { - "correct_indices": [(2, 0)], - "incorrect_indices": [], - "partial_indices": [], - "missed_indices": [], - "spurious_indices": [], - }, - }, - } - assert evaluation_agg_indices["ORG"]["strict"] == expected_evaluation_agg_indices["ORG"]["strict"] - assert evaluation_agg_indices["ORG"]["ent_type"] == expected_evaluation_agg_indices["ORG"]["ent_type"] - assert evaluation_agg_indices["ORG"]["partial"] == expected_evaluation_agg_indices["ORG"]["partial"] - assert evaluation_agg_indices["ORG"]["exact"] == expected_evaluation_agg_indices["ORG"]["exact"] - - assert evaluation_agg_indices["MISC"]["strict"] == expected_evaluation_agg_indices["MISC"]["strict"] - assert evaluation_agg_indices["MISC"]["ent_type"] == expected_evaluation_agg_indices["MISC"]["ent_type"] - assert evaluation_agg_indices["MISC"]["partial"] == expected_evaluation_agg_indices["MISC"]["partial"] - assert evaluation_agg_indices["MISC"]["exact"] == expected_evaluation_agg_indices["MISC"]["exact"] + tags = ["PER", "ORG", "LOC", "DATE"] - assert evaluation_indices["strict"] == expected_evaluation_indices["strict"] - assert evaluation_indices["ent_type"] == expected_evaluation_indices["ent_type"] - assert evaluation_indices["partial"] == expected_evaluation_indices["partial"] - assert evaluation_indices["exact"] == expected_evaluation_indices["exact"] + # Test that ValueError is raised + with pytest.raises(ValueError, match="Document 1 has different lengths: true=7, pred=10"): + evaluator = Evaluator(true=true, pred=pred, tags=tags, loader="list") + evaluator.evaluate() diff --git a/tests/test_evaluator_new.py b/tests/test_evaluator_new.py deleted file mode 100644 index f4fc0ef..0000000 --- a/tests/test_evaluator_new.py +++ /dev/null @@ -1,80 +0,0 @@ -import pytest -from nervaluate.evaluator import Evaluator - - -@pytest.fixture -def sample_data(): - true = [ - ["O", "B-PER", "O", "B-ORG", "I-ORG", "B-LOC"], - ["O", "B-PER", "O", "B-ORG"], - ] - - pred = [ - ["O", "B-PER", "O", "B-ORG", "O", "B-PER"], - ["O", "B-PER", "O", "B-LOC"], - ] - - return true, pred - - -def test_evaluator_initialization(sample_data): - """Test evaluator initialization.""" - true, pred = sample_data - evaluator = Evaluator(true, pred, ["PER", "ORG", "LOC"], loader="list") - - assert len(evaluator.true) == 2 - assert len(evaluator.pred) == 2 - assert evaluator.tags == ["PER", "ORG", "LOC"] - - -def test_evaluator_evaluation(sample_data): - """Test evaluation process.""" - true, pred = sample_data - evaluator = Evaluator(true, pred, ["PER", "ORG", "LOC"], loader="list") - results = evaluator.evaluate() - - # Check that we have results for all strategies - assert "overall" in results - assert "entities" in results - assert "strict" in results["overall"] - assert "partial" in results["overall"] - assert "ent_type" in results["overall"] - - # Check that we have results for each entity type - for entity in ["PER", "ORG", "LOC"]: - assert entity in results["entities"] - assert "strict" in results["entities"][entity] - assert "partial" in results["entities"][entity] - assert "ent_type" in results["entities"][entity] - - -def test_evaluator_with_invalid_tags(sample_data): - """Test evaluator with invalid tags.""" - true, pred = sample_data - evaluator = Evaluator(true, pred, ["INVALID"], loader="list") - results = evaluator.evaluate() - - for strategy in ["strict", "partial", "ent_type"]: - assert results["overall"][strategy].correct == 0 - assert results["overall"][strategy].incorrect == 0 - assert results["overall"][strategy].partial == 0 - assert results["overall"][strategy].missed == 0 - assert results["overall"][strategy].spurious == 0 - - -def test_evaluator_different_document_lengths(): - """Test that Evaluator raises ValueError when documents have different lengths.""" - true = [ - ["O", "B-PER", "I-PER", "O", "O", "O", "B-ORG", "I-ORG"], # 8 tokens - ["O", "B-LOC", "B-PER", "I-PER", "O", "O", "B-DATE"], # 7 tokens - ] - pred = [ - ["O", "B-PER", "I-PER", "O", "O", "O", "B-ORG", "I-ORG"], # 8 tokens - ["O", "B-LOC", "I-LOC", "O", "B-PER", "I-PER", "O", "B-DATE", "I-DATE", "O"], # 10 tokens - ] - tags = ["PER", "ORG", "LOC", "DATE"] - - # Test that ValueError is raised - with pytest.raises(ValueError, match="Document 1 has different lengths: true=7, pred=10"): - evaluator = Evaluator(true=true, pred=pred, tags=tags, loader="list") - evaluator.evaluate() diff --git a/tests/test_loaders.py b/tests/test_loaders.py index 2c92f23..13a9f92 100644 --- a/tests/test_loaders.py +++ b/tests/test_loaders.py @@ -4,63 +4,6 @@ from nervaluate.loaders import ConllLoader, ListLoader, DictLoader -def test_loaders_produce_the_same_results(): - true_list = [ - ["O", "O", "O", "O", "O", "O"], - ["O", "O", "B-ORG", "I-ORG", "O", "O"], - ["O", "O", "B-MISC", "I-MISC", "O", "O"], - ["B-MISC", "I-MISC", "I-MISC", "I-MISC", "I-MISC", "I-MISC"], - ] - - pred_list = [ - ["O", "O", "B-PER", "I-PER", "O", "O"], - ["O", "O", "B-ORG", "I-ORG", "O", "O"], - ["O", "O", "B-MISC", "I-MISC", "O", "O"], - ["B-MISC", "I-MISC", "I-MISC", "I-MISC", "I-MISC", "I-MISC"], - ] - - true_conll = ( - "word\tO\nword\tO\nword\tO\nword\tO\nword\tO\nword\tO\n\n" - "word\tO\nword\tO\nword\tB-ORG\nword\tI-ORG\nword\tO\nword\tO\n\n" - "word\tO\nword\tO\nword\tB-MISC\nword\tI-MISC\nword\tO\nword\tO\n\n" - "word\tB-MISC\nword\tI-MISC\nword\tI-MISC\nword\tI-MISC\nword\tI-MISC\nword\tI-MISC\n" - ) - - pred_conll = ( - "word\tO\nword\tO\nword\tB-PER\nword\tI-PER\nword\tO\nword\tO\n\n" - "word\tO\nword\tO\nword\tB-ORG\nword\tI-ORG\nword\tO\nword\tO\n\n" - "word\tO\nword\tO\nword\tB-MISC\nword\tI-MISC\nword\tO\nword\tO\n\n" - "word\tB-MISC\nword\tI-MISC\nword\tI-MISC\nword\tI-MISC\nword\tI-MISC\nword\tI-MISC\n" - ) - - true_prod = [ - [], - [{"label": "ORG", "start": 2, "end": 3}], - [{"label": "MISC", "start": 2, "end": 3}], - [{"label": "MISC", "start": 0, "end": 5}], - ] - - pred_prod = [ - [{"label": "PER", "start": 2, "end": 3}], - [{"label": "ORG", "start": 2, "end": 3}], - [{"label": "MISC", "start": 2, "end": 3}], - [{"label": "MISC", "start": 0, "end": 5}], - ] - - evaluator_list = Evaluator(true_list, pred_list, tags=["PER", "ORG", "MISC"], loader="list") - - evaluator_conll = Evaluator(true_conll, pred_conll, tags=["PER", "ORG", "MISC"], loader="conll") - - evaluator_prod = Evaluator(true_prod, pred_prod, tags=["PER", "ORG", "MISC"]) - - _, _, _, _ = evaluator_list.evaluate() - _, _, _, _ = evaluator_prod.evaluate() - _, _, _, _ = evaluator_conll.evaluate() - - assert evaluator_prod.pred == evaluator_list.pred == evaluator_conll.pred - assert evaluator_prod.true == evaluator_list.true == evaluator_conll.true - - def test_conll_loader(): """Test CoNLL format loader.""" true_conll = ( diff --git a/tests/test_nervaluate.py b/tests/test_nervaluate.py deleted file mode 100644 index 56fca6e..0000000 --- a/tests/test_nervaluate.py +++ /dev/null @@ -1,909 +0,0 @@ -import pytest - -from nervaluate import ( - Evaluator, - compute_actual_possible, - compute_metrics, - compute_precision_recall, - compute_precision_recall_wrapper, -) - - -def test_compute_metrics_case_1(): - true_named_entities = [ - {"label": "PER", "start": 59, "end": 69}, - {"label": "LOC", "start": 127, "end": 134}, - {"label": "LOC", "start": 164, "end": 174}, - {"label": "LOC", "start": 197, "end": 205}, - {"label": "LOC", "start": 208, "end": 219}, - {"label": "MISC", "start": 230, "end": 240}, - ] - pred_named_entities = [ - {"label": "PER", "start": 24, "end": 30}, - {"label": "LOC", "start": 124, "end": 134}, - {"label": "PER", "start": 164, "end": 174}, - {"label": "LOC", "start": 197, "end": 205}, - {"label": "LOC", "start": 208, "end": 219}, - {"label": "LOC", "start": 225, "end": 243}, - ] - results, _, _, _ = compute_metrics(true_named_entities, pred_named_entities, ["PER", "LOC", "MISC"]) - results = compute_precision_recall_wrapper(results) - expected = { - "strict": { - "correct": 2, - "incorrect": 3, - "partial": 0, - "missed": 1, - "spurious": 1, - "possible": 6, - "actual": 6, - "precision": 0.3333333333333333, - "recall": 0.3333333333333333, - "f1": 0.3333333333333333, - }, - "ent_type": { - "correct": 3, - "incorrect": 2, - "partial": 0, - "missed": 1, - "spurious": 1, - "possible": 6, - "actual": 6, - "precision": 0.5, - "recall": 0.5, - "f1": 0.5, - }, - "partial": { - "correct": 3, - "incorrect": 0, - "partial": 2, - "missed": 1, - "spurious": 1, - "possible": 6, - "actual": 6, - "precision": 0.6666666666666666, - "recall": 0.6666666666666666, - "f1": 0.6666666666666666, - }, - "exact": { - "correct": 3, - "incorrect": 2, - "partial": 0, - "missed": 1, - "spurious": 1, - "possible": 6, - "actual": 6, - "precision": 0.5, - "recall": 0.5, - "f1": 0.5, - }, - } - assert results == expected - - -def test_compute_metrics_agg_scenario_3(): - true_named_entities = [{"label": "PER", "start": 59, "end": 69}] - pred_named_entities = [] - _, results_agg, _, _ = compute_metrics(true_named_entities, pred_named_entities, ["PER"]) - expected_agg = { - "PER": { - "strict": { - "correct": 0, - "incorrect": 0, - "partial": 0, - "missed": 1, - "spurious": 0, - "actual": 0, - "possible": 1, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "ent_type": { - "correct": 0, - "incorrect": 0, - "partial": 0, - "missed": 1, - "spurious": 0, - "actual": 0, - "possible": 1, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "partial": { - "correct": 0, - "incorrect": 0, - "partial": 0, - "missed": 1, - "spurious": 0, - "actual": 0, - "possible": 1, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "exact": { - "correct": 0, - "incorrect": 0, - "partial": 0, - "missed": 1, - "spurious": 0, - "actual": 0, - "possible": 1, - "precision": 0, - "recall": 0, - "f1": 0, - }, - } - } - - assert results_agg["PER"]["strict"] == expected_agg["PER"]["strict"] - assert results_agg["PER"]["ent_type"] == expected_agg["PER"]["ent_type"] - assert results_agg["PER"]["partial"] == expected_agg["PER"]["partial"] - assert results_agg["PER"]["exact"] == expected_agg["PER"]["exact"] - - -def test_compute_metrics_agg_scenario_2(): - true_named_entities = [] - pred_named_entities = [{"label": "PER", "start": 59, "end": 69}] - _, results_agg, _, _ = compute_metrics(true_named_entities, pred_named_entities, ["PER"]) - expected_agg = { - "PER": { - "strict": { - "correct": 0, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 1, - "actual": 1, - "possible": 0, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "ent_type": { - "correct": 0, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 1, - "actual": 1, - "possible": 0, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "partial": { - "correct": 0, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 1, - "actual": 1, - "possible": 0, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "exact": { - "correct": 0, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 1, - "actual": 1, - "possible": 0, - "precision": 0, - "recall": 0, - "f1": 0, - }, - } - } - - assert results_agg["PER"]["strict"] == expected_agg["PER"]["strict"] - assert results_agg["PER"]["ent_type"] == expected_agg["PER"]["ent_type"] - assert results_agg["PER"]["partial"] == expected_agg["PER"]["partial"] - assert results_agg["PER"]["exact"] == expected_agg["PER"]["exact"] - - -def test_compute_metrics_agg_scenario_5(): - true_named_entities = [{"label": "PER", "start": 59, "end": 69}] - pred_named_entities = [{"label": "PER", "start": 57, "end": 69}] - _, results_agg, _, _ = compute_metrics(true_named_entities, pred_named_entities, ["PER"]) - expected_agg = { - "PER": { - "strict": { - "correct": 0, - "incorrect": 1, - "partial": 0, - "missed": 0, - "spurious": 0, - "actual": 1, - "possible": 1, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "ent_type": { - "correct": 1, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "actual": 1, - "possible": 1, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "partial": { - "correct": 0, - "incorrect": 0, - "partial": 1, - "missed": 0, - "spurious": 0, - "actual": 1, - "possible": 1, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "exact": { - "correct": 0, - "incorrect": 1, - "partial": 0, - "missed": 0, - "spurious": 0, - "actual": 1, - "possible": 1, - "precision": 0, - "recall": 0, - "f1": 0, - }, - } - } - - assert results_agg["PER"]["strict"] == expected_agg["PER"]["strict"] - assert results_agg["PER"]["ent_type"] == expected_agg["PER"]["ent_type"] - assert results_agg["PER"]["partial"] == expected_agg["PER"]["partial"] - assert results_agg["PER"]["exact"] == expected_agg["PER"]["exact"] - - -def test_compute_metrics_agg_scenario_4(): - true_named_entities = [{"label": "PER", "start": 59, "end": 69}] - pred_named_entities = [{"label": "LOC", "start": 59, "end": 69}] - _, results_agg, _, _ = compute_metrics(true_named_entities, pred_named_entities, ["PER", "LOC"]) - expected_agg = { - "PER": { - "strict": { - "correct": 0, - "incorrect": 1, - "partial": 0, - "missed": 0, - "spurious": 0, - "actual": 1, - "possible": 1, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "ent_type": { - "correct": 0, - "incorrect": 1, - "partial": 0, - "missed": 0, - "spurious": 0, - "actual": 1, - "possible": 1, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "partial": { - "correct": 1, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "actual": 1, - "possible": 1, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "exact": { - "correct": 1, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "actual": 1, - "possible": 1, - "precision": 0, - "recall": 0, - "f1": 0, - }, - }, - "LOC": { - "strict": { - "correct": 0, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "actual": 0, - "possible": 0, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "ent_type": { - "correct": 0, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "actual": 0, - "possible": 0, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "partial": { - "correct": 0, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "actual": 0, - "possible": 0, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "exact": { - "correct": 0, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "actual": 0, - "possible": 0, - "precision": 0, - "recall": 0, - "f1": 0, - }, - }, - } - - assert results_agg["PER"]["strict"] == expected_agg["PER"]["strict"] - assert results_agg["PER"]["ent_type"] == expected_agg["PER"]["ent_type"] - assert results_agg["PER"]["partial"] == expected_agg["PER"]["partial"] - assert results_agg["PER"]["exact"] == expected_agg["PER"]["exact"] - assert results_agg["LOC"] == expected_agg["LOC"] - - -def test_compute_metrics_agg_scenario_1(): - true_named_entities = [{"label": "PER", "start": 59, "end": 69}] - pred_named_entities = [{"label": "PER", "start": 59, "end": 69}] - _, results_agg, _, _ = compute_metrics(true_named_entities, pred_named_entities, ["PER"]) - expected_agg = { - "PER": { - "strict": { - "correct": 1, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "actual": 1, - "possible": 1, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "ent_type": { - "correct": 1, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "actual": 1, - "possible": 1, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "partial": { - "correct": 1, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "actual": 1, - "possible": 1, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "exact": { - "correct": 1, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "actual": 1, - "possible": 1, - "precision": 0, - "recall": 0, - "f1": 0, - }, - } - } - - assert results_agg["PER"]["strict"] == expected_agg["PER"]["strict"] - assert results_agg["PER"]["ent_type"] == expected_agg["PER"]["ent_type"] - assert results_agg["PER"]["partial"] == expected_agg["PER"]["partial"] - assert results_agg["PER"]["exact"] == expected_agg["PER"]["exact"] - - -def test_compute_metrics_agg_scenario_6(): - true_named_entities = [{"label": "PER", "start": 59, "end": 69}] - pred_named_entities = [{"label": "LOC", "start": 54, "end": 69}] - _, results_agg, _, _ = compute_metrics(true_named_entities, pred_named_entities, ["PER", "LOC"]) - expected_agg = { - "PER": { - "strict": { - "correct": 0, - "incorrect": 1, - "partial": 0, - "missed": 0, - "spurious": 0, - "actual": 1, - "possible": 1, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "ent_type": { - "correct": 0, - "incorrect": 1, - "partial": 0, - "missed": 0, - "spurious": 0, - "actual": 1, - "possible": 1, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "partial": { - "correct": 0, - "incorrect": 0, - "partial": 1, - "missed": 0, - "spurious": 0, - "actual": 1, - "possible": 1, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "exact": { - "correct": 0, - "incorrect": 1, - "partial": 0, - "missed": 0, - "spurious": 0, - "actual": 1, - "possible": 1, - "precision": 0, - "recall": 0, - "f1": 0, - }, - }, - "LOC": { - "strict": { - "correct": 0, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "actual": 0, - "possible": 0, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "ent_type": { - "correct": 0, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "actual": 0, - "possible": 0, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "partial": { - "correct": 0, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "actual": 0, - "possible": 0, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "exact": { - "correct": 0, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "actual": 0, - "possible": 0, - "precision": 0, - "recall": 0, - "f1": 0, - }, - }, - } - - assert results_agg["PER"]["strict"] == expected_agg["PER"]["strict"] - assert results_agg["PER"]["ent_type"] == expected_agg["PER"]["ent_type"] - assert results_agg["PER"]["partial"] == expected_agg["PER"]["partial"] - assert results_agg["PER"]["exact"] == expected_agg["PER"]["exact"] - assert results_agg["LOC"] == expected_agg["LOC"] - - -def test_compute_metrics_extra_tags_in_prediction(): - true_named_entities = [ - {"label": "PER", "start": 50, "end": 52}, - {"label": "ORG", "start": 59, "end": 69}, - {"label": "ORG", "start": 71, "end": 72}, - ] - - pred_named_entities = [ - {"label": "LOC", "start": 50, "end": 52}, # Wrong type - {"label": "ORG", "start": 59, "end": 69}, # Correct - {"label": "MISC", "start": 71, "end": 72}, # Wrong type - ] - results, _, _, _ = compute_metrics(true_named_entities, pred_named_entities, ["PER", "LOC", "ORG"]) - expected = { - "strict": { - "correct": 1, - "incorrect": 1, - "partial": 0, - "missed": 1, - "spurious": 0, - "actual": 2, - "possible": 3, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "ent_type": { - "correct": 1, - "incorrect": 1, - "partial": 0, - "missed": 1, - "spurious": 0, - "actual": 2, - "possible": 3, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "partial": { - "correct": 2, - "incorrect": 0, - "partial": 0, - "missed": 1, - "spurious": 0, - "actual": 2, - "possible": 3, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "exact": { - "correct": 2, - "incorrect": 0, - "partial": 0, - "missed": 1, - "spurious": 0, - "actual": 2, - "possible": 3, - "precision": 0, - "recall": 0, - "f1": 0, - }, - } - - assert results["strict"] == expected["strict"] - assert results["ent_type"] == expected["ent_type"] - assert results["partial"] == expected["partial"] - assert results["exact"] == expected["exact"] - - -def test_compute_metrics_extra_tags_in_true(): - true_named_entities = [ - {"label": "PER", "start": 50, "end": 52}, - {"label": "ORG", "start": 59, "end": 69}, - {"label": "MISC", "start": 71, "end": 72}, - ] - - pred_named_entities = [ - {"label": "LOC", "start": 50, "end": 52}, # Wrong type - {"label": "ORG", "start": 59, "end": 69}, # Correct - {"label": "ORG", "start": 71, "end": 72}, # Spurious - ] - - results, _, _, _ = compute_metrics(true_named_entities, pred_named_entities, ["PER", "LOC", "ORG"]) - - expected = { - "strict": { - "correct": 1, - "incorrect": 1, - "partial": 0, - "missed": 0, - "spurious": 1, - "actual": 3, - "possible": 2, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "ent_type": { - "correct": 1, - "incorrect": 1, - "partial": 0, - "missed": 0, - "spurious": 1, - "actual": 3, - "possible": 2, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "partial": { - "correct": 2, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 1, - "actual": 3, - "possible": 2, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "exact": { - "correct": 2, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 1, - "actual": 3, - "possible": 2, - "precision": 0, - "recall": 0, - "f1": 0, - }, - } - - assert results["strict"] == expected["strict"] - assert results["ent_type"] == expected["ent_type"] - assert results["partial"] == expected["partial"] - assert results["exact"] == expected["exact"] - - -def test_compute_metrics_no_predictions(): - true_named_entities = [ - {"label": "PER", "start": 50, "end": 52}, - {"label": "ORG", "start": 59, "end": 69}, - {"label": "MISC", "start": 71, "end": 72}, - ] - pred_named_entities = [] - results, _, _, _ = compute_metrics(true_named_entities, pred_named_entities, ["PER", "ORG", "MISC"]) - expected = { - "strict": { - "correct": 0, - "incorrect": 0, - "partial": 0, - "missed": 3, - "spurious": 0, - "actual": 0, - "possible": 3, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "ent_type": { - "correct": 0, - "incorrect": 0, - "partial": 0, - "missed": 3, - "spurious": 0, - "actual": 0, - "possible": 3, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "partial": { - "correct": 0, - "incorrect": 0, - "partial": 0, - "missed": 3, - "spurious": 0, - "actual": 0, - "possible": 3, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "exact": { - "correct": 0, - "incorrect": 0, - "partial": 0, - "missed": 3, - "spurious": 0, - "actual": 0, - "possible": 3, - "precision": 0, - "recall": 0, - "f1": 0, - }, - } - - assert results["strict"] == expected["strict"] - assert results["ent_type"] == expected["ent_type"] - assert results["partial"] == expected["partial"] - assert results["exact"] == expected["exact"] - - -def test_compute_actual_possible(): - results = { - "correct": 6, - "incorrect": 3, - "partial": 2, - "missed": 4, - "spurious": 2, - } - - expected = { - "correct": 6, - "incorrect": 3, - "partial": 2, - "missed": 4, - "spurious": 2, - "possible": 15, - "actual": 13, - } - - out = compute_actual_possible(results) - - assert out == expected - - -def test_compute_precision_recall(): - results = { - "correct": 6, - "incorrect": 3, - "partial": 2, - "missed": 4, - "spurious": 2, - "possible": 15, - "actual": 13, - } - - expected = { - "correct": 6, - "incorrect": 3, - "partial": 2, - "missed": 4, - "spurious": 2, - "possible": 15, - "actual": 13, - "precision": 0.46153846153846156, - "recall": 0.4, - "f1": 0.42857142857142855, - } - - out = compute_precision_recall(results) - - assert out == expected - - -def test_compute_metrics_one_pred_two_true(): - true_named_entities_1 = [ - {"start": 0, "end": 12, "label": "A"}, - {"start": 14, "end": 17, "label": "B"}, - ] - true_named_entities_2 = [ - {"start": 14, "end": 17, "label": "B"}, - {"start": 0, "end": 12, "label": "A"}, - ] - pred_named_entities = [ - {"start": 0, "end": 17, "label": "A"}, - ] - - results1, _, _, _ = compute_metrics(true_named_entities_1, pred_named_entities, ["A", "B"]) - results2, _, _, _ = compute_metrics(true_named_entities_2, pred_named_entities, ["A", "B"]) - - expected = { - "ent_type": { - "correct": 1, - "incorrect": 1, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 2, - "actual": 2, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "partial": { - "correct": 0, - "incorrect": 0, - "partial": 2, - "missed": 0, - "spurious": 0, - "possible": 2, - "actual": 2, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "strict": { - "correct": 0, - "incorrect": 2, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 2, - "actual": 2, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "exact": { - "correct": 0, - "incorrect": 2, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 2, - "actual": 2, - "precision": 0, - "recall": 0, - "f1": 0, - }, - } - - assert results1 == expected - assert results2 == expected - - -def test_evaluator_different_number_of_documents(): - """Test that Evaluator raises ValueError when number of predicted documents doesn't match true documents.""" - - # Create test data with different number of documents - true = [ - [{"label": "PER", "start": 0, "end": 5}], # First document - [{"label": "LOC", "start": 10, "end": 15}], # Second document - ] - pred = [[{"label": "PER", "start": 0, "end": 5}]] # Only one document - tags = ["PER", "LOC"] - - # Test that ValueError is raised - with pytest.raises(ValueError, match="Number of predicted documents does not equal true"): - evaluator = Evaluator(true=true, pred=pred, tags=tags) - evaluator.evaluate() diff --git a/tests/test_reporting.py b/tests/test_reporting.py deleted file mode 100644 index 029e88f..0000000 --- a/tests/test_reporting.py +++ /dev/null @@ -1,111 +0,0 @@ -import pytest - -from nervaluate.reporting import summary_report_ent, summary_report_overall - - -def test_summary_report_ent(): - # Sample input data - results_agg_entities_type = { - "PER": { - "strict": { - "correct": 10, - "incorrect": 2, - "partial": 1, - "missed": 3, - "spurious": 2, - "precision": 0.769, - "recall": 0.714, - "f1": 0.741, - } - }, - "ORG": { - "strict": { - "correct": 15, - "incorrect": 1, - "partial": 0, - "missed": 2, - "spurious": 1, - "precision": 0.882, - "recall": 0.833, - "f1": 0.857, - } - }, - } - - # Call the function - report = summary_report_ent(results_agg_entities_type, scenario="strict", digits=3) - - # Verify the report contains expected content - assert "PER" in report - assert "ORG" in report - assert "correct" in report - assert "incorrect" in report - assert "partial" in report - assert "missed" in report - assert "spurious" in report - assert "precision" in report - assert "recall" in report - assert "f1-score" in report - - # Verify specific values are present - assert "10" in report # PER correct - assert "15" in report # ORG correct - assert "0.769" in report # PER precision - assert "0.857" in report # ORG f1 - - # Test invalid scenario - with pytest.raises(Exception) as exc_info: - summary_report_ent(results_agg_entities_type, scenario="invalid") - assert "Invalid scenario" in str(exc_info.value) - - -def test_summary_report_overall(): - # Sample input data - results = { - "strict": { - "correct": 25, - "incorrect": 3, - "partial": 1, - "missed": 5, - "spurious": 3, - "precision": 0.862, - "recall": 0.806, - "f1": 0.833, - }, - "ent_type": { - "correct": 26, - "incorrect": 2, - "partial": 1, - "missed": 4, - "spurious": 3, - "precision": 0.897, - "recall": 0.839, - "f1": 0.867, - }, - } - - # Call the function - report = summary_report_overall(results, digits=3) - - # Verify the report contains expected content - assert "strict" in report - assert "ent_type" in report - assert "correct" in report - assert "incorrect" in report - assert "partial" in report - assert "missed" in report - assert "spurious" in report - assert "precision" in report - assert "recall" in report - assert "f1-score" in report - - # Verify specific values are present - assert "25" in report # strict correct - assert "26" in report # ent_type correct - assert "0.862" in report # strict precision - assert "0.867" in report # ent_type f1 - - # Test with different number of digits - report_2digits = summary_report_overall(results, digits=2) - assert "0.86" in report_2digits # strict precision with 2 digits - assert "0.87" in report_2digits # ent_type f1 with 2 digits diff --git a/tests/test_utils.py b/tests/test_utils.py index db5c76d..d26e3aa 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,7 +1,6 @@ from nervaluate import ( collect_named_entities, conll_to_spans, - find_overlap, list_to_spans, split_list, ) @@ -134,63 +133,3 @@ def test_collect_named_entities_no_entity(): result = collect_named_entities(tags) expected = [] assert result == expected - - -def test_find_overlap_no_overlap(): - pred_entity = {"label": "LOC", "start": 1, "end": 10} - true_entity = {"label": "LOC", "start": 11, "end": 20} - - pred_range = range(pred_entity["start"], pred_entity["end"]) - true_range = range(true_entity["start"], true_entity["end"]) - - pred_set = set(pred_range) - true_set = set(true_range) - - intersect = find_overlap(pred_set, true_set) - - assert not intersect - - -def test_find_overlap_total_overlap(): - pred_entity = {"label": "LOC", "start": 10, "end": 22} - true_entity = {"label": "LOC", "start": 11, "end": 20} - - pred_range = range(pred_entity["start"], pred_entity["end"]) - true_range = range(true_entity["start"], true_entity["end"]) - - pred_set = set(pred_range) - true_set = set(true_range) - - intersect = find_overlap(pred_set, true_set) - - assert intersect - - -def test_find_overlap_start_overlap(): - pred_entity = {"label": "LOC", "start": 5, "end": 12} - true_entity = {"label": "LOC", "start": 11, "end": 20} - - pred_range = range(pred_entity["start"], pred_entity["end"]) - true_range = range(true_entity["start"], true_entity["end"]) - - pred_set = set(pred_range) - true_set = set(true_range) - - intersect = find_overlap(pred_set, true_set) - - assert intersect - - -def test_find_overlap_end_overlap(): - pred_entity = {"label": "LOC", "start": 15, "end": 25} - true_entity = {"label": "LOC", "start": 11, "end": 20} - - pred_range = range(pred_entity["start"], pred_entity["end"]) - true_range = range(true_entity["start"], true_entity["end"]) - - pred_set = set(pred_range) - true_set = set(true_range) - - intersect = find_overlap(pred_set, true_set) - - assert intersect From 128fc13c0e4967057ba843d2a61361c345139cca Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Mon, 2 Jun 2025 10:38:20 +0200 Subject: [PATCH 38/41] removing old files --- tests/test_loaders.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_loaders.py b/tests/test_loaders.py index 13a9f92..99316ec 100644 --- a/tests/test_loaders.py +++ b/tests/test_loaders.py @@ -1,6 +1,5 @@ import pytest -from nervaluate import Evaluator from nervaluate.loaders import ConllLoader, ListLoader, DictLoader From 5cf172b0972cb31363330fef6570be13fa8317d6 Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Sat, 7 Jun 2025 10:31:06 +0200 Subject: [PATCH 39/41] renaming evaluation_strategies to strategies and improving README --- README.md | 159 ++++++++++-------- src/nervaluate/evaluator.py | 2 +- ...evaluation_strategies.py => strategies.py} | 0 ...ation_strategies.py => test_strategies.py} | 6 +- 4 files changed, 89 insertions(+), 78 deletions(-) rename src/nervaluate/{evaluation_strategies.py => strategies.py} (100%) rename tests/{test_evaluation_strategies.py => test_strategies.py} (99%) diff --git a/README.md b/README.md index 572cea1..1346074 100644 --- a/README.md +++ b/README.md @@ -19,16 +19,73 @@ based on whether all the tokens that belong to a named entity were classified or entity type was assigned. This full problem is described in detail in the [original blog](http://www.davidsbatista.net/blog/2018/05/09/Named_Entity_Evaluation/) -post by [David Batista](https://github.com/davidsbatista), and extends the code in the [original repository](https://github.com/davidsbatista/NER-Evaluation) +post by [David Batista](https://github.com/davidsbatista), and this package extends the code in the [original repository](https://github.com/davidsbatista/NER-Evaluation) which accompanied the blog post. The code draws heavily on the papers: * [SemEval-2013 Task 9 : Extraction of Drug-Drug Interactions from Biomedical Texts (DDIExtraction 2013)](https://www.aclweb.org/anthology/S13-2056) - * [SemEval-2013 Task 9.1 - Evaluation Metrics](https://davidsbatista.net/assets/documents/others/semeval_2013-task-9_1-evaluation-metrics.pdf) +# Usage example + +``` +pip install nervaluate +``` + +A possible input format are lists of NER labels, where each list corresponds to a sentence and each label is a token label. +Initialize the `Evaluator` class with the true labels and predicted labels, and specify the entity types we want to evaluate. + +```python +from nervaluate.evaluator import Evaluator + +true = [ + ['O', 'B-PER', 'I-PER', 'O', 'O', 'O', 'B-ORG', 'I-ORG'], # "The John Smith who works at Google Inc" + ['O', 'B-LOC', 'B-PER', 'I-PER', 'O', 'O', 'B-DATE'], # "In Paris Marie Curie lived in 1895" +] + +pred = [ + ['O', 'O', 'B-PER', 'I-PER', 'O', 'O', 'B-ORG', 'I-ORG'], + ['O', 'B-LOC', 'I-LOC', 'B-PER', 'O', 'O', 'B-DATE'], +] + +evaluator = Evaluator(true, pred, tags=['PER', 'ORG', 'LOC', 'DATE'], loader="list") +``` + +Print the summary report for the evaluation, which will show the metrics for each entity type and evaluation scenario: + +```python + +print(evaluator.summary_report()) + +Scenario: all + + correct incorrect partial missed spurious precision recall f1-score + +ent_type 5 0 0 0 0 1.00 1.00 1.00 + exact 2 3 0 0 0 0.40 0.40 0.40 + partial 2 0 3 0 0 0.40 0.40 0.40 + strict 2 3 0 0 0 0.40 0.40 0.40 +``` + +or aggregated by entity type under a specific evaluation scenario: + +```python +print(evaluator.summary_report(mode='entities')) + +Scenario: strict + + correct incorrect partial missed spurious precision recall f1-score + + DATE 1 0 0 0 0 1.00 1.00 1.00 + LOC 0 1 0 0 0 0.00 0.00 0.00 + ORG 1 0 0 0 0 1.00 1.00 1.00 + PER 0 2 0 0 0 0.00 0.00 0.00 +``` + +# Evaluation Scenarios + ## Token level evaluation for NER is too simplistic When running machine learning models for NER, it is common to report metrics at the individual token level. This may @@ -69,12 +126,12 @@ positives, true positives, false negatives and false positives, and subsequently F1-score for each named-entity type. However, this simple schema ignores the possibility of partial matches or other scenarios when the NER system gets -the named-entity surface string correct but the type wrong, and we might also want to evaluate these scenarios +the named-entity surface string correct but the type wrong. We might also want to evaluate these scenarios again at a full-entity level. For example: -__IV. System assigns the wrong entity type__ +__IV. System identifies the surface string but assigns the wrong entity type__ | Token | Gold | Prediction | |-------|-------|------------| @@ -103,10 +160,13 @@ __VI. System gets the boundaries and entity type wrong__ | Smith | I-PER | I-ORG | | resigns | O | O | + +## Defining evaluation metrics + How can we incorporate these described scenarios into evaluation metrics? See the [original blog](http://www.davidsbatista.net/blog/2018/05/09/Named_Entity_Evaluation/) -for a great explanation, a summary is included here: +for a great explanation, a summary is included here. -We can use the following five metrics to consider difference categories of errors: +We can define the following five metrics to consider different categories of errors: | Error type | Explanation | |-----------------|--------------------------------------------------------------------------| @@ -136,24 +196,27 @@ These five errors and four evaluation schema interact in the following ways: | I | DRUG | phenytoin | DRUG | phenytoin | COR | COR | COR | COR | | VI | GROUP | contraceptives | DRUG | oral contraceptives | INC | PAR | INC | INC | -Then precision/recall/f1-score are calculated for each different evaluation schema. In order to achieve data, two more -quantities need to be calculated: +Then precision, recall and f1-score are calculated for each different evaluation schema. In order to achieve data, +two more quantities need to be calculated: ``` POSSIBLE (POS) = COR + INC + PAR + MIS = TP + FN ACTUAL (ACT) = COR + INC + PAR + SPU = TP + FP ``` -Then we can compute precision/recall/f1-score, where roughly describing precision is the percentage of correct -named-entities found by the NER system, and recall is the percentage of the named-entities in the golden annotations -that are retrieved by the NER system. This is computed in two different ways depending on whether we want an exact -match (i.e., strict and exact ) or a partial match (i.e., partial and type) scenario: +Then we can compute precision, recall, f1-score, where roughly describing precision is the percentage of correct +named-entities found by the NER system. Recall as the percentage of the named-entities in the golden annotations +that are retrieved by the NER system. + +This is computed in two different ways depending on whether we want an exact match (i.e., strict and exact ) or a +partial match (i.e., partial and type) scenario: __Exact Match (i.e., strict and exact )__ ``` Precision = (COR / ACT) = TP / (TP + FP) Recall = (COR / POS) = TP / (TP+FN) ``` + __Partial Match (i.e., partial and type)__ ``` Precision = (COR + 0.5 × PAR) / ACT = TP / (TP + FP) @@ -184,74 +247,22 @@ but according to the definition of `spurious`: In this case there exists an annotation, but with a different entity type, so we assume it's only incorrect. -## Installation - -``` -pip install nervaluate -``` - -## Example: - -The main `Evaluator` class will accept the following formats: - -* Nested lists containing NER labels. -* CoNLL style tab delimited strings. -* [prodi.gy](https://prodi.gy) style lists of spans. - -### Nested lists - -``` -from nervaluate.evaluator import Evaluator - - true = [ - ['O', 'B-PER', 'I-PER', 'O', 'O', 'O', 'B-ORG', 'I-ORG'], - ['O', 'B-LOC', 'B-PER', 'I-PER', 'O', 'O', 'B-DATE'], - ] - - pred = [ - ['O', 'O', 'B-PER', 'I-PER', 'O', 'O', 'B-ORG', 'I-ORG'], - ['O', 'B-LOC', 'I-LOC', 'B-PER', 'O', 'O', 'B-DATE'], - ] - - # Example text for reference: - # "The John Smith who works at Google Inc" - # "In Paris Marie Curie lived in 1895" - - evaluator = Evaluator(true, pred, tags=['PER', 'ORG', 'LOC', 'DATE'], loader="list") - - -Scenario: all - - correct incorrect partial missed spurious precision recall f1-score - -ent_type 5 0 0 0 0 1.00 1.00 1.00 - exact 2 3 0 0 0 0.40 0.40 0.40 - partial 2 0 3 0 0 0.40 0.40 0.40 - strict 2 3 0 0 0 0.40 0.40 0.40 -``` -and, aggregated by entity type: - -``` -print(evaluator.summary_report(mode='entities')) - -Scenario: strict +## Contributing to the `nervaluate` package - correct incorrect partial missed spurious precision recall f1-score +### Extending the package to accept more formats - DATE 1 0 0 0 0 1.00 1.00 1.00 - LOC 0 1 0 0 0 0.00 0.00 0.00 - ORG 1 0 0 0 0 1.00 1.00 1.00 - PER 0 2 0 0 0 0.00 0.00 0.00 -``` +The `Evaluator` accepts the following formats: -## Contributing to the `nervaluate` package +* Nested lists containing NER labels +* CoNLL style tab delimited strings +* [prodi.gy](https://prodi.gy) style lists of spans -### Extending the package to accept more formats +Additional formats can easily be added by creating a new loader class in `nervaluate/loaders.py`. The loader class +should inherit from the `DataLoader` base class and implement the `load` method. -Additional formats can easily be added to the module by creating a new loader class in `nervaluate/loaders.py`. The -loader class should inherit from the `DataLoader` base class and implement the `load` method. The `load` method should - return a list of entity lists, where each entity is represented as a dictionary with `label`, `start`, and `end` keys. +The `load` method should return a list of entity lists, where each entity is represented as a dictionary +with `label`, `start`, and `end` keys. The new loader can then be added to the `_setup_loaders` method in the `Evaluator` class, and can be selected with the `loader` argument when instantiating the `Evaluator` class. diff --git a/src/nervaluate/evaluator.py b/src/nervaluate/evaluator.py index 1cd9905..f8695a1 100644 --- a/src/nervaluate/evaluator.py +++ b/src/nervaluate/evaluator.py @@ -2,7 +2,7 @@ import pandas as pd from .entities import EvaluationResult, EvaluationIndices -from .evaluation_strategies import ( +from .strategies import ( EvaluationStrategy, StrictEvaluation, PartialEvaluation, diff --git a/src/nervaluate/evaluation_strategies.py b/src/nervaluate/strategies.py similarity index 100% rename from src/nervaluate/evaluation_strategies.py rename to src/nervaluate/strategies.py diff --git a/tests/test_evaluation_strategies.py b/tests/test_strategies.py similarity index 99% rename from tests/test_evaluation_strategies.py rename to tests/test_strategies.py index 54ff180..22eada9 100644 --- a/tests/test_evaluation_strategies.py +++ b/tests/test_strategies.py @@ -1,10 +1,10 @@ import pytest from nervaluate.entities import Entity -from nervaluate.evaluation_strategies import ( - StrictEvaluation, - PartialEvaluation, +from nervaluate.strategies import ( EntityTypeEvaluation, ExactEvaluation, + PartialEvaluation, + StrictEvaluation ) From 2c9cec71f40765f8cbf6162ff6daa58d1e5b2ce7 Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Sat, 7 Jun 2025 11:29:35 +0200 Subject: [PATCH 40/41] updating CITATION and removing flake --- CITATION.cff | 11 +++++++---- CONTRIBUTING.md | 14 ++++++-------- pyproject.toml | 7 ------- 3 files changed, 13 insertions(+), 19 deletions(-) diff --git a/CITATION.cff b/CITATION.cff index 7353d55..3c5f673 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -1,5 +1,9 @@ cff-version: 1.2.0 message: "If you use this software, please cite it as below." +title: "nervaluate" +date-released: 2025-06-08 +url: "https://github.com/mantisnlp/nervaluate" +version: 1.0.0 authors: - family-names: "Batista" given-names: "David" @@ -7,7 +11,6 @@ authors: - family-names: "Upson" given-names: "Matthew Antony" orcid: "https://orcid.org/0000-0002-1040-8048" -title: "nervaluate" -version: 0.2.0 -date-released: 2020-10-17 -url: "https://github.com/mantisnlp/nervaluate" + + + diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 323b5ca..795a596 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -10,11 +10,9 @@ Thank you for your interest in contributing to `nervaluate`! This document provi git clone https://github.com/your-username/nervaluate.git cd nervaluate ``` -3. Create a virtual environment and install dependencies: +3. Make sure you have hatch installed, then create a virtual environment: ```bash - python -m venv .venv - source .venv/bin/activate # On Windows: .venv\Scripts\activate - pip install -e ".[dev]" + hatch env create ``` ## Adding Tests @@ -26,6 +24,10 @@ Thank you for your interest in contributing to `nervaluate`! This document provi 3. Test files should be named `test_*.py` 4. Test functions should be named `test_*` 5. Use pytest fixtures when appropriate for test setup and teardown +6. Run tests locally before submitting a pull request: + ```bash + hatch -e + ``` ## Changelog Management @@ -72,10 +74,6 @@ Thank you for your interest in contributing to `nervaluate`! This document provi - Follow PEP 8 guidelines - Use type hints -- Run pre-commit hooks before committing: - ```bash - pre-commit run --all-files - ``` ## Questions? diff --git a/pyproject.toml b/pyproject.toml index dd79e5b..05082c5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,13 +98,6 @@ default-docstring-type = "numpy" load-plugins = ["pylint.extensions.docparams"] ignore-paths = ["./examples/.*"] -[tool.flake8] -max-line-length = 120 -extend-ignore = ["E203"] -exclude = [".git", "__pycache__", "build", "dist", "./examples/*"] -max-complexity = 10 -per-file-ignores = ["*/__init__.py: F401"] - [tool.mypy] python_version = "3.11" ignore_missing_imports = true From 4742f96bc1baccad7aed987b51b98cf8162481ee Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Sat, 7 Jun 2025 11:40:11 +0200 Subject: [PATCH 41/41] wip: using hatch in contributing --- CONTRIBUTING.md | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 795a596..53e529b 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -11,9 +11,7 @@ Thank you for your interest in contributing to `nervaluate`! This document provi cd nervaluate ``` 3. Make sure you have hatch installed, then create a virtual environment: - ```bash - hatch env create - ``` + # ToDo ## Adding Tests