diff --git a/CITATION.cff b/CITATION.cff index 7353d55..3c5f673 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -1,5 +1,9 @@ cff-version: 1.2.0 message: "If you use this software, please cite it as below." +title: "nervaluate" +date-released: 2025-06-08 +url: "https://github.com/mantisnlp/nervaluate" +version: 1.0.0 authors: - family-names: "Batista" given-names: "David" @@ -7,7 +11,6 @@ authors: - family-names: "Upson" given-names: "Matthew Antony" orcid: "https://orcid.org/0000-0002-1040-8048" -title: "nervaluate" -version: 0.2.0 -date-released: 2020-10-17 -url: "https://github.com/mantisnlp/nervaluate" + + + diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 323b5ca..53e529b 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -10,12 +10,8 @@ Thank you for your interest in contributing to `nervaluate`! This document provi git clone https://github.com/your-username/nervaluate.git cd nervaluate ``` -3. Create a virtual environment and install dependencies: - ```bash - python -m venv .venv - source .venv/bin/activate # On Windows: .venv\Scripts\activate - pip install -e ".[dev]" - ``` +3. Make sure you have hatch installed, then create a virtual environment: + # ToDo ## Adding Tests @@ -26,6 +22,10 @@ Thank you for your interest in contributing to `nervaluate`! This document provi 3. Test files should be named `test_*.py` 4. Test functions should be named `test_*` 5. Use pytest fixtures when appropriate for test setup and teardown +6. Run tests locally before submitting a pull request: + ```bash + hatch -e + ``` ## Changelog Management @@ -72,10 +72,6 @@ Thank you for your interest in contributing to `nervaluate`! This document provi - Follow PEP 8 guidelines - Use type hints -- Run pre-commit hooks before committing: - ```bash - pre-commit run --all-files - ``` ## Questions? diff --git a/README.md b/README.md index 31f4b47..1346074 100644 --- a/README.md +++ b/README.md @@ -12,19 +12,80 @@ # nervaluate -`nervaluate` is a python module for evaluating Named Entity Recognition (NER) models as defined in the SemEval 2013 - 9.1 task. +`nervaluate` is a module for evaluating Named Entity Recognition (NER) models as defined in the SemEval 2013 - 9.1 task. The evaluation metrics output by nervaluate go beyond a simple token/tag based schema, and consider different scenarios based on whether all the tokens that belong to a named entity were classified or not, and also whether the correct entity type was assigned. This full problem is described in detail in the [original blog](http://www.davidsbatista.net/blog/2018/05/09/Named_Entity_Evaluation/) -post by [David Batista](https://github.com/davidsbatista), and extends the code in the [original repository](https://github.com/davidsbatista/NER-Evaluation) which accompanied the blog post. +post by [David Batista](https://github.com/davidsbatista), and this package extends the code in the [original repository](https://github.com/davidsbatista/NER-Evaluation) +which accompanied the blog post. -The code draws heavily on: +The code draws heavily on the papers: * [SemEval-2013 Task 9 : Extraction of Drug-Drug Interactions from Biomedical Texts (DDIExtraction 2013)](https://www.aclweb.org/anthology/S13-2056) +* [SemEval-2013 Task 9.1 - Evaluation Metrics](https://davidsbatista.net/assets/documents/others/semeval_2013-task-9_1-evaluation-metrics.pdf) + +# Usage example + +``` +pip install nervaluate +``` + +A possible input format are lists of NER labels, where each list corresponds to a sentence and each label is a token label. +Initialize the `Evaluator` class with the true labels and predicted labels, and specify the entity types we want to evaluate. + +```python +from nervaluate.evaluator import Evaluator + +true = [ + ['O', 'B-PER', 'I-PER', 'O', 'O', 'O', 'B-ORG', 'I-ORG'], # "The John Smith who works at Google Inc" + ['O', 'B-LOC', 'B-PER', 'I-PER', 'O', 'O', 'B-DATE'], # "In Paris Marie Curie lived in 1895" +] + +pred = [ + ['O', 'O', 'B-PER', 'I-PER', 'O', 'O', 'B-ORG', 'I-ORG'], + ['O', 'B-LOC', 'I-LOC', 'B-PER', 'O', 'O', 'B-DATE'], +] + +evaluator = Evaluator(true, pred, tags=['PER', 'ORG', 'LOC', 'DATE'], loader="list") +``` + +Print the summary report for the evaluation, which will show the metrics for each entity type and evaluation scenario: + +```python + +print(evaluator.summary_report()) + +Scenario: all + + correct incorrect partial missed spurious precision recall f1-score + +ent_type 5 0 0 0 0 1.00 1.00 1.00 + exact 2 3 0 0 0 0.40 0.40 0.40 + partial 2 0 3 0 0 0.40 0.40 0.40 + strict 2 3 0 0 0 0.40 0.40 0.40 +``` + +or aggregated by entity type under a specific evaluation scenario: + +```python +print(evaluator.summary_report(mode='entities')) + +Scenario: strict + + correct incorrect partial missed spurious precision recall f1-score + + DATE 1 0 0 0 0 1.00 1.00 1.00 + LOC 0 1 0 0 0 0.00 0.00 0.00 + ORG 1 0 0 0 0 1.00 1.00 1.00 + PER 0 2 0 0 0 0.00 0.00 0.00 +``` + +# Evaluation Scenarios + ## Token level evaluation for NER is too simplistic When running machine learning models for NER, it is common to report metrics at the individual token level. This may @@ -35,121 +96,127 @@ When comparing the golden standard annotations with the output of a NER system d __I. Surface string and entity type match__ -|Token|Gold|Prediction| -|---|---|---| -|in|O|O| -|New|B-LOC|B-LOC| -|York|I-LOC|I-LOC| -|.|O|O| +| Token | Gold | Prediction | +|-------|-------|------------| +| in | O | O | +| New | B-LOC | B-LOC | +| York | I-LOC | I-LOC | +| . | O | O | __II. System hypothesized an incorrect entity__ -|Token|Gold|Prediction| -|---|---|---| -|an|O|O| -|Awful|O|B-ORG| -|Headache|O|I-ORG| -|in|O|O| +| Token | Gold | Prediction | +|----------|------|------------| +| an | O | O | +| Awful | O | B-ORG | +| Headache | O | I-ORG | +| in | O | O | __III. System misses an entity__ -|Token|Gold|Prediction| -|---|---|---| -|in|O|O| -|Palo|B-LOC|O| -|Alto|I-LOC|O| -|,|O|O| +| Token | Gold | Prediction | +|-------|-------|------------| +| in | O | O | +| Palo | B-LOC | O | +| Alto | I-LOC | O | +| , | O | O | Based on these three scenarios we have a simple classification evaluation that can be measured in terms of false positives, true positives, false negatives and false positives, and subsequently compute precision, recall and F1-score for each named-entity type. However, this simple schema ignores the possibility of partial matches or other scenarios when the NER system gets -the named-entity surface string correct but the type wrong, and we might also want to evaluate these scenarios +the named-entity surface string correct but the type wrong. We might also want to evaluate these scenarios again at a full-entity level. For example: -__IV. System assigns the wrong entity type__ +__IV. System identifies the surface string but assigns the wrong entity type__ -|Token|Gold|Prediction| -|---|---|---| -|I|O|O| -|live|O|O| -|in|O|O| -|Palo|B-LOC|B-ORG| -|Alto|I-LOC|I-ORG| -|,|O|O| +| Token | Gold | Prediction | +|-------|-------|------------| +| I | O | O | +| live | O | O | +| in | O | O | +| Palo | B-LOC | B-ORG | +| Alto | I-LOC | I-ORG | +| , | O | O | __V. System gets the boundaries of the surface string wrong__ -|Token|Gold|Prediction| -|---|---|---| -|Unless|O|B-PER| -|Karl|B-PER|I-PER| -|Smith|I-PER|I-PER| -|resigns|O|O| +| Token | Gold | Prediction | +|---------|-------|------------| +| Unless | O | B-PER | +| Karl | B-PER | I-PER | +| Smith | I-PER | I-PER | +| resigns | O | O | __VI. System gets the boundaries and entity type wrong__ -|Token|Gold|Prediction| -|---|---|---| -|Unless|O|B-ORG| -|Karl|B-PER|I-ORG| -|Smith|I-PER|I-ORG| -|resigns|O|O| +| Token | Gold | Prediction | +|---------|-------|------------| +| Unless | O | B-ORG | +| Karl | B-PER | I-ORG | +| Smith | I-PER | I-ORG | +| resigns | O | O | + + +## Defining evaluation metrics How can we incorporate these described scenarios into evaluation metrics? See the [original blog](http://www.davidsbatista.net/blog/2018/05/09/Named_Entity_Evaluation/) -for a great explanation, a summary is included here: +for a great explanation, a summary is included here. -We can use the following five metrics to consider difference categories of errors: +We can define the following five metrics to consider different categories of errors: -|Error type|Explanation| -|---|---| -|Correct (COR)|both are the same| -|Incorrect (INC)|the output of a system and the golden annotation don’t match| -|Partial (PAR)|system and the golden annotation are somewhat “similar” but not the same| -|Missing (MIS)|a golden annotation is not captured by a system| -|Spurious (SPU)|system produces a response which doesn’t exist in the golden annotation| +| Error type | Explanation | +|-----------------|--------------------------------------------------------------------------| +| Correct (COR) | both are the same | +| Incorrect (INC) | the output of a system and the golden annotation don’t match | +| Partial (PAR) | system and the golden annotation are somewhat “similar” but not the same | +| Missing (MIS) | a golden annotation is not captured by a system | +| Spurious (SPU) | system produces a response which doesn’t exist in the golden annotation | These five metrics can be measured in four different ways: -|Evaluation schema|Explanation| -|---|---| -|Strict|exact boundary surface string match and entity type| -|Exact|exact boundary match over the surface string, regardless of the type| -|Partial|partial boundary match over the surface string, regardless of the type| -|Type|some overlap between the system tagged entity and the gold annotation is required| +| Evaluation schema | Explanation | +|-------------------|-----------------------------------------------------------------------------------| +| Strict | exact boundary surface string match and entity type | +| Exact | exact boundary match over the surface string, regardless of the type | +| Partial | partial boundary match over the surface string, regardless of the type | +| Type | some overlap between the system tagged entity and the gold annotation is required | These five errors and four evaluation schema interact in the following ways: -|Scenario|Gold entity|Gold string|Pred entity|Pred string|Type|Partial|Exact|Strict| -|---|---|---|---|---|---|---|---|---| -|III|BRAND|tikosyn| | |MIS|MIS|MIS|MIS| -|II| | |BRAND|healthy|SPU|SPU|SPU|SPU| -|V|DRUG|warfarin|DRUG|of warfarin|COR|PAR|INC|INC| -|IV|DRUG|propranolol|BRAND|propranolol|INC|COR|COR|INC| -|I|DRUG|phenytoin|DRUG|phenytoin|COR|COR|COR|COR| -|VI|GROUP|contraceptives|DRUG|oral contraceptives|INC|PAR|INC|INC| +| Scenario | Gold entity | Gold string | Pred entity | Pred string | Type | Partial | Exact | Strict | +|----------|-------------|----------------|-------------|---------------------|------|---------|-------|--------| +| III | BRAND | tikosyn | | | MIS | MIS | MIS | MIS | +| II | | | BRAND | healthy | SPU | SPU | SPU | SPU | +| V | DRUG | warfarin | DRUG | of warfarin | COR | PAR | INC | INC | +| IV | DRUG | propranolol | BRAND | propranolol | INC | COR | COR | INC | +| I | DRUG | phenytoin | DRUG | phenytoin | COR | COR | COR | COR | +| VI | GROUP | contraceptives | DRUG | oral contraceptives | INC | PAR | INC | INC | -Then precision/recall/f1-score are calculated for each different evaluation schema. In order to achieve data, two more -quantities need to be calculated: +Then precision, recall and f1-score are calculated for each different evaluation schema. In order to achieve data, +two more quantities need to be calculated: ``` POSSIBLE (POS) = COR + INC + PAR + MIS = TP + FN ACTUAL (ACT) = COR + INC + PAR + SPU = TP + FP ``` -Then we can compute precision/recall/f1-score, where roughly describing precision is the percentage of correct -named-entities found by the NER system, and recall is the percentage of the named-entities in the golden annotations -that are retrieved by the NER system. This is computed in two different ways depending on whether we want an exact -match (i.e., strict and exact ) or a partial match (i.e., partial and type) scenario: +Then we can compute precision, recall, f1-score, where roughly describing precision is the percentage of correct +named-entities found by the NER system. Recall as the percentage of the named-entities in the golden annotations +that are retrieved by the NER system. + +This is computed in two different ways depending on whether we want an exact match (i.e., strict and exact ) or a +partial match (i.e., partial and type) scenario: __Exact Match (i.e., strict and exact )__ ``` Precision = (COR / ACT) = TP / (TP + FP) Recall = (COR / POS) = TP / (TP+FN) ``` + __Partial Match (i.e., partial and type)__ ``` Precision = (COR + 0.5 × PAR) / ACT = TP / (TP + FP) @@ -158,16 +225,16 @@ Recall = (COR + 0.5 × PAR)/POS = COR / ACT = TP / (TP + FN) __Putting all together:__ -|Measure|Type|Partial|Exact|Strict| -|---|---|---|---|---| -|Correct|3|3|3|2| -|Incorrect|2|0|2|3| -|Partial|0|2|0|0| -|Missed|1|1|1|1| -|Spurious|1|1|1|1| -|Precision|0.5|0.66|0.5|0.33| -|Recall|0.5|0.66|0.5|0.33| -|F1|0.5|0.66|0.5|0.33| +| Measure | Type | Partial | Exact | Strict | +|-----------|------|---------|-------|--------| +| Correct | 3 | 3 | 3 | 2 | +| Incorrect | 2 | 0 | 2 | 3 | +| Partial | 0 | 2 | 0 | 0 | +| Missed | 1 | 1 | 1 | 1 | +| Spurious | 1 | 1 | 1 | 1 | +| Precision | 0.5 | 0.66 | 0.5 | 0.33 | +| Recall | 0.5 | 0.66 | 0.5 | 0.33 | +| F1 | 0.5 | 0.66 | 0.5 | 0.33 | ## Notes: @@ -180,270 +247,31 @@ but according to the definition of `spurious`: In this case there exists an annotation, but with a different entity type, so we assume it's only incorrect. -## Installation - -``` -pip install nervaluate -``` - -## Example: - -The main `Evaluator` class will accept a number of formats: - -* [prodi.gy](https://prodi.gy) style lists of spans. -* Nested lists containing NER labels. -* CoNLL style tab delimited strings. - -### Prodigy spans - -``` -true = [ - [{"label": "PER", "start": 2, "end": 4}], - [{"label": "LOC", "start": 1, "end": 2}, - {"label": "LOC", "start": 3, "end": 4}] -] - -pred = [ - [{"label": "PER", "start": 2, "end": 4}], - [{"label": "LOC", "start": 1, "end": 2}, - {"label": "LOC", "start": 3, "end": 4}] -] -from nervaluate import Evaluator +## Contributing to the `nervaluate` package -evaluator = Evaluator(true, pred, tags=['LOC', 'PER']) +### Extending the package to accept more formats -# Returns overall metrics and metrics for each tag - -results, results_per_tag, result_indices, result_indices_by_tag = evaluator.evaluate() - -print(results) -``` - -``` -{ - 'ent_type':{ - 'correct':3, - 'incorrect':0, - 'partial':0, - 'missed':0, - 'spurious':0, - 'possible':3, - 'actual':3, - 'precision':1.0, - 'recall':1.0 - }, - 'partial':{ - 'correct':3, - 'incorrect':0, - 'partial':0, - 'missed':0, - 'spurious':0, - 'possible':3, - 'actual':3, - 'precision':1.0, - 'recall':1.0 - }, - 'strict':{ - 'correct':3, - 'incorrect':0, - 'partial':0, - 'missed':0, - 'spurious':0, - 'possible':3, - 'actual':3, - 'precision':1.0, - 'recall':1.0 - }, - 'exact':{ - 'correct':3, - 'incorrect':0, - 'partial':0, - 'missed':0, - 'spurious':0, - 'possible':3, - 'actual':3, - 'precision':1.0, - 'recall':1.0 - } -} -``` - -``` -print(results_by_tag) -``` - -``` -{ - 'LOC':{ - 'ent_type':{ - 'correct':2, - 'incorrect':0, - 'partial':0, - 'missed':0, - 'spurious':0, - 'possible':2, - 'actual':2, - 'precision':1.0, - 'recall':1.0 - }, - 'partial':{ - 'correct':2, - 'incorrect':0, - 'partial':0, - 'missed':0, - 'spurious':0, - 'possible':2, - 'actual':2, - 'precision':1.0, - 'recall':1.0 - }, - 'strict':{ - 'correct':2, - 'incorrect':0, - 'partial':0, - 'missed':0, - 'spurious':0, - 'possible':2, - 'actual':2, - 'precision':1.0, - 'recall':1.0 - }, - 'exact':{ - 'correct':2, - 'incorrect':0, - 'partial':0, - 'missed':0, - 'spurious':0, - 'possible':2, - 'actual':2, - 'precision':1.0, - 'recall':1.0 - } - }, - 'PER':{ - 'ent_type':{ - 'correct':1, - 'incorrect':0, - 'partial':0, - 'missed':0, - 'spurious':0, - 'possible':1, - 'actual':1, - 'precision':1.0, - 'recall':1.0 - }, - 'partial':{ - 'correct':1, - 'incorrect':0, - 'partial':0, - 'missed':0, - 'spurious':0, - 'possible':1, - 'actual':1, - 'precision':1.0, - 'recall':1.0 - }, - 'strict':{ - 'correct':1, - 'incorrect':0, - 'partial':0, - 'missed':0, - 'spurious':0, - 'possible':1, - 'actual':1, - 'precision':1.0, - 'recall':1.0 - }, - 'exact':{ - 'correct':1, - 'incorrect':0, - 'partial':0, - 'missed':0, - 'spurious':0, - 'possible':1, - 'actual':1, - 'precision':1.0, - 'recall':1.0 - } - } -} -``` - -``` -from nervaluate import summary_report_overall_indices - -print(summary_report_overall_indices(evaluation_indices=result_indices, error_schema='partial', preds=pred)) -``` - -``` -Indices for error schema 'partial': - -Correct indices: - - Instance 0, Entity 0: Label=PER, Start=2, End=4 - - Instance 1, Entity 0: Label=LOC, Start=1, End=2 - - Instance 1, Entity 1: Label=LOC, Start=3, End=4 - -Incorrect indices: - - None - -Partial indices: - - None - -Missed indices: - - None - -Spurious indices: - - None -``` - -### Nested lists - -``` -true = [ - ['O', 'O', 'B-PER', 'I-PER', 'O'], - ['O', 'B-LOC', 'I-LOC', 'B-LOC', 'I-LOC', 'O'], -] - -pred = [ - ['O', 'O', 'B-PER', 'I-PER', 'O'], - ['O', 'B-LOC', 'I-LOC', 'B-LOC', 'I-LOC', 'O'], -] - -evaluator = Evaluator(true, pred, tags=['LOC', 'PER'], loader="list") - -results, results_by_tag, result_indices, result_indices_by_tag = evaluator.evaluate() -``` - -### CoNLL style tab delimited - -``` - -true = "word\tO\nword\tO\B-PER\nword\tI-PER\n" - -pred = "word\tO\nword\tO\B-PER\nword\tI-PER\n" - -evaluator = Evaluator(true, pred, tags=['PER'], loader="conll") - -results, results_by_tag, result_indices, result_indices_by_tag = evaluator.evaluate() - -``` +The `Evaluator` accepts the following formats: -## Extending the package to accept more formats +* Nested lists containing NER labels +* CoNLL style tab delimited strings +* [prodi.gy](https://prodi.gy) style lists of spans -Additional formats can easily be added to the module by creating a conversion function in `nervaluate/utils.py`, -for example `conll_to_spans()`. This function must return the spans in the prodigy style dicts shown in the prodigy -example above. +Additional formats can easily be added by creating a new loader class in `nervaluate/loaders.py`. The loader class +should inherit from the `DataLoader` base class and implement the `load` method. -The new function can then be added to the list of loaders in `nervaluate/nervaluate.py`, and can then be selection -with the `loader` argument when instantiating the `Evaluator` class. +The `load` method should return a list of entity lists, where each entity is represented as a dictionary +with `label`, `start`, and `end` keys. -A list of formats we intend to include is included in https://github.com/ivyleavedtoadflax/nervaluate/issues/3. +The new loader can then be added to the `_setup_loaders` method in the `Evaluator` class, and can be selected with the + `loader` argument when instantiating the `Evaluator` class. +Here is list of formats we intend to [include](https://github.com/MantisAI/nervaluate/issues/3). -## Contributing to the nervaluate package +### General Contributing -Improvements, adding new features and bug fixes are welcome. If you wish to participate in the development of nervaluate +Improvements, adding new features and bug fixes are welcome. If you wish to participate in the development of `nervaluate` please read the guidelines in the [CONTRIBUTING.md](CONTRIBUTING.md) file. --- diff --git a/compare_versions.py b/compare_versions.py new file mode 100644 index 0000000..e4b9d31 --- /dev/null +++ b/compare_versions.py @@ -0,0 +1,194 @@ +from nervaluate.evaluator import Evaluator as NewEvaluator +from nervaluate import Evaluator as OldEvaluator +from nervaluate.reporting import summary_report_overall_indices, summary_report_ents_indices, summary_report + +def list_to_dict_format(data): + """ + Convert list format data to dictionary format. + + Args: + data: List of lists containing BIO tags + + Returns: + List of lists containing dictionaries with label, start, and end keys + """ + result = [] + for doc in data: + doc_entities = [] + current_entity = None + + for i, tag in enumerate(doc): + if tag.startswith('B-'): + # If we were tracking an entity, add it to the list + if current_entity is not None: + doc_entities.append(current_entity) + # Start tracking a new entity + current_entity = { + 'label': tag[2:], # Remove 'B-' prefix + 'start': i, + 'end': i + } + elif tag.startswith('I-'): + # Continue tracking the current entity + if current_entity is not None: + current_entity['end'] = i + else: # 'O' tag + # If we were tracking an entity, add it to the list + if current_entity is not None: + doc_entities.append(current_entity) + current_entity = None + + # Don't forget to add the last entity if there was one + if current_entity is not None: + doc_entities.append(current_entity) + + result.append(doc_entities) + + return result + + +def generate_synthetic_data(tags, num_samples, min_length=5, max_length=15): + """ + Generate synthetic NER data with ground truth and predictions. + + Args: + tags (list): List of entity tags to use (e.g., ['PER', 'ORG', 'LOC', 'DATE']) + num_samples (int): Number of samples to generate + min_length (int): Minimum sequence length + max_length (int): Maximum sequence length + + Returns: + tuple: (true_sequences, pred_sequences) + """ + import random + + def generate_sequence(length): + sequence = ['O'] * length + # Randomly decide if we'll add an entity + if random.random() < 0.7: # 70% chance to add an entity + # Choose random tag + tag = random.choice(tags) + # Choose random start position + start = random.randint(0, length - 2) + # Choose random length (1 or 2 tokens) + entity_length = random.randint(1, 2) + if start + entity_length <= length: + sequence[start] = f'B-{tag}' + for i in range(1, entity_length): + sequence[start + i] = f'I-{tag}' + return sequence + + def generate_prediction(true_sequence): + pred_sequence = true_sequence.copy() + # Randomly modify some predictions + for i in range(len(pred_sequence)): + if random.random() < 0.2: # 20% chance to modify each token + if pred_sequence[i] == 'O': + # Sometimes predict an entity where there isn't one + if random.random() < 0.3: + tag = random.choice(tags) + pred_sequence[i] = f'B-{tag}' + else: + # Sometimes change the entity type or boundary + if random.random() < 0.3: + tag = random.choice(tags) + if pred_sequence[i].startswith('B-'): + pred_sequence[i] = f'B-{tag}' + elif pred_sequence[i].startswith('I-'): + pred_sequence[i] = f'I-{tag}' + elif random.random() < 0.3: + # Sometimes predict O instead of an entity + pred_sequence[i] = 'O' + return pred_sequence + + true_sequences = [] + pred_sequences = [] + + for _ in range(num_samples): + length = random.randint(min_length, max_length) + true_sequence = generate_sequence(length) + pred_sequence = generate_prediction(true_sequence) + true_sequences.append(true_sequence) + pred_sequences.append(pred_sequence) + + return true_sequences, pred_sequences + + +def overall_report(true, pred): + + new_evaluator = NewEvaluator(true, pred, tags=['PER', 'ORG', 'LOC', 'DATE'], loader="list") + print(new_evaluator.summary_report()) + + print("-"*100) + + old_evaluator = OldEvaluator(true, pred, tags=['PER', 'ORG', 'LOC', 'DATE'], loader="list") + results = old_evaluator.evaluate()[0] # Get the first element which contains the overall results + print(summary_report(results)) + + +def entities_report(true, pred): + + new_evaluator = NewEvaluator(true, pred, tags=['PER', 'ORG', 'LOC', 'DATE'], loader="list") + + # entities - strict, exact, partial, ent_type + print(new_evaluator.summary_report(mode="entities", scenario="strict")) + print(new_evaluator.summary_report(mode="entities", scenario="exact")) + print(new_evaluator.summary_report(mode="entities", scenario="partial")) + print(new_evaluator.summary_report(mode="entities", scenario="ent_type")) + + print("-"*100) + + old_evaluator = OldEvaluator(true, pred, tags=['PER', 'ORG', 'LOC', 'DATE'], loader="list") + _, results_agg_entities_type, _, _ = old_evaluator.evaluate() # Get the second element which contains the entity-specific results + print(summary_report(results_agg_entities_type, mode="entities", scenario="strict")) + print(summary_report(results_agg_entities_type, mode="entities", scenario="exact")) + print(summary_report(results_agg_entities_type, mode="entities", scenario="partial")) + print(summary_report(results_agg_entities_type, mode="entities", scenario="ent_type")) + + +def indices_report_overall(true, pred): + + new_evaluator = NewEvaluator(true, pred, tags=['PER', 'LOC', 'DATE'], loader="list") + print(new_evaluator.summary_report_indices(colors=True, mode="overall", scenario="strict")) + print(new_evaluator.summary_report_indices(colors=True, mode="overall", scenario="exact")) + print(new_evaluator.summary_report_indices(colors=True, mode="overall", scenario="partial")) + print(new_evaluator.summary_report_indices(colors=True, mode="overall", scenario="ent_type")) + + old_evaluator = OldEvaluator(true, pred, tags=['LOC', 'PER', 'DATE'], loader="list") + _, _, result_indices, _ = old_evaluator.evaluate() + pred_dict = list_to_dict_format(pred) # convert predictions to dictionary format for reporting + print(summary_report_overall_indices(evaluation_indices=result_indices, error_schema='strict', preds=pred_dict)) + print(summary_report_overall_indices(evaluation_indices=result_indices, error_schema='exact', preds=pred_dict)) + print(summary_report_overall_indices(evaluation_indices=result_indices, error_schema='partial', preds=pred_dict)) + print(summary_report_overall_indices(evaluation_indices=result_indices, error_schema='ent_type', preds=pred_dict)) + + +def indices_report_entities(true, pred): + + new_evaluator = NewEvaluator(true, pred, tags=['PER', 'LOC', 'DATE'], loader="list") + print(new_evaluator.summary_report_indices(colors=True, mode="entities", scenario="strict")) + print(new_evaluator.summary_report_indices(colors=True, mode="entities", scenario="exact")) + print(new_evaluator.summary_report_indices(colors=True, mode="entities", scenario="partial")) + print(new_evaluator.summary_report_indices(colors=True, mode="entities", scenario="ent_type")) + + old_evaluator = OldEvaluator(true, pred, tags=['LOC', 'PER', 'DATE'], loader="list") + _, _, _, result_indices_by_tag = old_evaluator.evaluate() + pred_dict = list_to_dict_format(pred) # convert predictions to dictionary format for reporting + print(summary_report_ents_indices(evaluation_agg_indices=result_indices_by_tag, error_schema='strict', preds=pred_dict)) + print(summary_report_ents_indices(evaluation_agg_indices=result_indices_by_tag, error_schema='exact', preds=pred_dict)) + print(summary_report_ents_indices(evaluation_agg_indices=result_indices_by_tag, error_schema='partial', preds=pred_dict)) + print(summary_report_ents_indices(evaluation_agg_indices=result_indices_by_tag, error_schema='ent_type', preds=pred_dict)) + + +if __name__ == "__main__": + tags = ['PER', 'ORG', 'LOC', 'DATE'] + true, pred = generate_synthetic_data(tags, num_samples=10) + + overall_report(true, pred) + print("\n\n" + "="*100 + "\n\n") + entities_report(true, pred) + print("\n\n" + "="*100 + "\n\n") + indices_report_overall(true, pred) + print("\n\n" + "="*100 + "\n\n") + indices_report_entities(true, pred) + print("\n\n" + "="*100 + "\n\n") \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index eaa84bb..05082c5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,8 +71,19 @@ disable = [ "I0020", # bad-option-value "R0801", # duplicate-code "W9020", # bad-option-value + "W0621", # redefined-outer-name ] +[tool.pylint.'DESIGN'] +max-args = 38 # Default is 5 +max-attributes = 28 # Default is 7 +max-branches = 14 # Default is 12 +max-locals = 45 # Default is 15 +max-module-lines = 2468 # Default is 1000 +max-nested-blocks = 9 # Default is 5 +max-statements = 206 # Default is 50 +min-public-methods = 1 # Allow classes with just one public method + [tool.pylint.format] max-line-length = 120 @@ -87,13 +98,6 @@ default-docstring-type = "numpy" load-plugins = ["pylint.extensions.docparams"] ignore-paths = ["./examples/.*"] -[tool.flake8] -max-line-length = 120 -extend-ignore = ["E203"] -exclude = [".git", "__pycache__", "build", "dist", "./examples/*"] -max-complexity = 10 -per-file-ignores = ["*/__init__.py: F401"] - [tool.mypy] python_version = "3.11" ignore_missing_imports = true @@ -103,7 +107,6 @@ warn_redundant_casts = true warn_unused_ignores = true warn_unused_configs = true - [[tool.mypy.overrides]] module = "examples.*" follow_imports = "skip" diff --git a/src/nervaluate/__init__.py b/src/nervaluate/__init__.py index f34c277..830a9de 100644 --- a/src/nervaluate/__init__.py +++ b/src/nervaluate/__init__.py @@ -1,15 +1,2 @@ -from .evaluate import ( - Evaluator, - compute_actual_possible, - compute_metrics, - compute_precision_recall, - compute_precision_recall_wrapper, - find_overlap, -) -from .reporting import ( - summary_report_ent, - summary_report_ents_indices, - summary_report_overall, - summary_report_overall_indices, -) +from .evaluator import Evaluator from .utils import collect_named_entities, conll_to_spans, list_to_spans, split_list diff --git a/src/nervaluate/entities.py b/src/nervaluate/entities.py new file mode 100644 index 0000000..4561cab --- /dev/null +++ b/src/nervaluate/entities.py @@ -0,0 +1,74 @@ +from dataclasses import dataclass +from typing import List, Tuple + + +@dataclass +class Entity: + """Represents a named entity with its position and label.""" + + label: str + start: int + end: int + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Entity): + return NotImplemented + return self.label == other.label and self.start == other.start and self.end == other.end + + def __hash__(self) -> int: + return hash((self.label, self.start, self.end)) + + +@dataclass +class EvaluationResult: + """Represents the evaluation metrics for a single entity type or overall.""" + + correct: int = 0 + incorrect: int = 0 + partial: int = 0 + missed: int = 0 + spurious: int = 0 + precision: float = 0.0 + recall: float = 0.0 + f1: float = 0.0 + actual: int = 0 + possible: int = 0 + + def compute_metrics(self, partial_or_type: bool = False) -> None: + """Compute precision, recall and F1 score.""" + self.actual = self.correct + self.incorrect + self.partial + self.spurious + self.possible = self.correct + self.incorrect + self.partial + self.missed + + if partial_or_type: + precision = (self.correct + 0.5 * self.partial) / self.actual if self.actual > 0 else 0 + recall = (self.correct + 0.5 * self.partial) / self.possible if self.possible > 0 else 0 + else: + precision = self.correct / self.actual if self.actual > 0 else 0 + recall = self.correct / self.possible if self.possible > 0 else 0 + + self.precision = precision + self.recall = recall + self.f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 + + +@dataclass +class EvaluationIndices: + """Represents the indices of entities in different evaluation categories.""" + + correct_indices: List[Tuple[int, int]] = None # type: ignore + incorrect_indices: List[Tuple[int, int]] = None # type: ignore + partial_indices: List[Tuple[int, int]] = None # type: ignore + missed_indices: List[Tuple[int, int]] = None # type: ignore + spurious_indices: List[Tuple[int, int]] = None # type: ignore + + def __post_init__(self) -> None: + if self.correct_indices is None: + self.correct_indices = [] + if self.incorrect_indices is None: + self.incorrect_indices = [] + if self.partial_indices is None: + self.partial_indices = [] + if self.missed_indices is None: + self.missed_indices = [] + if self.spurious_indices is None: + self.spurious_indices = [] diff --git a/src/nervaluate/evaluate.py b/src/nervaluate/evaluate.py deleted file mode 100644 index 4f0abb1..0000000 --- a/src/nervaluate/evaluate.py +++ /dev/null @@ -1,609 +0,0 @@ -import logging -import warnings -from copy import deepcopy -from typing import Any - -import pandas as pd - -from .utils import conll_to_spans, find_overlap, list_to_spans, clean_entities - -logger = logging.getLogger(__name__) - - -class Evaluator: # pylint: disable=too-many-instance-attributes, too-few-public-methods - """ - Evaluator class for evaluating named entity recognition (NER) models. - """ - - def __init__( - self, - true: list[list[str]] | list[str] | list[dict] | str | list[list[dict[str, int | str]]], - pred: list[list[str]] | list[str] | list[dict] | str | list[list[dict[str, int | str]]], - tags: list[str], - loader: str = "default", - ) -> None: - """ - Initialize the Evaluator class. - - Args: - true: List of true named entities. - pred: List of predicted named entities. - tags: List of tags to be used. - loader: Loader to be used. - - Raises: - ValueError: If the number of predicted documents does not equal the number of true documents. - """ - self.true = true - self.pred = pred - self.tags = tags - - # Setup dict into which metrics will be stored. - self.metrics_results = { - "correct": 0, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 0, - "actual": 0, - "precision": 0, - "recall": 0, - "f1": 0, - } - - # Copy results dict to cover the four schemes. - self.results = { - "strict": deepcopy(self.metrics_results), - "ent_type": deepcopy(self.metrics_results), - "partial": deepcopy(self.metrics_results), - "exact": deepcopy(self.metrics_results), - } - - # Create an accumulator to store results - self.evaluation_agg_entities_type = {e: deepcopy(self.results) for e in tags} - self.loaders = { - "list": list_to_spans, - "conll": conll_to_spans, - } - - self.loader = loader - - self.eval_indices: dict[str, list[int]] = { - "correct_indices": [], - "incorrect_indices": [], - "partial_indices": [], - "missed_indices": [], - "spurious_indices": [], - } - - # Create dicts to hold indices for correct/spurious/missing/etc examples - self.evaluation_indices = { - "strict": deepcopy(self.eval_indices), - "ent_type": deepcopy(self.eval_indices), - "partial": deepcopy(self.eval_indices), - "exact": deepcopy(self.eval_indices), - } - self.evaluation_agg_indices = {e: deepcopy(self.evaluation_indices) for e in tags} - - def evaluate(self) -> tuple[dict, dict, dict, dict]: # noqa: C901 - warnings.warn( - "The current evaluation method is deprecated and it will the output change in the next release." - "The output will change to a dictionary with the following keys: overall, entities, entity_results, " - "overall_indices, entity_indices.", - DeprecationWarning, - stacklevel=2 - ) - logging.debug("Imported %s predictions for %s true examples", len(self.pred), len(self.true)) - - if self.loader != "default": - loader = self.loaders[self.loader] - self.pred = loader(self.pred) # type: ignore - self.true = loader(self.true) # type: ignore - - if len(self.true) != len(self.pred): - raise ValueError("Number of predicted documents does not equal true") - - for index, (true_ents, pred_ents) in enumerate(zip(self.true, self.pred, strict=False)): - # Compute results for one message - tmp_results, tmp_agg_results, tmp_results_indices, tmp_agg_results_indices = compute_metrics( - true_ents, pred_ents, self.tags, index - ) - - # Cycle through each result and accumulate - # TODO: Combine these loops below: - for eval_schema in self.results: - for metric in self.results[eval_schema]: - self.results[eval_schema][metric] += tmp_results[eval_schema][metric] - - # Accumulate indices for each error type - for error_type in self.evaluation_indices[eval_schema]: - self.evaluation_indices[eval_schema][error_type] += tmp_results_indices[eval_schema][error_type] - - # Calculate global precision and recall - self.results = compute_precision_recall_wrapper(self.results) - - # Aggregate results by entity type - for label in self.tags: - for eval_schema in tmp_agg_results[label]: - for metric in tmp_agg_results[label][eval_schema]: - self.evaluation_agg_entities_type[label][eval_schema][metric] += tmp_agg_results[label][ - eval_schema - ][metric] - - # Accumulate indices for each error type per entity type - for error_type in self.evaluation_agg_indices[label][eval_schema]: - self.evaluation_agg_indices[label][eval_schema][error_type] += tmp_agg_results_indices[label][ - eval_schema - ][error_type] - - # Calculate precision recall at the individual entity level - self.evaluation_agg_entities_type[label] = compute_precision_recall_wrapper( - self.evaluation_agg_entities_type[label] - ) - - return self.results, self.evaluation_agg_entities_type, self.evaluation_indices, self.evaluation_agg_indices - - # Helper method to flatten a nested dictionary - def _flatten_dict(self, d: dict[str, Any], parent_key: str = "", sep: str = ".") -> dict[str, Any]: - """ - Flattens a nested dictionary. - - Args: - d (dict): The dictionary to flatten. - parent_key (str): The base key string to prepend to each dictionary key. - sep (str): The separator to use when combining keys. - - Returns: - dict: A flattened dictionary. - """ - items: list[tuple[str, Any]] = [] - for k, v in d.items(): - new_key = f"{parent_key}{sep}{k}" if parent_key else k - if isinstance(v, dict): - items.extend(self._flatten_dict(v, new_key, sep=sep).items()) - else: - items.append((new_key, v)) - return dict(items) - - # Modified results_to_dataframe method using the helper method - def results_to_dataframe(self) -> Any: - if not self.results: - raise ValueError("self.results should be defined.") - - if not isinstance(self.results, dict) or not all(isinstance(v, dict) for v in self.results.values()): - raise ValueError("self.results must be a dictionary of dictionaries.") - - # Flatten the nested results dictionary, including the 'entities' sub-dictionaries - flattened_results: dict[str, dict[str, Any]] = {} - for outer_key, inner_dict in self.results.items(): - flattened_inner_dict = self._flatten_dict(inner_dict) - for inner_key, value in flattened_inner_dict.items(): - if inner_key not in flattened_results: - flattened_results[inner_key] = {} - flattened_results[inner_key][outer_key] = value - - # Convert the flattened results to a pandas DataFrame - try: - return pd.DataFrame(flattened_results) - except Exception as e: - raise RuntimeError("Error converting flattened results to DataFrame") from e - - -def compute_metrics( # type: ignore # noqa: C901 - # pylint: disable=too-many-locals,too-many-branches,too-many-statements,missing-type-doc - true_named_entities, - pred_named_entities, - tags: list[str], - instance_index: int = 0, -) -> tuple[dict, dict, dict, dict]: - """ - Compute metrics on the collected true and predicted named entities - - :param true_named_entities: - collected true named entities output by collect_named_entities - - :param pred_named_entities: - collected predicted named entities output by collect_named_entities - - :param tags: - list of tags to be used - - :param instance_index: - index of the example being evaluated. Used to record indices of correct/missing/spurious/exact/partial - predictions. - """ - - eval_metrics = { - "correct": 0, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "precision": 0, - "recall": 0, - "f1": 0, - } - - # overall results - evaluation = { - "strict": deepcopy(eval_metrics), - "ent_type": deepcopy(eval_metrics), - "partial": deepcopy(eval_metrics), - "exact": deepcopy(eval_metrics), - } - - # results by entity type - evaluation_agg_entities_type = {e: deepcopy(evaluation) for e in tags} - - eval_ent_indices: dict[str, list[tuple[int, int]]] = { - "correct_indices": [], - "incorrect_indices": [], - "partial_indices": [], - "missed_indices": [], - "spurious_indices": [], - } - - # Create dicts to hold indices for correct/spurious/missing/etc examples - evaluation_ent_indices = { - "strict": deepcopy(eval_ent_indices), - "ent_type": deepcopy(eval_ent_indices), - "partial": deepcopy(eval_ent_indices), - "exact": deepcopy(eval_ent_indices), - } - evaluation_agg_ent_indices = {e: deepcopy(evaluation_ent_indices) for e in tags} - - # keep track of entities that overlapped - true_which_overlapped_with_pred = [] - - # Subset into only the tags that we are interested in. - # NOTE: we remove the tags we don't want from both the predicted and the - # true entities. This covers the two cases where mismatches can occur: - # - # 1) Where the model predicts a tag that is not present in the true data - # 2) Where there is a tag in the true data that the model is not capable of - # predicting. - - # Strip the spans down to just start, end, label. Note that failing - # to do this results in a bug. The exact cause is not clear. - true_named_entities = [clean_entities(ent) for ent in true_named_entities if ent["label"] in tags] - pred_named_entities = [clean_entities(ent) for ent in pred_named_entities if ent["label"] in tags] - - # Sort the lists to improve the speed of the overlap comparison - true_named_entities.sort(key=lambda x: x["start"]) - pred_named_entities.sort(key=lambda x: x["end"]) - - # go through each predicted named-entity - for within_instance_index, pred in enumerate(pred_named_entities): - found_overlap = False - - # Check each of the potential scenarios in turn. See - # http://www.davidsbatista.net/blog/2018/05/09/Named_Entity_Evaluation/ - # for scenario explanation. - - # Scenario I: Exact match between true and pred - if pred in true_named_entities: - true_which_overlapped_with_pred.append(pred) - evaluation["strict"]["correct"] += 1 - evaluation["ent_type"]["correct"] += 1 - evaluation["exact"]["correct"] += 1 - evaluation["partial"]["correct"] += 1 - evaluation_ent_indices["strict"]["correct_indices"].append((instance_index, within_instance_index)) - evaluation_ent_indices["ent_type"]["correct_indices"].append((instance_index, within_instance_index)) - evaluation_ent_indices["exact"]["correct_indices"].append((instance_index, within_instance_index)) - evaluation_ent_indices["partial"]["correct_indices"].append((instance_index, within_instance_index)) - - # for the agg. by label results - evaluation_agg_entities_type[pred["label"]]["strict"]["correct"] += 1 - evaluation_agg_entities_type[pred["label"]]["ent_type"]["correct"] += 1 - evaluation_agg_entities_type[pred["label"]]["exact"]["correct"] += 1 - evaluation_agg_entities_type[pred["label"]]["partial"]["correct"] += 1 - evaluation_agg_ent_indices[pred["label"]]["strict"]["correct_indices"].append( - (instance_index, within_instance_index) - ) - evaluation_agg_ent_indices[pred["label"]]["ent_type"]["correct_indices"].append( - (instance_index, within_instance_index) - ) - evaluation_agg_ent_indices[pred["label"]]["exact"]["correct_indices"].append( - (instance_index, within_instance_index) - ) - evaluation_agg_ent_indices[pred["label"]]["partial"]["correct_indices"].append( - (instance_index, within_instance_index) - ) - - else: - # check for overlaps with any of the true entities - for true in true_named_entities: - # Only enter this block if an overlap is possible - if pred["end"] < true["start"]: - break - - # overlapping needs to take into account last token as well - pred_range = range(pred["start"], pred["end"] + 1) - true_range = range(true["start"], true["end"] + 1) - - # Scenario IV: Offsets match, but entity type is wrong - if true["start"] == pred["start"] and pred["end"] == true["end"] and true["label"] != pred["label"]: - # overall results - evaluation["strict"]["incorrect"] += 1 - evaluation_ent_indices["strict"]["incorrect_indices"].append( - (instance_index, within_instance_index) - ) - evaluation["ent_type"]["incorrect"] += 1 - evaluation_ent_indices["ent_type"]["incorrect_indices"].append( - (instance_index, within_instance_index) - ) - evaluation["partial"]["correct"] += 1 - evaluation["exact"]["correct"] += 1 - - # aggregated by entity type results - evaluation_agg_entities_type[true["label"]]["strict"]["incorrect"] += 1 - evaluation_agg_ent_indices[true["label"]]["strict"]["incorrect_indices"].append( - (instance_index, within_instance_index) - ) - evaluation_agg_entities_type[true["label"]]["ent_type"]["incorrect"] += 1 - evaluation_agg_ent_indices[true["label"]]["ent_type"]["incorrect_indices"].append( - (instance_index, within_instance_index) - ) - evaluation_agg_entities_type[true["label"]]["partial"]["correct"] += 1 - evaluation_agg_entities_type[true["label"]]["exact"]["correct"] += 1 - - true_which_overlapped_with_pred.append(true) - found_overlap = True - break - - # check for an overlap i.e. not exact boundary match, with true entities - # overlaps with true entities must only count once - if find_overlap(true_range, pred_range) and true not in true_which_overlapped_with_pred: - true_which_overlapped_with_pred.append(true) - - # Scenario V: There is an overlap (but offsets do not match - # exactly), and the entity type is the same. - # 2.1 overlaps with the same entity type - if pred["label"] == true["label"]: - # overall results - evaluation["strict"]["incorrect"] += 1 - evaluation_ent_indices["strict"]["incorrect_indices"].append( - (instance_index, within_instance_index) - ) - evaluation["ent_type"]["correct"] += 1 - evaluation["partial"]["partial"] += 1 - evaluation_ent_indices["partial"]["partial_indices"].append( - (instance_index, within_instance_index) - ) - evaluation["exact"]["incorrect"] += 1 - evaluation_ent_indices["exact"]["incorrect_indices"].append( - (instance_index, within_instance_index) - ) - - # aggregated by entity type results - evaluation_agg_entities_type[true["label"]]["strict"]["incorrect"] += 1 - evaluation_agg_ent_indices[true["label"]]["strict"]["incorrect_indices"].append( - (instance_index, within_instance_index) - ) - evaluation_agg_entities_type[true["label"]]["ent_type"]["correct"] += 1 - evaluation_agg_entities_type[true["label"]]["partial"]["partial"] += 1 - evaluation_agg_ent_indices[true["label"]]["partial"]["partial_indices"].append( - (instance_index, within_instance_index) - ) - evaluation_agg_entities_type[true["label"]]["exact"]["incorrect"] += 1 - evaluation_agg_ent_indices[true["label"]]["exact"]["incorrect_indices"].append( - (instance_index, within_instance_index) - ) - - found_overlap = True - - else: - # Scenario VI: Entities overlap, but the entity type is - # different. - - # overall results - evaluation["strict"]["incorrect"] += 1 - evaluation_ent_indices["strict"]["incorrect_indices"].append( - (instance_index, within_instance_index) - ) - evaluation["ent_type"]["incorrect"] += 1 - evaluation_ent_indices["ent_type"]["incorrect_indices"].append( - (instance_index, within_instance_index) - ) - evaluation["partial"]["partial"] += 1 - evaluation_ent_indices["partial"]["partial_indices"].append( - (instance_index, within_instance_index) - ) - evaluation["exact"]["incorrect"] += 1 - evaluation_ent_indices["exact"]["incorrect_indices"].append( - (instance_index, within_instance_index) - ) - - # aggregated by entity type results - # Results against the true entity - - evaluation_agg_entities_type[true["label"]]["strict"]["incorrect"] += 1 - evaluation_agg_ent_indices[true["label"]]["strict"]["incorrect_indices"].append( - (instance_index, within_instance_index) - ) - evaluation_agg_entities_type[true["label"]]["partial"]["partial"] += 1 - evaluation_agg_ent_indices[true["label"]]["partial"]["partial_indices"].append( - (instance_index, within_instance_index) - ) - evaluation_agg_entities_type[true["label"]]["ent_type"]["incorrect"] += 1 - evaluation_agg_ent_indices[true["label"]]["ent_type"]["incorrect_indices"].append( - (instance_index, within_instance_index) - ) - evaluation_agg_entities_type[true["label"]]["exact"]["incorrect"] += 1 - evaluation_agg_ent_indices[true["label"]]["exact"]["incorrect_indices"].append( - (instance_index, within_instance_index) - ) - - # Results against the predicted entity - # evaluation_agg_entities_type[pred['label']]['strict']['spurious'] += 1 - found_overlap = True - - # Scenario II: Entities are spurious (i.e., over-generated). - if not found_overlap: - # Overall results - evaluation["strict"]["spurious"] += 1 - evaluation_ent_indices["strict"]["spurious_indices"].append((instance_index, within_instance_index)) - evaluation["ent_type"]["spurious"] += 1 - evaluation_ent_indices["ent_type"]["spurious_indices"].append((instance_index, within_instance_index)) - evaluation["partial"]["spurious"] += 1 - evaluation_ent_indices["partial"]["spurious_indices"].append((instance_index, within_instance_index)) - evaluation["exact"]["spurious"] += 1 - evaluation_ent_indices["exact"]["spurious_indices"].append((instance_index, within_instance_index)) - - # Aggregated by entity type results - - # a over-generated entity with a valid tag should be - # attributed to the respective tag only - if pred["label"] in tags: - spurious_tags = [pred["label"]] - else: - # NOTE: when pred.e_type is not found in valid tags - # or when it simply does not appear in the test set, then it is - # spurious, but it is not clear where to assign it at the tag - # level. In this case, it is applied to all target_tags - # found in this example. This will mean that the sum of the - # evaluation_agg_entities will not equal evaluation. - - spurious_tags = tags - - for true in spurious_tags: - evaluation_agg_entities_type[true]["strict"]["spurious"] += 1 - evaluation_agg_ent_indices[true]["strict"]["spurious_indices"].append( - (instance_index, within_instance_index) - ) - evaluation_agg_entities_type[true]["ent_type"]["spurious"] += 1 - evaluation_agg_ent_indices[true]["ent_type"]["spurious_indices"].append( - (instance_index, within_instance_index) - ) - evaluation_agg_entities_type[true]["partial"]["spurious"] += 1 - evaluation_agg_ent_indices[true]["partial"]["spurious_indices"].append( - (instance_index, within_instance_index) - ) - evaluation_agg_entities_type[true]["exact"]["spurious"] += 1 - evaluation_agg_ent_indices[true]["exact"]["spurious_indices"].append( - (instance_index, within_instance_index) - ) - - # Scenario III: Entity was missed entirely. - for within_instance_index, true in enumerate(true_named_entities): - if true in true_which_overlapped_with_pred: - continue - - # overall results - evaluation["strict"]["missed"] += 1 - evaluation_ent_indices["strict"]["missed_indices"].append((instance_index, within_instance_index)) - evaluation["ent_type"]["missed"] += 1 - evaluation_ent_indices["ent_type"]["missed_indices"].append((instance_index, within_instance_index)) - evaluation["partial"]["missed"] += 1 - evaluation_ent_indices["partial"]["missed_indices"].append((instance_index, within_instance_index)) - evaluation["exact"]["missed"] += 1 - evaluation_ent_indices["exact"]["missed_indices"].append((instance_index, within_instance_index)) - - # for the agg. by label - evaluation_agg_entities_type[true["label"]]["strict"]["missed"] += 1 - evaluation_agg_ent_indices[true["label"]]["strict"]["missed_indices"].append( - (instance_index, within_instance_index) - ) - evaluation_agg_entities_type[true["label"]]["ent_type"]["missed"] += 1 - evaluation_agg_ent_indices[true["label"]]["ent_type"]["missed_indices"].append( - (instance_index, within_instance_index) - ) - evaluation_agg_entities_type[true["label"]]["partial"]["missed"] += 1 - evaluation_agg_ent_indices[true["label"]]["partial"]["missed_indices"].append( - (instance_index, within_instance_index) - ) - evaluation_agg_entities_type[true["label"]]["exact"]["missed"] += 1 - evaluation_agg_ent_indices[true["label"]]["exact"]["missed_indices"].append( - (instance_index, within_instance_index) - ) - - # Compute 'possible', 'actual' according to SemEval-2013 Task 9.1 on the - # overall results, and use these to calculate precision and recall. - for eval_type in evaluation: - evaluation[eval_type] = compute_actual_possible(evaluation[eval_type]) - - # Compute 'possible', 'actual', and precision and recall on entity level - # results. Start by cycling through the accumulated results. - for entity_type, entity_level in evaluation_agg_entities_type.items(): - # Cycle through the evaluation types for each dict containing entity - # level results. - - for eval_type in entity_level: - evaluation_agg_entities_type[entity_type][eval_type] = compute_actual_possible(entity_level[eval_type]) - - return evaluation, evaluation_agg_entities_type, evaluation_ent_indices, evaluation_agg_ent_indices - - -def compute_actual_possible(results: dict) -> dict: - """ - Takes a result dict that has been output by compute metrics. - Returns the results' dict with actual, possible populated. - - When the results dicts is from partial or ent_type metrics, then - partial_or_type=True to ensure the right calculation is used for - calculating precision and recall. - """ - - correct = results["correct"] - incorrect = results["incorrect"] - partial = results["partial"] - missed = results["missed"] - spurious = results["spurious"] - - # Possible: number annotations in the gold-standard which contribute to the - # final score - possible = correct + incorrect + partial + missed - - # Actual: number of annotations produced by the NER system - actual = correct + incorrect + partial + spurious - - results["actual"] = actual - results["possible"] = possible - - return results - - -def compute_precision_recall(results: dict, partial_or_type: bool = False) -> dict: - """ - Takes a result dict that has been output by compute metrics. - Returns the results' dict with precision and recall populated. - - When the results dicts is from partial or ent_type metrics, then - partial_or_type=True to ensure the right calculation is used for - calculating precision and recall. - """ - - actual = results["actual"] - possible = results["possible"] - partial = results["partial"] - correct = results["correct"] - - if partial_or_type: - precision = (correct + 0.5 * partial) / actual if actual > 0 else 0 - recall = (correct + 0.5 * partial) / possible if possible > 0 else 0 - - else: - precision = correct / actual if actual > 0 else 0 - recall = correct / possible if possible > 0 else 0 - - results["precision"] = precision - results["recall"] = recall - results["f1"] = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 - - return results - - -def compute_precision_recall_wrapper(results: dict) -> dict: - """ - Wraps the compute_precision_recall function and runs on a dict of results - """ - - results_a = { - key: compute_precision_recall(value, True) for key, value in results.items() if key in ["partial", "ent_type"] - } - results_b = {key: compute_precision_recall(value) for key, value in results.items() if key in ["strict", "exact"]} - - results = {**results_a, **results_b} - - return results diff --git a/src/nervaluate/evaluator.py b/src/nervaluate/evaluator.py new file mode 100644 index 0000000..f8695a1 --- /dev/null +++ b/src/nervaluate/evaluator.py @@ -0,0 +1,437 @@ +from typing import List, Dict, Any, Union +import pandas as pd + +from .entities import EvaluationResult, EvaluationIndices +from .strategies import ( + EvaluationStrategy, + StrictEvaluation, + PartialEvaluation, + EntityTypeEvaluation, + ExactEvaluation, +) +from .loaders import DataLoader, ConllLoader, ListLoader, DictLoader +from .entities import Entity + + +class Evaluator: + """Main evaluator class for NER evaluation.""" + + def __init__(self, true: Any, pred: Any, tags: List[str], loader: str = "default") -> None: + """ + Initialize the evaluator. + + Args: + true: True entities in any supported format + pred: Predicted entities in any supported format + tags: List of valid entity tags + loader: Name of the loader to use + """ + self.tags = tags + self._setup_loaders() + self._load_data(true, pred, loader) + self._setup_evaluation_strategies() + + def _setup_loaders(self) -> None: + """Setup available data loaders.""" + self.loaders: Dict[str, DataLoader] = {"conll": ConllLoader(), "list": ListLoader(), "dict": DictLoader()} + + def _setup_evaluation_strategies(self) -> None: + """Setup evaluation strategies.""" + self.strategies: Dict[str, EvaluationStrategy] = { + "strict": StrictEvaluation(), + "partial": PartialEvaluation(), + "ent_type": EntityTypeEvaluation(), + "exact": ExactEvaluation(), + } + + def _load_data(self, true: Any, pred: Any, loader: str) -> None: + """Load the true and predicted data.""" + if loader == "default": + # Try to infer the loader based on input type + if isinstance(true, str): + loader = "conll" + elif isinstance(true, list) and true and isinstance(true[0], list): + if isinstance(true[0][0], dict): + loader = "dict" + else: + loader = "list" + else: + raise ValueError("Could not infer loader from input type") + + if loader not in self.loaders: + raise ValueError(f"Unknown loader: {loader}") + + # For list loader, check document lengths before loading + if loader == "list": + if len(true) != len(pred): + raise ValueError("Number of predicted documents does not equal true") + + # Check that each document has the same length + for i, (true_doc, pred_doc) in enumerate(zip(true, pred)): + if len(true_doc) != len(pred_doc): + raise ValueError(f"Document {i} has different lengths: true={len(true_doc)}, pred={len(pred_doc)}") + + self.true = self.loaders[loader].load(true) + self.pred = self.loaders[loader].load(pred) + + if len(self.true) != len(self.pred): + raise ValueError("Number of predicted documents does not equal true") + + def evaluate(self) -> Dict[str, Any]: + """ + Run the evaluation. + + Returns: + Dictionary containing evaluation results for each strategy and entity type + """ + results = {} + # Get unique tags that appear in either true or predicted data + used_tags = set() # type: ignore + for doc in self.true: + used_tags.update(e.label for e in doc) + for doc in self.pred: + used_tags.update(e.label for e in doc) + # Only keep tags that are both used and in the allowed tags list + used_tags = used_tags.intersection(set(self.tags)) + + entity_results: Dict[str, Dict[str, EvaluationResult]] = {tag: {} for tag in used_tags} + indices = {} + entity_indices: Dict[str, Dict[str, EvaluationIndices]] = {tag: {} for tag in used_tags} + + # Evaluate each document + for doc_idx, (true_doc, pred_doc) in enumerate(zip(self.true, self.pred)): + # Filter entities by valid tags + true_doc = [e for e in true_doc if e.label in self.tags] + pred_doc = [e for e in pred_doc if e.label in self.tags] + + # Evaluate with each strategy + for strategy_name, strategy in self.strategies.items(): + result, doc_indices = strategy.evaluate(true_doc, pred_doc, self.tags, doc_idx) + + # Update overall results + if strategy_name not in results: + results[strategy_name] = result + indices[strategy_name] = doc_indices + else: + self._merge_results(results[strategy_name], result) + self._merge_indices(indices[strategy_name], doc_indices) + + # Update entity-specific results + for tag in used_tags: + # Filter entities for this specific tag + true_tag_doc = [e for e in true_doc if e.label == tag] + pred_tag_doc = [e for e in pred_doc if e.label == tag] + + # Evaluate only entities of this tag + tag_result, tag_indices = strategy.evaluate(true_tag_doc, pred_tag_doc, [tag], doc_idx) + + if tag not in entity_results: + entity_results[tag] = {} + entity_indices[tag] = {} + if strategy_name not in entity_results[tag]: + entity_results[tag][strategy_name] = tag_result + entity_indices[tag][strategy_name] = tag_indices + else: + self._merge_results(entity_results[tag][strategy_name], tag_result) + self._merge_indices(entity_indices[tag][strategy_name], tag_indices) + + return { + "overall": results, + "entities": entity_results, + "overall_indices": indices, + "entity_indices": entity_indices, + } + + @staticmethod + def _merge_results(target: EvaluationResult, source: EvaluationResult) -> None: + """Merge two evaluation results.""" + target.correct += source.correct + target.incorrect += source.incorrect + target.partial += source.partial + target.missed += source.missed + target.spurious += source.spurious + target.compute_metrics() + + @staticmethod + def _merge_indices(target: EvaluationIndices, source: EvaluationIndices) -> None: + """Merge two evaluation indices.""" + target.correct_indices.extend(source.correct_indices) + target.incorrect_indices.extend(source.incorrect_indices) + target.partial_indices.extend(source.partial_indices) + target.missed_indices.extend(source.missed_indices) + target.spurious_indices.extend(source.spurious_indices) + + def results_to_dataframe(self) -> Any: + """Convert results to a pandas DataFrame.""" + results = self.evaluate() + + # Flatten the results structure + flat_results = {} + for category, category_results in results.items(): + for strategy, strategy_results in category_results.items(): + for metric, value in strategy_results.__dict__.items(): + key = f"{category}.{strategy}.{metric}" + flat_results[key] = value + + return pd.DataFrame([flat_results]) + + def summary_report(self, mode: str = "overall", scenario: str = "strict", digits: int = 2) -> str: + """ + Generate a summary report of the evaluation results. + + Args: + mode: Either 'overall' for overall metrics or 'entities' for per-entity metrics. + scenario: The scenario to report on. Only used when mode is 'entities'. + Must be one of: + - 'strict' exact boundary surface string match and entity type; + - 'exact': exact boundary match over the surface string and entity type; + - 'partial': partial boundary match over the surface string, regardless of the type; + - 'ent_type': exact boundary match over the surface string, regardless of the type; + digits: The number of digits to round the results to. + + Returns: + A string containing the summary report. + + Raises: + ValueError: If the scenario or mode is invalid. + """ + valid_scenarios = {"strict", "ent_type", "partial", "exact"} + valid_modes = {"overall", "entities"} + + if mode not in valid_modes: + raise ValueError(f"Invalid mode: must be one of {valid_modes}") + + if mode == "entities" and scenario not in valid_scenarios: + raise ValueError(f"Invalid scenario: must be one of {valid_scenarios}") + + headers = ["correct", "incorrect", "partial", "missed", "spurious", "precision", "recall", "f1-score"] + rows = [headers] + + results = self.evaluate() + if mode == "overall": + # Process overall results - show all scenarios + results_data = results["overall"] + for eval_schema in sorted(valid_scenarios): # Sort to ensure consistent order + if eval_schema not in results_data: + continue + results_schema = results_data[eval_schema] + rows.append( + [ + eval_schema, + results_schema.correct, + results_schema.incorrect, + results_schema.partial, + results_schema.missed, + results_schema.spurious, + results_schema.precision, + results_schema.recall, + results_schema.f1, + ] + ) + else: + # Process entity-specific results for the specified scenario only + results_data = results["entities"] + target_names = sorted(results_data.keys()) + for ent_type in target_names: + if scenario not in results_data[ent_type]: + continue # Skip if scenario not available for this entity type + + results_ent = results_data[ent_type][scenario] + rows.append( + [ + ent_type, + results_ent.correct, + results_ent.incorrect, + results_ent.partial, + results_ent.missed, + results_ent.spurious, + results_ent.precision, + results_ent.recall, + results_ent.f1, + ] + ) + + # Format the report + name_width = max(len(str(row[0])) for row in rows) + width = max(name_width, digits) + head_fmt = "{:>{width}s} " + " {:>11}" * len(headers) + report = f"Scenario: {scenario if mode == 'entities' else 'all'}\n\n" + head_fmt.format( + "", *headers, width=width + ) + report += "\n\n" + row_fmt = "{:>{width}s} " + " {:>11}" * 5 + " {:>11.{digits}f}" * 3 + "\n" + + for row in rows[1:]: + report += row_fmt.format(*row, width=width, digits=digits) + + return report + + def summary_report_indices( # pylint: disable=too-many-branches + self, mode: str = "overall", scenario: str = "strict", colors: bool = False + ) -> str: + """ + Generate a summary report of the evaluation indices. + + Args: + mode: Either 'overall' for overall metrics or 'entities' for per-entity metrics. + scenario: The scenario to report on. Must be one of: 'strict', 'ent_type', 'partial', 'exact'. + Only used when mode is 'entities'. Defaults to 'strict'. + colors: Whether to use colors in the output. Defaults to False. + + Returns: + A string containing the summary report of indices. + + Raises: + ValueError: If the scenario or mode is invalid. + """ + valid_scenarios = {"strict", "ent_type", "partial", "exact"} + valid_modes = {"overall", "entities"} + + if mode not in valid_modes: + raise ValueError(f"Invalid mode: must be one of {valid_modes}") + + if mode == "entities" and scenario not in valid_scenarios: + raise ValueError(f"Invalid scenario: must be one of {valid_scenarios}") + + # ANSI color codes + COLORS = { + "reset": "\033[0m", + "bold": "\033[1m", + "red": "\033[91m", + "green": "\033[92m", + "yellow": "\033[93m", + "blue": "\033[94m", + "magenta": "\033[95m", + "cyan": "\033[96m", + "white": "\033[97m", + } + + def colorize(text: str, color: str) -> str: + """Helper function to colorize text if colors are enabled.""" + if colors: + return f"{COLORS[color]}{text}{COLORS['reset']}" + return text + + def get_prediction_info(pred: Union[Entity, str]) -> str: + """Helper function to get prediction info based on pred type.""" + if isinstance(pred, Entity): + return f"Label={pred.label}, Start={pred.start}, End={pred.end}" + # String (BIO tag) + return f"Tag={pred}" + + results = self.evaluate() + report = "" + + # Create headers for the table + headers = ["Category", "Instance", "Entity", "Details"] + header_fmt = "{:<20} {:<10} {:<8} {:<25}" + row_fmt = "{:<20} {:<10} {:<8} {:<10}" + + if mode == "overall": + # Get the indices from the overall results + indices_data = results["overall_indices"][scenario] + report += f"\n{colorize('Indices for error schema', 'bold')} '{colorize(scenario, 'cyan')}':\n\n" + report += colorize(header_fmt.format(*headers), "bold") + "\n" + report += colorize("-" * 78, "white") + "\n" + + for category, indices in indices_data.__dict__.items(): + if not category.endswith("_indices"): + continue + category_name = category.replace("_indices", "").replace("_", " ").capitalize() + + # Color mapping for categories + category_colors = { + "Correct": "green", + "Incorrect": "red", + "Partial": "yellow", + "Missed": "magenta", + "Spurious": "blue", + } + + if indices: + for instance_index, entity_index in indices: + if self.pred != [[]]: + pred = self.pred[instance_index][entity_index] + prediction_info = get_prediction_info(pred) + report += ( + row_fmt.format( + colorize(category_name, category_colors.get(category_name, "white")), + f"{instance_index}", + f"{entity_index}", + prediction_info, + ) + + "\n" + ) + else: + report += ( + row_fmt.format( + colorize(category_name, category_colors.get(category_name, "white")), + f"{instance_index}", + f"{entity_index}", + "No prediction info", + ) + + "\n" + ) + else: + report += ( + row_fmt.format( + colorize(category_name, category_colors.get(category_name, "white")), "-", "-", "None" + ) + + "\n" + ) + else: + # Get the indices from the entity-specific results + for entity_type, entity_results in results["entity_indices"].items(): + report += f"\n{colorize('Entity Type', 'bold')}: {colorize(entity_type, 'cyan')}\n" + report += f"{colorize('Error Schema', 'bold')}: '{colorize(scenario, 'cyan')}'\n\n" + report += colorize(header_fmt.format(*headers), "bold") + "\n" + report += colorize("-" * 78, "white") + "\n" + + error_data = entity_results[scenario] + for category, indices in error_data.__dict__.items(): + if not category.endswith("_indices"): + continue + category_name = category.replace("_indices", "").replace("_", " ").capitalize() + + # Color mapping for categories + category_colors = { + "Correct": "green", + "Incorrect": "red", + "Partial": "yellow", + "Missed": "magenta", + "Spurious": "blue", + } + + if indices: + for instance_index, entity_index in indices: + if self.pred != [[]]: + pred = self.pred[instance_index][entity_index] + prediction_info = get_prediction_info(pred) + report += ( + row_fmt.format( + colorize(category_name, category_colors.get(category_name, "white")), + f"{instance_index}", + f"{entity_index}", + prediction_info, + ) + + "\n" + ) + else: + report += ( + row_fmt.format( + colorize(category_name, category_colors.get(category_name, "white")), + f"{instance_index}", + f"{entity_index}", + "No prediction info", + ) + + "\n" + ) + else: + report += ( + row_fmt.format( + colorize(category_name, category_colors.get(category_name, "white")), "-", "-", "None" + ) + + "\n" + ) + + return report diff --git a/src/nervaluate/loaders.py b/src/nervaluate/loaders.py new file mode 100644 index 0000000..98ba819 --- /dev/null +++ b/src/nervaluate/loaders.py @@ -0,0 +1,189 @@ +from abc import ABC, abstractmethod +from typing import List, Dict, Any + +from .entities import Entity + + +class DataLoader(ABC): + """Abstract base class for data loaders.""" + + @abstractmethod + def load(self, data: Any) -> List[List[Entity]]: + """Load data into a list of entity lists.""" + + +class ConllLoader(DataLoader): + """Loader for CoNLL format data.""" + + def load(self, data: str) -> List[List[Entity]]: # pylint: disable=too-many-branches + """Load CoNLL format data into a list of Entity lists.""" + if not isinstance(data, str): + raise ValueError("ConllLoader expects string input") + + if not data: + return [] + + result: List[List[Entity]] = [] + # Strip trailing whitespace and newlines to avoid empty documents + documents = data.rstrip().split("\n\n") + + for doc in documents: + if not doc.strip(): + result.append([]) + continue + + current_doc = [] + start_offset = None + end_offset = None + ent_type = None + has_entities = False + + for offset, line in enumerate(doc.split("\n")): + if not line.strip(): + continue + + parts = line.split("\t") + if len(parts) < 2: + raise ValueError(f"Invalid CoNLL format: line '{line}' does not contain a tab separator") + + token_tag = parts[1] + + if token_tag == "O": + if ent_type is not None and start_offset is not None: + end_offset = offset - 1 + if isinstance(start_offset, int) and isinstance(end_offset, int): + current_doc.append(Entity(label=ent_type, start=start_offset, end=end_offset)) + start_offset = None + end_offset = None + ent_type = None + + elif ent_type is None: + if not (token_tag.startswith("B-") or token_tag.startswith("I-")): + raise ValueError(f"Invalid tag format: {token_tag}") + ent_type = token_tag[2:] # Remove B- or I- prefix + start_offset = offset + has_entities = True + + elif ent_type != token_tag[2:] or (ent_type == token_tag[2:] and token_tag[:1] == "B"): + end_offset = offset - 1 + if isinstance(start_offset, int) and isinstance(end_offset, int): + current_doc.append(Entity(label=ent_type, start=start_offset, end=end_offset)) + + # start of a new entity + if not (token_tag.startswith("B-") or token_tag.startswith("I-")): + raise ValueError(f"Invalid tag format: {token_tag}") + ent_type = token_tag[2:] + start_offset = offset + end_offset = None + has_entities = True + + # Catches an entity that goes up until the last token + if ent_type is not None and start_offset is not None and end_offset is None: + if isinstance(start_offset, int): + current_doc.append(Entity(label=ent_type, start=start_offset, end=len(doc.split("\n")) - 1)) + has_entities = True + + result.append(current_doc if has_entities else []) + + return result + + +class ListLoader(DataLoader): + """Loader for list format data.""" + + def load(self, data: List[List[str]]) -> List[List[Entity]]: # pylint: disable=too-many-branches + """Load list format data into a list of entity lists.""" + if not isinstance(data, list): + raise ValueError("ListLoader expects list input") + + if not data: + return [] + + result = [] + + for doc in data: + if not isinstance(doc, list): + raise ValueError("Each document must be a list of tags") + + current_doc = [] + start_offset = None + end_offset = None + ent_type = None + + for offset, token_tag in enumerate(doc): + if not isinstance(token_tag, str): + raise ValueError(f"Invalid tag type: {type(token_tag)}") + + if token_tag == "O": + if ent_type is not None and start_offset is not None: + end_offset = offset - 1 + if isinstance(start_offset, int) and isinstance(end_offset, int): + current_doc.append(Entity(label=ent_type, start=start_offset, end=end_offset)) + start_offset = None + end_offset = None + ent_type = None + + elif ent_type is None: + if not (token_tag.startswith("B-") or token_tag.startswith("I-")): + raise ValueError(f"Invalid tag format: {token_tag}") + ent_type = token_tag[2:] # Remove B- or I- prefix + start_offset = offset + + elif ent_type != token_tag[2:] or (ent_type == token_tag[2:] and token_tag[:1] == "B"): + end_offset = offset - 1 + if isinstance(start_offset, int) and isinstance(end_offset, int): + current_doc.append(Entity(label=ent_type, start=start_offset, end=end_offset)) + + # start of a new entity + if not (token_tag.startswith("B-") or token_tag.startswith("I-")): + raise ValueError(f"Invalid tag format: {token_tag}") + ent_type = token_tag[2:] + start_offset = offset + end_offset = None + + # Catches an entity that goes up until the last token + if ent_type is not None and start_offset is not None and end_offset is None: + if isinstance(start_offset, int): + current_doc.append(Entity(label=ent_type, start=start_offset, end=len(doc) - 1)) + + result.append(current_doc) + + return result + + +class DictLoader(DataLoader): + """Loader for dictionary format data.""" + + def load(self, data: List[List[Dict[str, Any]]]) -> List[List[Entity]]: + """Load dictionary format data into a list of entity lists.""" + if not isinstance(data, list): + raise ValueError("DictLoader expects list input") + + if not data: + return [] + + result = [] + + for doc in data: + if not isinstance(doc, list): + raise ValueError("Each document must be a list of entity dictionaries") + + current_doc = [] + for entity in doc: + if not isinstance(entity, dict): + raise ValueError(f"Invalid entity type: {type(entity)}") + + required_keys = {"label", "start", "end"} + if not all(key in entity for key in required_keys): + raise ValueError(f"Entity missing required keys: {required_keys}") + + if not isinstance(entity["label"], str): + raise ValueError("Entity label must be a string") + + if not isinstance(entity["start"], int) or not isinstance(entity["end"], int): + raise ValueError("Entity start and end must be integers") + + current_doc.append(Entity(label=entity["label"], start=entity["start"], end=entity["end"])) + result.append(current_doc) + + return result diff --git a/src/nervaluate/reporting.py b/src/nervaluate/reporting.py deleted file mode 100644 index 0e1b85f..0000000 --- a/src/nervaluate/reporting.py +++ /dev/null @@ -1,198 +0,0 @@ -import warnings - - -def summary_report_ent(results_agg_entities_type: dict, scenario: str = "strict", digits: int = 2) -> str: - """ - Generate a summary report of the evaluation results for a given scenario. - - :param results_agg_entities_type: Dictionary containing the evaluation results. - :param scenario: The scenario to report on. Must be one of: 'strict', 'ent_type', 'partial', 'exact'. - Defaults to 'strict'. - :param digits: The number of digits to round the results to. - - :returns: - A string containing the summary report. - - :raises ValueError: - If the scenario is invalid. - """ - warnings.warn( - "summary_report_ent() is deprecated and will be removed in a future release. " - "In the future the Evaluator will contain a method `summary_report` with the same functionality.", - DeprecationWarning, - stacklevel=2 - ) - - valid_scenarios = {"strict", "ent_type", "partial", "exact"} - if scenario not in valid_scenarios: - raise ValueError(f"Invalid scenario: must be one of {valid_scenarios}") - - target_names = sorted(results_agg_entities_type.keys()) - headers = ["correct", "incorrect", "partial", "missed", "spurious", "precision", "recall", "f1-score"] - rows = [headers] - - # Aggregate results by entity type for the specified scenario - for ent_type in target_names: - if scenario not in results_agg_entities_type[ent_type]: - raise ValueError(f"Scenario '{scenario}' not found in results for entity type '{ent_type}'") - - results = results_agg_entities_type[ent_type][scenario] - rows.append( - [ - ent_type, - results["correct"], - results["incorrect"], - results["partial"], - results["missed"], - results["spurious"], - results["precision"], - results["recall"], - results["f1"], - ] - ) - - name_width = max(len(cn) for cn in target_names) - width = max(name_width, digits) - head_fmt = "{:>{width}s} " + " {:>11}" * len(headers) - report = head_fmt.format("", *headers, width=width) - report += "\n\n" - row_fmt = "{:>{width}s} " + " {:>11}" * 5 + " {:>11.{digits}f}" * 3 + "\n" - - for row in rows[1:]: - report += row_fmt.format(*row, width=width, digits=digits) - - return report - - -def summary_report_overall(results: dict, digits: int = 2) -> str: - """ - Generate a summary report of the evaluation results for the overall scenario. - - :param results: Dictionary containing the evaluation results. - :param digits: The number of digits to round the results to. - - :returns: - A string containing the summary report. - """ - warnings.warn( - "summary_report_overall() is deprecated and will be removed in a future. " - "In the future the Evaluator will contain a method `summary_report` with the same functionality.", - DeprecationWarning, - stacklevel=2 - ) - - - headers = ["correct", "incorrect", "partial", "missed", "spurious", "precision", "recall", "f1-score"] - rows = [headers] - - for k, v in results.items(): - rows.append( - [ - k, - v["correct"], - v["incorrect"], - v["partial"], - v["missed"], - v["spurious"], - v["precision"], - v["recall"], - v["f1"], - ] - ) - - target_names = sorted(results.keys()) - name_width = max(len(cn) for cn in target_names) - width = max(name_width, digits) - head_fmt = "{:>{width}s} " + " {:>11}" * len(headers) - report = head_fmt.format("", *headers, width=width) - report += "\n\n" - row_fmt = "{:>{width}s} " + " {:>11}" * 5 + " {:>11.{digits}f}" * 3 + "\n" - - for row in rows[1:]: - report += row_fmt.format(*row, width=width, digits=digits) - - return report - - -def summary_report_ents_indices(evaluation_agg_indices: dict, error_schema: str, preds: list | None = None) -> str: - """ - Generate a summary report of the evaluation results for the overall scenario. - - :param evaluation_agg_indices: Dictionary containing the evaluation results. - :param error_schema: The error schema to report on. - :param preds: List of predicted named entities. - - :returns: - A string containing the summary report. - """ - warnings.warn( - "summary_report_ents_indices() is deprecated and will be made part of the Evaluator class in the future. " - "In the future the Evaluator will contain a method `summary_report_indices` with the same functionality.", - DeprecationWarning, - stacklevel=2 - ) - - - if preds is None: - preds = [[]] - report = "" - for entity_type, entity_results in evaluation_agg_indices.items(): - report += f"\nEntity Type: {entity_type}\n" - error_data = entity_results[error_schema] - report += f" Error Schema: '{error_schema}'\n" - for category, indices in error_data.items(): - category_name = category.replace("_", " ").capitalize() - report += f" ({entity_type}) {category_name}:\n" - if indices: - for instance_index, entity_index in indices: - if preds is not None and preds != [[]]: - pred = preds[instance_index][entity_index] - prediction_info = f"Label={pred['label']}, Start={pred['start']}, End={pred['end']}" - report += f" - Instance {instance_index}, Entity {entity_index}: {prediction_info}\n" - else: - report += f" - Instance {instance_index}, Entity {entity_index}\n" - else: - report += " - None\n" - return report - - -def summary_report_overall_indices(evaluation_indices: dict, error_schema: str, preds: list | None = None) -> str: - """ - Generate a summary report of the evaluation results for the overall scenario. - - :param evaluation_indices: Dictionary containing the evaluation results. - :param error_schema: The error schema to report on. - :param preds: List of predicted named entities. - - :returns: - A string containing the summary report. - """ - warnings.warn( - "summary_report_ents_indices() is deprecated and will be removed in a future release. " - "In the future the Evaluator will contain a method `summary_report_indices` with the same functionality.", - DeprecationWarning, - stacklevel=2 - ) - report = "" - assert error_schema in evaluation_indices, f"Error schema '{error_schema}' not found in the results." - - error_data = evaluation_indices[error_schema] - report += f"Indices for error schema '{error_schema}':\n\n" - - for category, indices in error_data.items(): - category_name = category.replace("_", " ").capitalize() - report += f"{category_name}:\n" - if indices: - for instance_index, entity_index in indices: - if preds != [[]]: - # Retrieve the corresponding prediction - pred = preds[instance_index][entity_index] # type: ignore - prediction_info = f"Label={pred['label']}, Start={pred['start']}, End={pred['end']}" - report += f" - Instance {instance_index}, Entity {entity_index}: {prediction_info}\n" - else: - report += f" - Instance {instance_index}, Entity {entity_index}\n" - else: - report += " - None\n" - report += "\n" - - return report diff --git a/src/nervaluate/strategies.py b/src/nervaluate/strategies.py new file mode 100644 index 0000000..bcdc85d --- /dev/null +++ b/src/nervaluate/strategies.py @@ -0,0 +1,237 @@ +from abc import ABC, abstractmethod +from typing import List, Tuple + +from .entities import Entity, EvaluationResult, EvaluationIndices + + +class EvaluationStrategy(ABC): + """Abstract base class for evaluation strategies.""" + + @abstractmethod + def evaluate( + self, true_entities: List[Entity], pred_entities: List[Entity], tags: List[str], instance_index: int = 0 + ) -> Tuple[EvaluationResult, EvaluationIndices]: + """Evaluate the predicted entities against the true entities.""" + + +class StrictEvaluation(EvaluationStrategy): + """ + Strict evaluation strategy - entities must match exactly. + + If there's a predicted entity that perfectly matches a true entity and they have the same label + we mark it as correct. + If there's a predicted entity that doesn't perfectly match any true entity, we mark it as spurious. + If there's a true entity that doesn't perfecly match any predicted entity, we mark it as missed. + All other cases are marked as incorrect. + """ + + def evaluate( + self, true_entities: List[Entity], pred_entities: List[Entity], tags: List[str], instance_index: int = 0 + ) -> Tuple[EvaluationResult, EvaluationIndices]: + """ + Evaluate the predicted entities against the true entities using strict matching. + """ + result = EvaluationResult() + indices = EvaluationIndices() + matched_true = set() + + for pred_idx, pred in enumerate(pred_entities): + found_match = False + found_incorrect = False + + for true_idx, true in enumerate(true_entities): + if true_idx in matched_true: + continue + + # Check for perfect match (same boundaries and label) + if pred.label == true.label and pred.start == true.start and pred.end == true.end: + result.correct += 1 + indices.correct_indices.append((instance_index, pred_idx)) + matched_true.add(true_idx) + found_match = True + break + # Check for any overlap + if pred.start <= true.end and pred.end >= true.start: + result.incorrect += 1 + indices.incorrect_indices.append((instance_index, pred_idx)) + matched_true.add(true_idx) + found_incorrect = True + break + + if not found_match and not found_incorrect: + result.spurious += 1 + indices.spurious_indices.append((instance_index, pred_idx)) + + for true_idx, true in enumerate(true_entities): + if true_idx not in matched_true: + result.missed += 1 + indices.missed_indices.append((instance_index, true_idx)) + + result.compute_metrics() + return result, indices + + +class PartialEvaluation(EvaluationStrategy): + """ + Partial evaluation strategy - allows for partial matches. + + If there's a predicted entity that perfectly matches a true entity, we mark it as correct. + If there's a predicted entity that has some minimum overlap with a true entity we mark it as partial. + If there's a predicted entity that doesn't match any true entity, we mark it as spurious. + If there's a true entity that doesn't match any predicted entity, we mark it as missed. + + There's never entity type/label checking in this strategy, and there's never an entity marked as incorrect. + """ + + def evaluate( + self, true_entities: List[Entity], pred_entities: List[Entity], tags: List[str], instance_index: int = 0 + ) -> Tuple[EvaluationResult, EvaluationIndices]: + result = EvaluationResult() + indices = EvaluationIndices() + matched_true = set() + + for pred_idx, pred in enumerate(pred_entities): + found_match = False + + for true_idx, true in enumerate(true_entities): + if true_idx in matched_true: + continue + + # Check for overlap + if pred.start <= true.end and pred.end >= true.start: + if pred.start == true.start and pred.end == true.end: + result.correct += 1 + indices.correct_indices.append((instance_index, pred_idx)) + else: + result.partial += 1 + indices.partial_indices.append((instance_index, pred_idx)) + matched_true.add(true_idx) + found_match = True + break + + if not found_match: + result.spurious += 1 + indices.spurious_indices.append((instance_index, pred_idx)) + + for true_idx, true in enumerate(true_entities): + if true_idx not in matched_true: + result.missed += 1 + indices.missed_indices.append((instance_index, true_idx)) + + result.compute_metrics(partial_or_type=True) + return result, indices + + +class EntityTypeEvaluation(EvaluationStrategy): + """ + Entity type evaluation strategy - only checks entity types. + + In in strategy, we check for overlap between the predicted entity and the true entity. + + If there's a predicted entity that perfectly matches or only some minimum overlap with a + true entity, and the same label, we mark it as correct. + If there's a predicted entity that has some minimum overlap or perfectly matches but has + the wrong label we mark it as inccorrect. + If there's a predicted entity that doesn't match any true entity, we mark it as spurious. + If there's a true entity that doesn't match any predicted entity, we mark it as missed. + + # ToDo: define a minimum overlap threshold - see: https://github.com/MantisAI/nervaluate/pull/83 + """ + + def evaluate( + self, true_entities: List[Entity], pred_entities: List[Entity], tags: List[str], instance_index: int = 0 + ) -> Tuple[EvaluationResult, EvaluationIndices]: + result = EvaluationResult() + indices = EvaluationIndices() + matched_true = set() + + for pred_idx, pred in enumerate(pred_entities): + found_match = False + found_incorrect = False + + for true_idx, true in enumerate(true_entities): + if true_idx in matched_true: + continue + + # Check for any overlap (perfect or minimum) + if pred.start <= true.end and pred.end >= true.start: + if pred.label == true.label: + result.correct += 1 + indices.correct_indices.append((instance_index, pred_idx)) + matched_true.add(true_idx) + found_match = True + else: + result.incorrect += 1 + indices.incorrect_indices.append((instance_index, pred_idx)) + matched_true.add(true_idx) + found_incorrect = True + break + + if not found_match and not found_incorrect: + result.spurious += 1 + indices.spurious_indices.append((instance_index, pred_idx)) + + for true_idx, true in enumerate(true_entities): + if true_idx not in matched_true: + result.missed += 1 + indices.missed_indices.append((instance_index, true_idx)) + + result.compute_metrics(partial_or_type=True) + return result, indices + + +class ExactEvaluation(EvaluationStrategy): + """ + Exact evaluation strategy - exact boundary match over the surface string, regardless of the type. + + If there's a predicted entity that perfectly matches a true entity, regardless of the label, we mark it as correct. + If there's a predicted entity that has only some minimum overlap with a true entity, we mark it as incorrect. + If there's a predicted entity that doesn't match any true entity, we mark it as spurious. + If there's a true entity that doesn't match any predicted entity, we mark it as missed. + """ + + def evaluate( + self, true_entities: List[Entity], pred_entities: List[Entity], tags: List[str], instance_index: int = 0 + ) -> Tuple[EvaluationResult, EvaluationIndices]: + """ + Evaluate the predicted entities against the true entities using exact boundary matching. + Entity type is not considered in the matching. + """ + result = EvaluationResult() + indices = EvaluationIndices() + matched_true = set() + + for pred_idx, pred in enumerate(pred_entities): + found_match = False + found_incorrect = False + + for true_idx, true in enumerate(true_entities): + if true_idx in matched_true: + continue + + # Check for exact boundary match (regardless of label) + if pred.start == true.start and pred.end == true.end: + result.correct += 1 + indices.correct_indices.append((instance_index, pred_idx)) + matched_true.add(true_idx) + found_match = True + break + # Check for any overlap + if pred.start <= true.end and pred.end >= true.start: + result.incorrect += 1 + indices.incorrect_indices.append((instance_index, pred_idx)) + matched_true.add(true_idx) + found_incorrect = True + break + + if not found_match and not found_incorrect: + result.spurious += 1 + indices.spurious_indices.append((instance_index, pred_idx)) + + for true_idx, true in enumerate(true_entities): + if true_idx not in matched_true: + result.missed += 1 + indices.missed_indices.append((instance_index, true_idx)) + + result.compute_metrics() + return result, indices diff --git a/tests/test_entities.py b/tests/test_entities.py new file mode 100644 index 0000000..31b962a --- /dev/null +++ b/tests/test_entities.py @@ -0,0 +1,46 @@ +from nervaluate.entities import Entity, EvaluationResult + + +def test_entity_equality(): + """Test Entity equality comparison.""" + entity1 = Entity(label="PER", start=0, end=1) + entity2 = Entity(label="PER", start=0, end=1) + entity3 = Entity(label="ORG", start=0, end=1) + + assert entity1 == entity2 + assert entity1 != entity3 + assert entity1 != "not an entity" + + +def test_entity_hash(): + """Test Entity hashing.""" + entity1 = Entity(label="PER", start=0, end=1) + entity2 = Entity(label="PER", start=0, end=1) + entity3 = Entity(label="ORG", start=0, end=1) + + assert hash(entity1) == hash(entity2) + assert hash(entity1) != hash(entity3) + + +def test_evaluation_result_compute_metrics(): + """Test computation of evaluation metrics.""" + result = EvaluationResult(correct=5, incorrect=2, partial=1, missed=1, spurious=1) + + # Test strict metrics + result.compute_metrics(partial_or_type=False) + assert result.precision == 5 / 9 # 5/(5+2+1+1) + assert result.recall == 5 / (5 + 2 + 1 + 1) + + # Test partial metrics + result.compute_metrics(partial_or_type=True) + assert result.precision == 5.5 / 9 # (5+0.5*1)/(5+2+1+1) + assert result.recall == (5 + 0.5 * 1) / (5 + 2 + 1 + 1) + + +def test_evaluation_result_zero_cases(): + """Test evaluation metrics with zero values.""" + result = EvaluationResult() + result.compute_metrics() + assert result.precision == 0 + assert result.recall == 0 + assert result.f1 == 0 diff --git a/tests/test_evaluator.py b/tests/test_evaluator.py index 6c12d4d..f4fc0ef 100644 --- a/tests/test_evaluator.py +++ b/tests/test_evaluator.py @@ -1,1172 +1,80 @@ -# pylint: disable=too-many-lines -import pandas as pd +import pytest +from nervaluate.evaluator import Evaluator -from nervaluate import Evaluator - -def test_results_to_dataframe(): - """ - Test the results_to_dataframe method. - """ - # Setup - evaluator = Evaluator( - true=[["B-LOC", "I-LOC", "O"], ["B-PER", "O", "O"]], - pred=[["B-LOC", "I-LOC", "O"], ["B-PER", "I-PER", "O"]], - tags=["LOC", "PER"], - ) - - # Mock results data for the purpose of this test - evaluator.results = { - "strict": { - "correct": 10, - "incorrect": 5, - "partial": 3, - "missed": 2, - "spurious": 4, - "precision": 0.625, - "recall": 0.6667, - "f1": 0.6452, - "entities": { - "LOC": {"correct": 4, "incorrect": 1, "partial": 0, "missed": 1, "spurious": 2}, - "PER": {"correct": 3, "incorrect": 2, "partial": 1, "missed": 0, "spurious": 1}, - "ORG": {"correct": 3, "incorrect": 2, "partial": 2, "missed": 1, "spurious": 1}, - }, - }, - "ent_type": { - "correct": 8, - "incorrect": 4, - "partial": 1, - "missed": 3, - "spurious": 3, - "precision": 0.5714, - "recall": 0.6154, - "f1": 0.5926, - "entities": { - "LOC": {"correct": 3, "incorrect": 2, "partial": 1, "missed": 1, "spurious": 1}, - "PER": {"correct": 2, "incorrect": 1, "partial": 0, "missed": 2, "spurious": 0}, - "ORG": {"correct": 3, "incorrect": 1, "partial": 0, "missed": 0, "spurious": 2}, - }, - }, - "partial": { - "correct": 7, - "incorrect": 3, - "partial": 4, - "missed": 1, - "spurious": 5, - "precision": 0.5385, - "recall": 0.6364, - "f1": 0.5833, - "entities": { - "LOC": {"correct": 2, "incorrect": 1, "partial": 1, "missed": 1, "spurious": 2}, - "PER": {"correct": 3, "incorrect": 1, "partial": 1, "missed": 0, "spurious": 1}, - "ORG": {"correct": 2, "incorrect": 1, "partial": 2, "missed": 0, "spurious": 2}, - }, - }, - "exact": { - "correct": 9, - "incorrect": 6, - "partial": 2, - "missed": 2, - "spurious": 2, - "precision": 0.6, - "recall": 0.6429, - "f1": 0.6207, - "entities": { - "LOC": {"correct": 4, "incorrect": 1, "partial": 0, "missed": 1, "spurious": 1}, - "PER": {"correct": 3, "incorrect": 3, "partial": 0, "missed": 0, "spurious": 0}, - "ORG": {"correct": 2, "incorrect": 2, "partial": 2, "missed": 1, "spurious": 1}, - }, - }, - } - - # Expected DataFrame - expected_data = { - "correct": {"strict": 10, "ent_type": 8, "partial": 7, "exact": 9}, - "incorrect": {"strict": 5, "ent_type": 4, "partial": 3, "exact": 6}, - "partial": {"strict": 3, "ent_type": 1, "partial": 4, "exact": 2}, - "missed": {"strict": 2, "ent_type": 3, "partial": 1, "exact": 2}, - "spurious": {"strict": 4, "ent_type": 3, "partial": 5, "exact": 2}, - "precision": {"strict": 0.625, "ent_type": 0.5714, "partial": 0.5385, "exact": 0.6}, - "recall": {"strict": 0.6667, "ent_type": 0.6154, "partial": 0.6364, "exact": 0.6429}, - "f1": {"strict": 0.6452, "ent_type": 0.5926, "partial": 0.5833, "exact": 0.6207}, - "entities.LOC.correct": {"strict": 4, "ent_type": 3, "partial": 2, "exact": 4}, - "entities.LOC.incorrect": {"strict": 1, "ent_type": 2, "partial": 1, "exact": 1}, - "entities.LOC.partial": {"strict": 0, "ent_type": 1, "partial": 1, "exact": 0}, - "entities.LOC.missed": {"strict": 1, "ent_type": 1, "partial": 1, "exact": 1}, - "entities.LOC.spurious": {"strict": 2, "ent_type": 1, "partial": 2, "exact": 1}, - "entities.PER.correct": {"strict": 3, "ent_type": 2, "partial": 3, "exact": 3}, - "entities.PER.incorrect": {"strict": 2, "ent_type": 1, "partial": 1, "exact": 3}, - "entities.PER.partial": {"strict": 1, "ent_type": 0, "partial": 1, "exact": 0}, - "entities.PER.missed": {"strict": 0, "ent_type": 2, "partial": 0, "exact": 0}, - "entities.PER.spurious": {"strict": 1, "ent_type": 0, "partial": 1, "exact": 0}, - "entities.ORG.correct": {"strict": 3, "ent_type": 3, "partial": 2, "exact": 2}, - "entities.ORG.incorrect": {"strict": 2, "ent_type": 1, "partial": 1, "exact": 2}, - "entities.ORG.partial": {"strict": 2, "ent_type": 0, "partial": 2, "exact": 2}, - "entities.ORG.missed": {"strict": 1, "ent_type": 0, "partial": 0, "exact": 1}, - "entities.ORG.spurious": {"strict": 1, "ent_type": 2, "partial": 2, "exact": 1}, - } - - expected_df = pd.DataFrame(expected_data) - - # Execute - result_df = evaluator.results_to_dataframe() - - # Assert - pd.testing.assert_frame_equal(result_df, expected_df) - - -def test_evaluator_simple_case(): +@pytest.fixture +def sample_data(): true = [ - [{"label": "PER", "start": 2, "end": 4}], - [ - {"label": "LOC", "start": 1, "end": 2}, - {"label": "LOC", "start": 3, "end": 4}, - ], + ["O", "B-PER", "O", "B-ORG", "I-ORG", "B-LOC"], + ["O", "B-PER", "O", "B-ORG"], ] - pred = [ - [{"label": "PER", "start": 2, "end": 4}], - [ - {"label": "LOC", "start": 1, "end": 2}, - {"label": "LOC", "start": 3, "end": 4}, - ], - ] - evaluator = Evaluator(true, pred, tags=["LOC", "PER"]) - results, _, _, _ = evaluator.evaluate() - expected = { - "strict": { - "correct": 3, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 3, - "actual": 3, - "precision": 1.0, - "recall": 1.0, - "f1": 1.0, - }, - "ent_type": { - "correct": 3, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 3, - "actual": 3, - "precision": 1.0, - "recall": 1.0, - "f1": 1.0, - }, - "partial": { - "correct": 3, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 3, - "actual": 3, - "precision": 1.0, - "recall": 1.0, - "f1": 1.0, - }, - "exact": { - "correct": 3, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 3, - "actual": 3, - "precision": 1.0, - "recall": 1.0, - "f1": 1.0, - }, - } - - assert results["strict"] == expected["strict"] - assert results["ent_type"] == expected["ent_type"] - assert results["partial"] == expected["partial"] - assert results["exact"] == expected["exact"] - -def test_evaluator_simple_case_filtered_tags(): - """Check that tags can be excluded by passing the tags argument""" - true = [ - [{"label": "PER", "start": 2, "end": 4}], - [ - {"label": "LOC", "start": 1, "end": 2}, - {"label": "LOC", "start": 3, "end": 4}, - ], - ] pred = [ - [{"label": "PER", "start": 2, "end": 4}], - [ - {"label": "LOC", "start": 1, "end": 2}, - {"label": "LOC", "start": 3, "end": 4}, - ], + ["O", "B-PER", "O", "B-ORG", "O", "B-PER"], + ["O", "B-PER", "O", "B-LOC"], ] - evaluator = Evaluator(true, pred, tags=["PER", "LOC"]) - results, _, _, _ = evaluator.evaluate() - expected = { - "strict": { - "correct": 3, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 3, - "actual": 3, - "precision": 1.0, - "recall": 1.0, - "f1": 1.0, - }, - "ent_type": { - "correct": 3, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 3, - "actual": 3, - "precision": 1.0, - "recall": 1.0, - "f1": 1.0, - }, - "partial": { - "correct": 3, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 3, - "actual": 3, - "precision": 1.0, - "recall": 1.0, - "f1": 1.0, - }, - "exact": { - "correct": 3, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 3, - "actual": 3, - "precision": 1.0, - "recall": 1.0, - "f1": 1.0, - }, - } - - assert results["strict"] == expected["strict"] - assert results["ent_type"] == expected["ent_type"] - assert results["partial"] == expected["partial"] - assert results["exact"] == expected["exact"] + return true, pred -def test_evaluator_extra_classes(): - """Case when model predicts a class that is not in the gold (true) data""" - true = [ - [{"label": "ORG", "start": 1, "end": 3}], - ] - pred = [ - [{"label": "FOO", "start": 1, "end": 3}], - ] - evaluator = Evaluator(true, pred, tags=["ORG", "FOO"]) - results, _, _, _ = evaluator.evaluate() - expected = { - "strict": { - "correct": 0, - "incorrect": 1, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 1, - "actual": 1, - "precision": 0, - "recall": 0.0, - "f1": 0, - }, - "ent_type": { - "correct": 0, - "incorrect": 1, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 1, - "actual": 1, - "precision": 0, - "recall": 0.0, - "f1": 0, - }, - "partial": { - "correct": 1, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 1, - "actual": 1, - "precision": 1.0, - "recall": 1.0, - "f1": 1.0, - }, - "exact": { - "correct": 1, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 1, - "actual": 1, - "precision": 1.0, - "recall": 1.0, - "f1": 1.0, - }, - } - assert results["strict"] == expected["strict"] - assert results["ent_type"] == expected["ent_type"] - assert results["partial"] == expected["partial"] - assert results["exact"] == expected["exact"] +def test_evaluator_initialization(sample_data): + """Test evaluator initialization.""" + true, pred = sample_data + evaluator = Evaluator(true, pred, ["PER", "ORG", "LOC"], loader="list") + assert len(evaluator.true) == 2 + assert len(evaluator.pred) == 2 + assert evaluator.tags == ["PER", "ORG", "LOC"] -def test_evaluator_no_entities_in_prediction(): - """Case when model predicts a class that is not in the gold (true) data""" - true = [ - [{"label": "PER", "start": 2, "end": 4}], - ] - pred = [ - [], - ] - evaluator = Evaluator(true, pred, tags=["PER"]) - results, _, _, _ = evaluator.evaluate() - expected = { - "strict": { - "correct": 0, - "incorrect": 0, - "partial": 0, - "missed": 1, - "spurious": 0, - "possible": 1, - "actual": 0, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "ent_type": { - "correct": 0, - "incorrect": 0, - "partial": 0, - "missed": 1, - "spurious": 0, - "possible": 1, - "actual": 0, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "partial": { - "correct": 0, - "incorrect": 0, - "partial": 0, - "missed": 1, - "spurious": 0, - "possible": 1, - "actual": 0, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "exact": { - "correct": 0, - "incorrect": 0, - "partial": 0, - "missed": 1, - "spurious": 0, - "possible": 1, - "actual": 0, - "precision": 0, - "recall": 0, - "f1": 0, - }, - } - assert results["strict"] == expected["strict"] - assert results["ent_type"] == expected["ent_type"] - assert results["partial"] == expected["partial"] - assert results["exact"] == expected["exact"] +def test_evaluator_evaluation(sample_data): + """Test evaluation process.""" + true, pred = sample_data + evaluator = Evaluator(true, pred, ["PER", "ORG", "LOC"], loader="list") + results = evaluator.evaluate() + # Check that we have results for all strategies + assert "overall" in results + assert "entities" in results + assert "strict" in results["overall"] + assert "partial" in results["overall"] + assert "ent_type" in results["overall"] -def test_evaluator_compare_results_and_results_agg(): - """Check that the label level results match the total results.""" - true = [ - [{"label": "PER", "start": 2, "end": 4}], - ] - pred = [ - [{"label": "PER", "start": 2, "end": 4}], - ] - evaluator = Evaluator(true, pred, tags=["PER"]) - results, results_agg, _, _ = evaluator.evaluate() - expected = { - "strict": { - "correct": 1, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 1, - "actual": 1, - "precision": 1, - "recall": 1, - "f1": 1, - }, - "ent_type": { - "correct": 1, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 1, - "actual": 1, - "precision": 1, - "recall": 1, - "f1": 1, - }, - "partial": { - "correct": 1, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 1, - "actual": 1, - "precision": 1, - "recall": 1, - "f1": 1, - }, - "exact": { - "correct": 1, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 1, - "actual": 1, - "precision": 1, - "recall": 1, - "f1": 1, - }, - } - expected_agg = { - "PER": { - "strict": { - "correct": 1, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 1, - "actual": 1, - "precision": 1, - "recall": 1, - "f1": 1, - }, - "ent_type": { - "correct": 1, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 1, - "actual": 1, - "precision": 1, - "recall": 1, - "f1": 1, - }, - "partial": { - "correct": 1, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 1, - "actual": 1, - "precision": 1, - "recall": 1, - "f1": 1, - }, - "exact": { - "correct": 1, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 1, - "actual": 1, - "precision": 1, - "recall": 1, - "f1": 1, - }, - } - } + # Check that we have results for each entity type + for entity in ["PER", "ORG", "LOC"]: + assert entity in results["entities"] + assert "strict" in results["entities"][entity] + assert "partial" in results["entities"][entity] + assert "ent_type" in results["entities"][entity] - assert results_agg["PER"]["strict"] == expected_agg["PER"]["strict"] - assert results_agg["PER"]["ent_type"] == expected_agg["PER"]["ent_type"] - assert results_agg["PER"]["partial"] == expected_agg["PER"]["partial"] - assert results_agg["PER"]["exact"] == expected_agg["PER"]["exact"] - assert results["strict"] == expected["strict"] - assert results["ent_type"] == expected["ent_type"] - assert results["partial"] == expected["partial"] - assert results["exact"] == expected["exact"] +def test_evaluator_with_invalid_tags(sample_data): + """Test evaluator with invalid tags.""" + true, pred = sample_data + evaluator = Evaluator(true, pred, ["INVALID"], loader="list") + results = evaluator.evaluate() - assert results["strict"] == expected_agg["PER"]["strict"] - assert results["ent_type"] == expected_agg["PER"]["ent_type"] - assert results["partial"] == expected_agg["PER"]["partial"] - assert results["exact"] == expected_agg["PER"]["exact"] + for strategy in ["strict", "partial", "ent_type"]: + assert results["overall"][strategy].correct == 0 + assert results["overall"][strategy].incorrect == 0 + assert results["overall"][strategy].partial == 0 + assert results["overall"][strategy].missed == 0 + assert results["overall"][strategy].spurious == 0 -def test_evaluator_compare_results_and_results_agg_1(): - """Test case when model predicts a label not in the test data.""" +def test_evaluator_different_document_lengths(): + """Test that Evaluator raises ValueError when documents have different lengths.""" true = [ - [], - [{"label": "ORG", "start": 2, "end": 4}], - [{"label": "MISC", "start": 2, "end": 4}], + ["O", "B-PER", "I-PER", "O", "O", "O", "B-ORG", "I-ORG"], # 8 tokens + ["O", "B-LOC", "B-PER", "I-PER", "O", "O", "B-DATE"], # 7 tokens ] pred = [ - [{"label": "PER", "start": 2, "end": 4}], - [{"label": "ORG", "start": 2, "end": 4}], - [{"label": "MISC", "start": 2, "end": 4}], - ] - evaluator = Evaluator(true, pred, tags=["PER", "ORG", "MISC"]) - results, results_agg, _, _ = evaluator.evaluate() - expected = { - "strict": { - "correct": 2, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 1, - "possible": 2, - "actual": 3, - "precision": 0.6666666666666666, - "recall": 1.0, - "f1": 0.8, - }, - "ent_type": { - "correct": 2, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 1, - "possible": 2, - "actual": 3, - "precision": 0.6666666666666666, - "recall": 1.0, - "f1": 0.8, - }, - "partial": { - "correct": 2, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 1, - "possible": 2, - "actual": 3, - "precision": 0.6666666666666666, - "recall": 1.0, - "f1": 0.8, - }, - "exact": { - "correct": 2, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 1, - "possible": 2, - "actual": 3, - "precision": 0.6666666666666666, - "recall": 1.0, - "f1": 0.8, - }, - } - expected_agg = { - "ORG": { - "strict": { - "correct": 1, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 1, - "actual": 1, - "precision": 1.0, - "recall": 1, - "f1": 1.0, - }, - "ent_type": { - "correct": 1, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 1, - "actual": 1, - "precision": 1.0, - "recall": 1, - "f1": 1.0, - }, - "partial": { - "correct": 1, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 1, - "actual": 1, - "precision": 1, - "recall": 1, - "f1": 1.0, - }, - "exact": { - "correct": 1, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 1, - "actual": 1, - "precision": 1, - "recall": 1, - "f1": 1, - }, - }, - "MISC": { - "strict": { - "correct": 1, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 1, - "actual": 1, - "precision": 1, - "recall": 1, - "f1": 1, - }, - "ent_type": { - "correct": 1, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 1, - "actual": 1, - "precision": 1, - "recall": 1, - "f1": 1, - }, - "partial": { - "correct": 1, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 1, - "actual": 1, - "precision": 1, - "recall": 1, - "f1": 1, - }, - "exact": { - "correct": 1, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 1, - "actual": 1, - "precision": 1, - "recall": 1, - "f1": 1, - }, - }, - } - - assert results_agg["ORG"]["strict"] == expected_agg["ORG"]["strict"] - assert results_agg["ORG"]["ent_type"] == expected_agg["ORG"]["ent_type"] - assert results_agg["ORG"]["partial"] == expected_agg["ORG"]["partial"] - assert results_agg["ORG"]["exact"] == expected_agg["ORG"]["exact"] - - assert results_agg["MISC"]["strict"] == expected_agg["MISC"]["strict"] - assert results_agg["MISC"]["ent_type"] == expected_agg["MISC"]["ent_type"] - assert results_agg["MISC"]["partial"] == expected_agg["MISC"]["partial"] - assert results_agg["MISC"]["exact"] == expected_agg["MISC"]["exact"] - - assert results["strict"] == expected["strict"] - assert results["ent_type"] == expected["ent_type"] - assert results["partial"] == expected["partial"] - assert results["exact"] == expected["exact"] - - -def test_evaluator_with_extra_keys_in_pred(): - true = [ - [{"label": "PER", "start": 2, "end": 4}], - [ - {"label": "LOC", "start": 1, "end": 2}, - {"label": "LOC", "start": 3, "end": 4}, - ], + ["O", "B-PER", "I-PER", "O", "O", "O", "B-ORG", "I-ORG"], # 8 tokens + ["O", "B-LOC", "I-LOC", "O", "B-PER", "I-PER", "O", "B-DATE", "I-DATE", "O"], # 10 tokens ] - pred = [ - [{"label": "PER", "start": 2, "end": 4, "token_start": 0, "token_end": 5}], - [ - {"label": "LOC", "start": 1, "end": 2, "token_start": 0, "token_end": 6}, - {"label": "LOC", "start": 3, "end": 4, "token_start": 0, "token_end": 3}, - ], - ] - evaluator = Evaluator(true, pred, tags=["LOC", "PER"]) - results, _, _, _ = evaluator.evaluate() - expected = { - "strict": { - "correct": 3, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 3, - "actual": 3, - "precision": 1.0, - "recall": 1.0, - "f1": 1.0, - }, - "ent_type": { - "correct": 3, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 3, - "actual": 3, - "precision": 1.0, - "recall": 1.0, - "f1": 1.0, - }, - "partial": { - "correct": 3, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 3, - "actual": 3, - "precision": 1.0, - "recall": 1.0, - "f1": 1.0, - }, - "exact": { - "correct": 3, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 3, - "actual": 3, - "precision": 1.0, - "recall": 1.0, - "f1": 1.0, - }, - } - - assert results["strict"] == expected["strict"] - assert results["ent_type"] == expected["ent_type"] - assert results["partial"] == expected["partial"] - assert results["exact"] == expected["exact"] - - -def test_evaluator_with_extra_keys_in_true(): - true = [ - [{"label": "PER", "start": 2, "end": 4, "token_start": 0, "token_end": 4}], - [ - {"label": "LOC", "start": 1, "end": 2, "token_start": 0, "token_end": 5}, - {"label": "LOC", "start": 3, "end": 4, "token_start": 7, "token_end": 9}, - ], - ] - pred = [ - [{"label": "PER", "start": 2, "end": 4}], - [ - {"label": "LOC", "start": 1, "end": 2}, - {"label": "LOC", "start": 3, "end": 4}, - ], - ] - evaluator = Evaluator(true, pred, tags=["LOC", "PER"]) - results, _, _, _ = evaluator.evaluate() - expected = { - "strict": { - "correct": 3, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 3, - "actual": 3, - "precision": 1.0, - "recall": 1.0, - "f1": 1.0, - }, - "ent_type": { - "correct": 3, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 3, - "actual": 3, - "precision": 1.0, - "recall": 1.0, - "f1": 1.0, - }, - "partial": { - "correct": 3, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 3, - "actual": 3, - "precision": 1.0, - "recall": 1.0, - "f1": 1.0, - }, - "exact": { - "correct": 3, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 3, - "actual": 3, - "precision": 1.0, - "recall": 1.0, - "f1": 1.0, - }, - } - - assert results["strict"] == expected["strict"] - assert results["ent_type"] == expected["ent_type"] - assert results["partial"] == expected["partial"] - assert results["exact"] == expected["exact"] - - -def test_issue_29(): - true = [ - [ - {"label": "PER", "start": 1, "end": 2}, - {"label": "PER", "start": 3, "end": 10}, - ] - ] - pred = [ - [ - {"label": "PER", "start": 1, "end": 2}, - {"label": "PER", "start": 3, "end": 5}, - {"label": "PER", "start": 6, "end": 10}, - ] - ] - evaluator = Evaluator(true, pred, tags=["PER"]) - results, _, _, _ = evaluator.evaluate() - expected = { - "strict": { - "correct": 1, - "incorrect": 1, - "partial": 0, - "missed": 0, - "spurious": 1, - "possible": 2, - "actual": 3, - "precision": 0.3333333333333333, - "recall": 0.5, - "f1": 0.4, - }, - "ent_type": { - "correct": 2, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 1, - "possible": 2, - "actual": 3, - "precision": 0.6666666666666666, - "recall": 1.0, - "f1": 0.8, - }, - "partial": { - "correct": 1, - "incorrect": 0, - "partial": 1, - "missed": 0, - "spurious": 1, - "possible": 2, - "actual": 3, - "precision": 0.5, - "recall": 0.75, - "f1": 0.6, - }, - "exact": { - "correct": 1, - "incorrect": 1, - "partial": 0, - "missed": 0, - "spurious": 1, - "possible": 2, - "actual": 3, - "precision": 0.3333333333333333, - "recall": 0.5, - "f1": 0.4, - }, - } - - assert results["strict"] == expected["strict"] - assert results["ent_type"] == expected["ent_type"] - assert results["partial"] == expected["partial"] - assert results["exact"] == expected["exact"] - - -def test_evaluator_compare_results_indices_and_results_agg_indices(): - """Check that the label level results match the total results.""" - true = [ - [{"label": "PER", "start": 2, "end": 4}], - ] - pred = [ - [{"label": "PER", "start": 2, "end": 4}], - ] - evaluator = Evaluator(true, pred, tags=["PER"]) - _, _, evaluation_indices, evaluation_agg_indices = evaluator.evaluate() - expected_evaluation_indices = { - "strict": { - "correct_indices": [(0, 0)], - "incorrect_indices": [], - "partial_indices": [], - "missed_indices": [], - "spurious_indices": [], - }, - "ent_type": { - "correct_indices": [(0, 0)], - "incorrect_indices": [], - "partial_indices": [], - "missed_indices": [], - "spurious_indices": [], - }, - "partial": { - "correct_indices": [(0, 0)], - "incorrect_indices": [], - "partial_indices": [], - "missed_indices": [], - "spurious_indices": [], - }, - "exact": { - "correct_indices": [(0, 0)], - "incorrect_indices": [], - "partial_indices": [], - "missed_indices": [], - "spurious_indices": [], - }, - } - expected_evaluation_agg_indices = { - "PER": { - "strict": { - "correct_indices": [(0, 0)], - "incorrect_indices": [], - "partial_indices": [], - "missed_indices": [], - "spurious_indices": [], - }, - "ent_type": { - "correct_indices": [(0, 0)], - "incorrect_indices": [], - "partial_indices": [], - "missed_indices": [], - "spurious_indices": [], - }, - "partial": { - "correct_indices": [(0, 0)], - "incorrect_indices": [], - "partial_indices": [], - "missed_indices": [], - "spurious_indices": [], - }, - "exact": { - "correct_indices": [(0, 0)], - "incorrect_indices": [], - "partial_indices": [], - "missed_indices": [], - "spurious_indices": [], - }, - } - } - assert evaluation_agg_indices["PER"]["strict"] == expected_evaluation_agg_indices["PER"]["strict"] - assert evaluation_agg_indices["PER"]["ent_type"] == expected_evaluation_agg_indices["PER"]["ent_type"] - assert evaluation_agg_indices["PER"]["partial"] == expected_evaluation_agg_indices["PER"]["partial"] - assert evaluation_agg_indices["PER"]["exact"] == expected_evaluation_agg_indices["PER"]["exact"] - - assert evaluation_indices["strict"] == expected_evaluation_indices["strict"] - assert evaluation_indices["ent_type"] == expected_evaluation_indices["ent_type"] - assert evaluation_indices["partial"] == expected_evaluation_indices["partial"] - assert evaluation_indices["exact"] == expected_evaluation_indices["exact"] - - assert evaluation_indices["strict"] == expected_evaluation_agg_indices["PER"]["strict"] - assert evaluation_indices["ent_type"] == expected_evaluation_agg_indices["PER"]["ent_type"] - assert evaluation_indices["partial"] == expected_evaluation_agg_indices["PER"]["partial"] - assert evaluation_indices["exact"] == expected_evaluation_agg_indices["PER"]["exact"] - - -def test_evaluator_compare_results_indices_and_results_agg_indices_1(): - """Test case when model predicts a label not in the test data.""" - true = [ - [], - [{"label": "ORG", "start": 2, "end": 4}], - [{"label": "MISC", "start": 2, "end": 4}], - ] - pred = [ - [{"label": "PER", "start": 2, "end": 4}], - [{"label": "ORG", "start": 2, "end": 4}], - [{"label": "MISC", "start": 2, "end": 4}], - ] - evaluator = Evaluator(true, pred, tags=["PER", "ORG", "MISC"]) - _, _, evaluation_indices, evaluation_agg_indices = evaluator.evaluate() - - expected_evaluation_indices = { - "strict": { - "correct_indices": [(1, 0), (2, 0)], - "incorrect_indices": [], - "partial_indices": [], - "missed_indices": [], - "spurious_indices": [(0, 0)], - }, - "ent_type": { - "correct_indices": [(1, 0), (2, 0)], - "incorrect_indices": [], - "partial_indices": [], - "missed_indices": [], - "spurious_indices": [(0, 0)], - }, - "partial": { - "correct_indices": [(1, 0), (2, 0)], - "incorrect_indices": [], - "partial_indices": [], - "missed_indices": [], - "spurious_indices": [(0, 0)], - }, - "exact": { - "correct_indices": [(1, 0), (2, 0)], - "incorrect_indices": [], - "partial_indices": [], - "missed_indices": [], - "spurious_indices": [(0, 0)], - }, - } - expected_evaluation_agg_indices = { - "PER": { - "strict": { - "correct_indices": [], - "incorrect_indices": [], - "partial_indices": [], - "missed_indices": [], - "spurious_indices": [(0, 0)], - }, - "ent_type": { - "correct_indices": [], - "incorrect_indices": [], - "partial_indices": [], - "missed_indices": [], - "spurious_indices": [(0, 0)], - }, - "partial": { - "correct_indices": [], - "incorrect_indices": [], - "partial_indices": [], - "missed_indices": [], - "spurious_indices": [(0, 0)], - }, - "exact": { - "correct_indices": [], - "incorrect_indices": [], - "partial_indices": [], - "missed_indices": [], - "spurious_indices": [(0, 0)], - }, - }, - "ORG": { - "strict": { - "correct_indices": [(1, 0)], - "incorrect_indices": [], - "partial_indices": [], - "missed_indices": [], - "spurious_indices": [], - }, - "ent_type": { - "correct_indices": [(1, 0)], - "incorrect_indices": [], - "partial_indices": [], - "missed_indices": [], - "spurious_indices": [], - }, - "partial": { - "correct_indices": [(1, 0)], - "incorrect_indices": [], - "partial_indices": [], - "missed_indices": [], - "spurious_indices": [], - }, - "exact": { - "correct_indices": [(1, 0)], - "incorrect_indices": [], - "partial_indices": [], - "missed_indices": [], - "spurious_indices": [], - }, - }, - "MISC": { - "strict": { - "correct_indices": [(2, 0)], - "incorrect_indices": [], - "partial_indices": [], - "missed_indices": [], - "spurious_indices": [], - }, - "ent_type": { - "correct_indices": [(2, 0)], - "incorrect_indices": [], - "partial_indices": [], - "missed_indices": [], - "spurious_indices": [], - }, - "partial": { - "correct_indices": [(2, 0)], - "incorrect_indices": [], - "partial_indices": [], - "missed_indices": [], - "spurious_indices": [], - }, - "exact": { - "correct_indices": [(2, 0)], - "incorrect_indices": [], - "partial_indices": [], - "missed_indices": [], - "spurious_indices": [], - }, - }, - } - assert evaluation_agg_indices["ORG"]["strict"] == expected_evaluation_agg_indices["ORG"]["strict"] - assert evaluation_agg_indices["ORG"]["ent_type"] == expected_evaluation_agg_indices["ORG"]["ent_type"] - assert evaluation_agg_indices["ORG"]["partial"] == expected_evaluation_agg_indices["ORG"]["partial"] - assert evaluation_agg_indices["ORG"]["exact"] == expected_evaluation_agg_indices["ORG"]["exact"] - - assert evaluation_agg_indices["MISC"]["strict"] == expected_evaluation_agg_indices["MISC"]["strict"] - assert evaluation_agg_indices["MISC"]["ent_type"] == expected_evaluation_agg_indices["MISC"]["ent_type"] - assert evaluation_agg_indices["MISC"]["partial"] == expected_evaluation_agg_indices["MISC"]["partial"] - assert evaluation_agg_indices["MISC"]["exact"] == expected_evaluation_agg_indices["MISC"]["exact"] + tags = ["PER", "ORG", "LOC", "DATE"] - assert evaluation_indices["strict"] == expected_evaluation_indices["strict"] - assert evaluation_indices["ent_type"] == expected_evaluation_indices["ent_type"] - assert evaluation_indices["partial"] == expected_evaluation_indices["partial"] - assert evaluation_indices["exact"] == expected_evaluation_indices["exact"] + # Test that ValueError is raised + with pytest.raises(ValueError, match="Document 1 has different lengths: true=7, pred=10"): + evaluator = Evaluator(true=true, pred=pred, tags=tags, loader="list") + evaluator.evaluate() diff --git a/tests/test_loaders.py b/tests/test_loaders.py index 80cc921..99316ec 100644 --- a/tests/test_loaders.py +++ b/tests/test_loaders.py @@ -1,7 +1,61 @@ -from nervaluate import Evaluator +import pytest +from nervaluate.loaders import ConllLoader, ListLoader, DictLoader -def test_loaders_produce_the_same_results(): + +def test_conll_loader(): + """Test CoNLL format loader.""" + true_conll = ( + "word\tO\nword\tO\nword\tO\nword\tO\nword\tO\nword\tO\n\n" + "word\tO\nword\tO\nword\tB-ORG\nword\tI-ORG\nword\tO\nword\tO\n\n" + "word\tO\nword\tO\nword\tB-MISC\nword\tI-MISC\nword\tO\nword\tO\n\n" + "word\tB-MISC\nword\tI-MISC\nword\tI-MISC\nword\tI-MISC\nword\tI-MISC\nword\tI-MISC\n" + ) + + pred_conll = ( + "word\tO\nword\tO\nword\tB-PER\nword\tI-PER\nword\tO\nword\tO\n\n" + "word\tO\nword\tO\nword\tB-ORG\nword\tI-ORG\nword\tO\nword\tO\n\n" + "word\tO\nword\tO\nword\tB-MISC\nword\tI-MISC\nword\tO\nword\tO\n\n" + "word\tB-MISC\nword\tI-MISC\nword\tI-MISC\nword\tI-MISC\nword\tI-MISC\nword\tI-MISC\n" + ) + + loader = ConllLoader() + true_entities = loader.load(true_conll) + pred_entities = loader.load(pred_conll) + + # Test true entities + assert len(true_entities) == 4 # Four documents + assert len(true_entities[0]) == 0 # First document has no entities (all O tags) + assert len(true_entities[1]) == 1 # Second document has 1 entity (ORG) + assert len(true_entities[2]) == 1 # Third document has 1 entity (MISC) + assert len(true_entities[3]) == 1 # Fourth document has 1 entity (MISC) + + # Check first entity in second document + assert true_entities[1][0].label == "ORG" + assert true_entities[1][0].start == 2 + assert true_entities[1][0].end == 3 + + # Test pred entities + assert len(pred_entities) == 4 # Four documents + assert len(pred_entities[0]) == 1 # First document has 1 entity (PER) + assert len(pred_entities[1]) == 1 # Second document has 1 entity (ORG) + assert len(pred_entities[2]) == 1 # Third document has 1 entity (MISC) + assert len(pred_entities[3]) == 1 # Fourth document has 1 entity (MISC) + + # Check first entity in first document + assert pred_entities[0][0].label == "PER" + assert pred_entities[0][0].start == 2 + assert pred_entities[0][0].end == 3 + + # Test empty document handling + empty_doc = "word\tO\nword\tO\nword\tO\n\n" + empty_entities = loader.load(empty_doc) + assert len(empty_entities) == 1 # One document + assert len(empty_entities[0]) == 0 # Empty list for document with only O tags + + +def test_list_loader(): + """Test list format loader.""" true_list = [ ["O", "O", "O", "O", "O", "O"], ["O", "O", "B-ORG", "I-ORG", "O", "O"], @@ -16,20 +70,51 @@ def test_loaders_produce_the_same_results(): ["B-MISC", "I-MISC", "I-MISC", "I-MISC", "I-MISC", "I-MISC"], ] - true_conll = ( - "word\tO\nword\tO\nword\tO\nword\tO\nword\tO\nword\tO\n\n" - "word\tO\nword\tO\nword\tB-ORG\nword\tI-ORG\nword\tO\nword\tO\n\n" - "word\tO\nword\tO\nword\tB-MISC\nword\tI-MISC\nword\tO\nword\tO\n\n" - "word\tB-MISC\nword\tI-MISC\nword\tI-MISC\nword\tI-MISC\nword\tI-MISC\nword\tI-MISC\n" - ) + loader = ListLoader() + true_entities = loader.load(true_list) + pred_entities = loader.load(pred_list) - pred_conll = ( - "word\tO\nword\tO\nword\tB-PER\nword\tI-PER\nword\tO\nword\tO\n\n" - "word\tO\nword\tO\nword\tB-ORG\nword\tI-ORG\nword\tO\nword\tO\n\n" - "word\tO\nword\tO\nword\tB-MISC\nword\tI-MISC\nword\tO\nword\tO\n\n" - "word\tB-MISC\nword\tI-MISC\nword\tI-MISC\nword\tI-MISC\nword\tI-MISC\nword\tI-MISC\n" - ) + # Test true entities + assert len(true_entities) == 4 # Four documents + assert len(true_entities[0]) == 0 # First document has no entities (all O tags) + assert len(true_entities[1]) == 1 # Second document has 1 entity (ORG) + assert len(true_entities[2]) == 1 # Third document has 1 entity (MISC) + assert len(true_entities[3]) == 1 # Fourth document has 1 entity (MISC) + + # Check no entities in the first document + assert len(true_entities[0]) == 0 + + # Check first entity in second document + assert true_entities[1][0].label == "ORG" + assert true_entities[1][0].start == 2 + assert true_entities[1][0].end == 3 + + # Check only entity in the last document + assert true_entities[3][0].label == "MISC" + assert true_entities[3][0].start == 0 + assert true_entities[3][0].end == 5 + + # Test pred entities + assert len(pred_entities) == 4 # Four documents + assert len(pred_entities[0]) == 1 # First document has 1 entity (PER) + assert len(pred_entities[1]) == 1 # Second document has 1 entity (ORG) + assert len(pred_entities[2]) == 1 # Third document has 1 entity (MISC) + assert len(pred_entities[3]) == 1 # Fourth document has 1 entity (MISC) + + # Check first entity in first document + assert pred_entities[0][0].label == "PER" + assert pred_entities[0][0].start == 2 + assert pred_entities[0][0].end == 3 + # Test empty document handling + empty_doc = [["O", "O", "O"]] + empty_entities = loader.load(empty_doc) + assert len(empty_entities) == 1 # One document + assert len(empty_entities[0]) == 0 # Empty list for document with only O tags + + +def test_dict_loader(): + """Test dictionary format loader.""" true_prod = [ [], [{"label": "ORG", "start": 2, "end": 3}], @@ -44,15 +129,71 @@ def test_loaders_produce_the_same_results(): [{"label": "MISC", "start": 0, "end": 5}], ] - evaluator_list = Evaluator(true_list, pred_list, tags=["PER", "ORG", "MISC"], loader="list") + loader = DictLoader() + true_entities = loader.load(true_prod) + pred_entities = loader.load(pred_prod) + + # Test true entities + assert len(true_entities) == 4 # Four documents + assert len(true_entities[0]) == 0 # First document has no entities + assert len(true_entities[1]) == 1 # Second document has 1 entity (ORG) + assert len(true_entities[2]) == 1 # Third document has 1 entity (MISC) + assert len(true_entities[3]) == 1 # Fourth document has 1 entity (MISC) + + # Check first entity in second document + assert true_entities[1][0].label == "ORG" + assert true_entities[1][0].start == 2 + assert true_entities[1][0].end == 3 + + # Check only entity in the last document + assert true_entities[3][0].label == "MISC" + assert true_entities[3][0].start == 0 + assert true_entities[3][0].end == 5 + + # Test pred entities + assert len(pred_entities) == 4 # Four documents + assert len(pred_entities[0]) == 1 # First document has 1 entity (PER) + assert len(pred_entities[1]) == 1 # Second document has 1 entity (ORG) + assert len(pred_entities[2]) == 1 # Third document has 1 entity (MISC) + assert len(pred_entities[3]) == 1 # Fourth document has 1 entity (MISC) + + # Check first entity in first document + assert pred_entities[0][0].label == "PER" + assert pred_entities[0][0].start == 2 + assert pred_entities[0][0].end == 3 + + # Test empty document handling + empty_doc = [[]] + empty_entities = loader.load(empty_doc) + assert len(empty_entities) == 1 # One document + assert len(empty_entities[0]) == 0 # Empty list for empty document + + +def test_loader_with_empty_input(): + """Test loaders with empty input.""" + # Test ConllLoader with empty string + conll_loader = ConllLoader() + entities = conll_loader.load("") + assert len(entities) == 0 + + # Test ListLoader with empty list + list_loader = ListLoader() + entities = list_loader.load([]) + assert len(entities) == 0 + + # Test DictLoader with empty list + dict_loader = DictLoader() + entities = dict_loader.load([]) + assert len(entities) == 0 - evaluator_conll = Evaluator(true_conll, pred_conll, tags=["PER", "ORG", "MISC"], loader="conll") - evaluator_prod = Evaluator(true_prod, pred_prod, tags=["PER", "ORG", "MISC"]) +def test_loader_with_invalid_data(): + """Test loaders with invalid data.""" + with pytest.raises(Exception): + ConllLoader().load("invalid\tdata") - _, _, _, _ = evaluator_list.evaluate() - _, _, _, _ = evaluator_prod.evaluate() - _, _, _, _ = evaluator_conll.evaluate() + with pytest.raises(Exception): + ListLoader().load([["invalid"]]) - assert evaluator_prod.pred == evaluator_list.pred == evaluator_conll.pred - assert evaluator_prod.true == evaluator_list.true == evaluator_conll.true + with pytest.raises(Exception): + DictLoader().load([[{"invalid": "data"}]]) diff --git a/tests/test_nervaluate.py b/tests/test_nervaluate.py deleted file mode 100644 index 56fca6e..0000000 --- a/tests/test_nervaluate.py +++ /dev/null @@ -1,909 +0,0 @@ -import pytest - -from nervaluate import ( - Evaluator, - compute_actual_possible, - compute_metrics, - compute_precision_recall, - compute_precision_recall_wrapper, -) - - -def test_compute_metrics_case_1(): - true_named_entities = [ - {"label": "PER", "start": 59, "end": 69}, - {"label": "LOC", "start": 127, "end": 134}, - {"label": "LOC", "start": 164, "end": 174}, - {"label": "LOC", "start": 197, "end": 205}, - {"label": "LOC", "start": 208, "end": 219}, - {"label": "MISC", "start": 230, "end": 240}, - ] - pred_named_entities = [ - {"label": "PER", "start": 24, "end": 30}, - {"label": "LOC", "start": 124, "end": 134}, - {"label": "PER", "start": 164, "end": 174}, - {"label": "LOC", "start": 197, "end": 205}, - {"label": "LOC", "start": 208, "end": 219}, - {"label": "LOC", "start": 225, "end": 243}, - ] - results, _, _, _ = compute_metrics(true_named_entities, pred_named_entities, ["PER", "LOC", "MISC"]) - results = compute_precision_recall_wrapper(results) - expected = { - "strict": { - "correct": 2, - "incorrect": 3, - "partial": 0, - "missed": 1, - "spurious": 1, - "possible": 6, - "actual": 6, - "precision": 0.3333333333333333, - "recall": 0.3333333333333333, - "f1": 0.3333333333333333, - }, - "ent_type": { - "correct": 3, - "incorrect": 2, - "partial": 0, - "missed": 1, - "spurious": 1, - "possible": 6, - "actual": 6, - "precision": 0.5, - "recall": 0.5, - "f1": 0.5, - }, - "partial": { - "correct": 3, - "incorrect": 0, - "partial": 2, - "missed": 1, - "spurious": 1, - "possible": 6, - "actual": 6, - "precision": 0.6666666666666666, - "recall": 0.6666666666666666, - "f1": 0.6666666666666666, - }, - "exact": { - "correct": 3, - "incorrect": 2, - "partial": 0, - "missed": 1, - "spurious": 1, - "possible": 6, - "actual": 6, - "precision": 0.5, - "recall": 0.5, - "f1": 0.5, - }, - } - assert results == expected - - -def test_compute_metrics_agg_scenario_3(): - true_named_entities = [{"label": "PER", "start": 59, "end": 69}] - pred_named_entities = [] - _, results_agg, _, _ = compute_metrics(true_named_entities, pred_named_entities, ["PER"]) - expected_agg = { - "PER": { - "strict": { - "correct": 0, - "incorrect": 0, - "partial": 0, - "missed": 1, - "spurious": 0, - "actual": 0, - "possible": 1, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "ent_type": { - "correct": 0, - "incorrect": 0, - "partial": 0, - "missed": 1, - "spurious": 0, - "actual": 0, - "possible": 1, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "partial": { - "correct": 0, - "incorrect": 0, - "partial": 0, - "missed": 1, - "spurious": 0, - "actual": 0, - "possible": 1, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "exact": { - "correct": 0, - "incorrect": 0, - "partial": 0, - "missed": 1, - "spurious": 0, - "actual": 0, - "possible": 1, - "precision": 0, - "recall": 0, - "f1": 0, - }, - } - } - - assert results_agg["PER"]["strict"] == expected_agg["PER"]["strict"] - assert results_agg["PER"]["ent_type"] == expected_agg["PER"]["ent_type"] - assert results_agg["PER"]["partial"] == expected_agg["PER"]["partial"] - assert results_agg["PER"]["exact"] == expected_agg["PER"]["exact"] - - -def test_compute_metrics_agg_scenario_2(): - true_named_entities = [] - pred_named_entities = [{"label": "PER", "start": 59, "end": 69}] - _, results_agg, _, _ = compute_metrics(true_named_entities, pred_named_entities, ["PER"]) - expected_agg = { - "PER": { - "strict": { - "correct": 0, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 1, - "actual": 1, - "possible": 0, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "ent_type": { - "correct": 0, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 1, - "actual": 1, - "possible": 0, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "partial": { - "correct": 0, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 1, - "actual": 1, - "possible": 0, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "exact": { - "correct": 0, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 1, - "actual": 1, - "possible": 0, - "precision": 0, - "recall": 0, - "f1": 0, - }, - } - } - - assert results_agg["PER"]["strict"] == expected_agg["PER"]["strict"] - assert results_agg["PER"]["ent_type"] == expected_agg["PER"]["ent_type"] - assert results_agg["PER"]["partial"] == expected_agg["PER"]["partial"] - assert results_agg["PER"]["exact"] == expected_agg["PER"]["exact"] - - -def test_compute_metrics_agg_scenario_5(): - true_named_entities = [{"label": "PER", "start": 59, "end": 69}] - pred_named_entities = [{"label": "PER", "start": 57, "end": 69}] - _, results_agg, _, _ = compute_metrics(true_named_entities, pred_named_entities, ["PER"]) - expected_agg = { - "PER": { - "strict": { - "correct": 0, - "incorrect": 1, - "partial": 0, - "missed": 0, - "spurious": 0, - "actual": 1, - "possible": 1, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "ent_type": { - "correct": 1, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "actual": 1, - "possible": 1, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "partial": { - "correct": 0, - "incorrect": 0, - "partial": 1, - "missed": 0, - "spurious": 0, - "actual": 1, - "possible": 1, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "exact": { - "correct": 0, - "incorrect": 1, - "partial": 0, - "missed": 0, - "spurious": 0, - "actual": 1, - "possible": 1, - "precision": 0, - "recall": 0, - "f1": 0, - }, - } - } - - assert results_agg["PER"]["strict"] == expected_agg["PER"]["strict"] - assert results_agg["PER"]["ent_type"] == expected_agg["PER"]["ent_type"] - assert results_agg["PER"]["partial"] == expected_agg["PER"]["partial"] - assert results_agg["PER"]["exact"] == expected_agg["PER"]["exact"] - - -def test_compute_metrics_agg_scenario_4(): - true_named_entities = [{"label": "PER", "start": 59, "end": 69}] - pred_named_entities = [{"label": "LOC", "start": 59, "end": 69}] - _, results_agg, _, _ = compute_metrics(true_named_entities, pred_named_entities, ["PER", "LOC"]) - expected_agg = { - "PER": { - "strict": { - "correct": 0, - "incorrect": 1, - "partial": 0, - "missed": 0, - "spurious": 0, - "actual": 1, - "possible": 1, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "ent_type": { - "correct": 0, - "incorrect": 1, - "partial": 0, - "missed": 0, - "spurious": 0, - "actual": 1, - "possible": 1, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "partial": { - "correct": 1, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "actual": 1, - "possible": 1, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "exact": { - "correct": 1, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "actual": 1, - "possible": 1, - "precision": 0, - "recall": 0, - "f1": 0, - }, - }, - "LOC": { - "strict": { - "correct": 0, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "actual": 0, - "possible": 0, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "ent_type": { - "correct": 0, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "actual": 0, - "possible": 0, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "partial": { - "correct": 0, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "actual": 0, - "possible": 0, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "exact": { - "correct": 0, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "actual": 0, - "possible": 0, - "precision": 0, - "recall": 0, - "f1": 0, - }, - }, - } - - assert results_agg["PER"]["strict"] == expected_agg["PER"]["strict"] - assert results_agg["PER"]["ent_type"] == expected_agg["PER"]["ent_type"] - assert results_agg["PER"]["partial"] == expected_agg["PER"]["partial"] - assert results_agg["PER"]["exact"] == expected_agg["PER"]["exact"] - assert results_agg["LOC"] == expected_agg["LOC"] - - -def test_compute_metrics_agg_scenario_1(): - true_named_entities = [{"label": "PER", "start": 59, "end": 69}] - pred_named_entities = [{"label": "PER", "start": 59, "end": 69}] - _, results_agg, _, _ = compute_metrics(true_named_entities, pred_named_entities, ["PER"]) - expected_agg = { - "PER": { - "strict": { - "correct": 1, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "actual": 1, - "possible": 1, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "ent_type": { - "correct": 1, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "actual": 1, - "possible": 1, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "partial": { - "correct": 1, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "actual": 1, - "possible": 1, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "exact": { - "correct": 1, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "actual": 1, - "possible": 1, - "precision": 0, - "recall": 0, - "f1": 0, - }, - } - } - - assert results_agg["PER"]["strict"] == expected_agg["PER"]["strict"] - assert results_agg["PER"]["ent_type"] == expected_agg["PER"]["ent_type"] - assert results_agg["PER"]["partial"] == expected_agg["PER"]["partial"] - assert results_agg["PER"]["exact"] == expected_agg["PER"]["exact"] - - -def test_compute_metrics_agg_scenario_6(): - true_named_entities = [{"label": "PER", "start": 59, "end": 69}] - pred_named_entities = [{"label": "LOC", "start": 54, "end": 69}] - _, results_agg, _, _ = compute_metrics(true_named_entities, pred_named_entities, ["PER", "LOC"]) - expected_agg = { - "PER": { - "strict": { - "correct": 0, - "incorrect": 1, - "partial": 0, - "missed": 0, - "spurious": 0, - "actual": 1, - "possible": 1, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "ent_type": { - "correct": 0, - "incorrect": 1, - "partial": 0, - "missed": 0, - "spurious": 0, - "actual": 1, - "possible": 1, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "partial": { - "correct": 0, - "incorrect": 0, - "partial": 1, - "missed": 0, - "spurious": 0, - "actual": 1, - "possible": 1, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "exact": { - "correct": 0, - "incorrect": 1, - "partial": 0, - "missed": 0, - "spurious": 0, - "actual": 1, - "possible": 1, - "precision": 0, - "recall": 0, - "f1": 0, - }, - }, - "LOC": { - "strict": { - "correct": 0, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "actual": 0, - "possible": 0, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "ent_type": { - "correct": 0, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "actual": 0, - "possible": 0, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "partial": { - "correct": 0, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "actual": 0, - "possible": 0, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "exact": { - "correct": 0, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 0, - "actual": 0, - "possible": 0, - "precision": 0, - "recall": 0, - "f1": 0, - }, - }, - } - - assert results_agg["PER"]["strict"] == expected_agg["PER"]["strict"] - assert results_agg["PER"]["ent_type"] == expected_agg["PER"]["ent_type"] - assert results_agg["PER"]["partial"] == expected_agg["PER"]["partial"] - assert results_agg["PER"]["exact"] == expected_agg["PER"]["exact"] - assert results_agg["LOC"] == expected_agg["LOC"] - - -def test_compute_metrics_extra_tags_in_prediction(): - true_named_entities = [ - {"label": "PER", "start": 50, "end": 52}, - {"label": "ORG", "start": 59, "end": 69}, - {"label": "ORG", "start": 71, "end": 72}, - ] - - pred_named_entities = [ - {"label": "LOC", "start": 50, "end": 52}, # Wrong type - {"label": "ORG", "start": 59, "end": 69}, # Correct - {"label": "MISC", "start": 71, "end": 72}, # Wrong type - ] - results, _, _, _ = compute_metrics(true_named_entities, pred_named_entities, ["PER", "LOC", "ORG"]) - expected = { - "strict": { - "correct": 1, - "incorrect": 1, - "partial": 0, - "missed": 1, - "spurious": 0, - "actual": 2, - "possible": 3, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "ent_type": { - "correct": 1, - "incorrect": 1, - "partial": 0, - "missed": 1, - "spurious": 0, - "actual": 2, - "possible": 3, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "partial": { - "correct": 2, - "incorrect": 0, - "partial": 0, - "missed": 1, - "spurious": 0, - "actual": 2, - "possible": 3, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "exact": { - "correct": 2, - "incorrect": 0, - "partial": 0, - "missed": 1, - "spurious": 0, - "actual": 2, - "possible": 3, - "precision": 0, - "recall": 0, - "f1": 0, - }, - } - - assert results["strict"] == expected["strict"] - assert results["ent_type"] == expected["ent_type"] - assert results["partial"] == expected["partial"] - assert results["exact"] == expected["exact"] - - -def test_compute_metrics_extra_tags_in_true(): - true_named_entities = [ - {"label": "PER", "start": 50, "end": 52}, - {"label": "ORG", "start": 59, "end": 69}, - {"label": "MISC", "start": 71, "end": 72}, - ] - - pred_named_entities = [ - {"label": "LOC", "start": 50, "end": 52}, # Wrong type - {"label": "ORG", "start": 59, "end": 69}, # Correct - {"label": "ORG", "start": 71, "end": 72}, # Spurious - ] - - results, _, _, _ = compute_metrics(true_named_entities, pred_named_entities, ["PER", "LOC", "ORG"]) - - expected = { - "strict": { - "correct": 1, - "incorrect": 1, - "partial": 0, - "missed": 0, - "spurious": 1, - "actual": 3, - "possible": 2, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "ent_type": { - "correct": 1, - "incorrect": 1, - "partial": 0, - "missed": 0, - "spurious": 1, - "actual": 3, - "possible": 2, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "partial": { - "correct": 2, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 1, - "actual": 3, - "possible": 2, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "exact": { - "correct": 2, - "incorrect": 0, - "partial": 0, - "missed": 0, - "spurious": 1, - "actual": 3, - "possible": 2, - "precision": 0, - "recall": 0, - "f1": 0, - }, - } - - assert results["strict"] == expected["strict"] - assert results["ent_type"] == expected["ent_type"] - assert results["partial"] == expected["partial"] - assert results["exact"] == expected["exact"] - - -def test_compute_metrics_no_predictions(): - true_named_entities = [ - {"label": "PER", "start": 50, "end": 52}, - {"label": "ORG", "start": 59, "end": 69}, - {"label": "MISC", "start": 71, "end": 72}, - ] - pred_named_entities = [] - results, _, _, _ = compute_metrics(true_named_entities, pred_named_entities, ["PER", "ORG", "MISC"]) - expected = { - "strict": { - "correct": 0, - "incorrect": 0, - "partial": 0, - "missed": 3, - "spurious": 0, - "actual": 0, - "possible": 3, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "ent_type": { - "correct": 0, - "incorrect": 0, - "partial": 0, - "missed": 3, - "spurious": 0, - "actual": 0, - "possible": 3, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "partial": { - "correct": 0, - "incorrect": 0, - "partial": 0, - "missed": 3, - "spurious": 0, - "actual": 0, - "possible": 3, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "exact": { - "correct": 0, - "incorrect": 0, - "partial": 0, - "missed": 3, - "spurious": 0, - "actual": 0, - "possible": 3, - "precision": 0, - "recall": 0, - "f1": 0, - }, - } - - assert results["strict"] == expected["strict"] - assert results["ent_type"] == expected["ent_type"] - assert results["partial"] == expected["partial"] - assert results["exact"] == expected["exact"] - - -def test_compute_actual_possible(): - results = { - "correct": 6, - "incorrect": 3, - "partial": 2, - "missed": 4, - "spurious": 2, - } - - expected = { - "correct": 6, - "incorrect": 3, - "partial": 2, - "missed": 4, - "spurious": 2, - "possible": 15, - "actual": 13, - } - - out = compute_actual_possible(results) - - assert out == expected - - -def test_compute_precision_recall(): - results = { - "correct": 6, - "incorrect": 3, - "partial": 2, - "missed": 4, - "spurious": 2, - "possible": 15, - "actual": 13, - } - - expected = { - "correct": 6, - "incorrect": 3, - "partial": 2, - "missed": 4, - "spurious": 2, - "possible": 15, - "actual": 13, - "precision": 0.46153846153846156, - "recall": 0.4, - "f1": 0.42857142857142855, - } - - out = compute_precision_recall(results) - - assert out == expected - - -def test_compute_metrics_one_pred_two_true(): - true_named_entities_1 = [ - {"start": 0, "end": 12, "label": "A"}, - {"start": 14, "end": 17, "label": "B"}, - ] - true_named_entities_2 = [ - {"start": 14, "end": 17, "label": "B"}, - {"start": 0, "end": 12, "label": "A"}, - ] - pred_named_entities = [ - {"start": 0, "end": 17, "label": "A"}, - ] - - results1, _, _, _ = compute_metrics(true_named_entities_1, pred_named_entities, ["A", "B"]) - results2, _, _, _ = compute_metrics(true_named_entities_2, pred_named_entities, ["A", "B"]) - - expected = { - "ent_type": { - "correct": 1, - "incorrect": 1, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 2, - "actual": 2, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "partial": { - "correct": 0, - "incorrect": 0, - "partial": 2, - "missed": 0, - "spurious": 0, - "possible": 2, - "actual": 2, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "strict": { - "correct": 0, - "incorrect": 2, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 2, - "actual": 2, - "precision": 0, - "recall": 0, - "f1": 0, - }, - "exact": { - "correct": 0, - "incorrect": 2, - "partial": 0, - "missed": 0, - "spurious": 0, - "possible": 2, - "actual": 2, - "precision": 0, - "recall": 0, - "f1": 0, - }, - } - - assert results1 == expected - assert results2 == expected - - -def test_evaluator_different_number_of_documents(): - """Test that Evaluator raises ValueError when number of predicted documents doesn't match true documents.""" - - # Create test data with different number of documents - true = [ - [{"label": "PER", "start": 0, "end": 5}], # First document - [{"label": "LOC", "start": 10, "end": 15}], # Second document - ] - pred = [[{"label": "PER", "start": 0, "end": 5}]] # Only one document - tags = ["PER", "LOC"] - - # Test that ValueError is raised - with pytest.raises(ValueError, match="Number of predicted documents does not equal true"): - evaluator = Evaluator(true=true, pred=pred, tags=tags) - evaluator.evaluate() diff --git a/tests/test_reporting.py b/tests/test_reporting.py deleted file mode 100644 index 029e88f..0000000 --- a/tests/test_reporting.py +++ /dev/null @@ -1,111 +0,0 @@ -import pytest - -from nervaluate.reporting import summary_report_ent, summary_report_overall - - -def test_summary_report_ent(): - # Sample input data - results_agg_entities_type = { - "PER": { - "strict": { - "correct": 10, - "incorrect": 2, - "partial": 1, - "missed": 3, - "spurious": 2, - "precision": 0.769, - "recall": 0.714, - "f1": 0.741, - } - }, - "ORG": { - "strict": { - "correct": 15, - "incorrect": 1, - "partial": 0, - "missed": 2, - "spurious": 1, - "precision": 0.882, - "recall": 0.833, - "f1": 0.857, - } - }, - } - - # Call the function - report = summary_report_ent(results_agg_entities_type, scenario="strict", digits=3) - - # Verify the report contains expected content - assert "PER" in report - assert "ORG" in report - assert "correct" in report - assert "incorrect" in report - assert "partial" in report - assert "missed" in report - assert "spurious" in report - assert "precision" in report - assert "recall" in report - assert "f1-score" in report - - # Verify specific values are present - assert "10" in report # PER correct - assert "15" in report # ORG correct - assert "0.769" in report # PER precision - assert "0.857" in report # ORG f1 - - # Test invalid scenario - with pytest.raises(Exception) as exc_info: - summary_report_ent(results_agg_entities_type, scenario="invalid") - assert "Invalid scenario" in str(exc_info.value) - - -def test_summary_report_overall(): - # Sample input data - results = { - "strict": { - "correct": 25, - "incorrect": 3, - "partial": 1, - "missed": 5, - "spurious": 3, - "precision": 0.862, - "recall": 0.806, - "f1": 0.833, - }, - "ent_type": { - "correct": 26, - "incorrect": 2, - "partial": 1, - "missed": 4, - "spurious": 3, - "precision": 0.897, - "recall": 0.839, - "f1": 0.867, - }, - } - - # Call the function - report = summary_report_overall(results, digits=3) - - # Verify the report contains expected content - assert "strict" in report - assert "ent_type" in report - assert "correct" in report - assert "incorrect" in report - assert "partial" in report - assert "missed" in report - assert "spurious" in report - assert "precision" in report - assert "recall" in report - assert "f1-score" in report - - # Verify specific values are present - assert "25" in report # strict correct - assert "26" in report # ent_type correct - assert "0.862" in report # strict precision - assert "0.867" in report # ent_type f1 - - # Test with different number of digits - report_2digits = summary_report_overall(results, digits=2) - assert "0.86" in report_2digits # strict precision with 2 digits - assert "0.87" in report_2digits # ent_type f1 with 2 digits diff --git a/tests/test_strategies.py b/tests/test_strategies.py new file mode 100644 index 0000000..22eada9 --- /dev/null +++ b/tests/test_strategies.py @@ -0,0 +1,392 @@ +import pytest +from nervaluate.entities import Entity +from nervaluate.strategies import ( + EntityTypeEvaluation, + ExactEvaluation, + PartialEvaluation, + StrictEvaluation +) + + +def create_entities_from_bio(bio_tags): + """Helper function to create entities from BIO tags.""" + entities = [] + current_entity = None + + for i, tag in enumerate(bio_tags): + if tag == "O": + continue + + if tag.startswith("B-"): + if current_entity: + entities.append(current_entity) + current_entity = Entity(tag[2:], i, i + 1) + elif tag.startswith("I-"): + if current_entity: + current_entity.end = i + 1 + else: + # Handle case where I- tag appears without B- + current_entity = Entity(tag[2:], i, i + 1) + + if current_entity: + entities.append(current_entity) + + return entities + + +@pytest.fixture +def base_sequence(): + """Base sequence: 'The John Smith who works at Google Inc'""" + return ["O", "B-PER", "I-PER", "O", "O", "O", "B-ORG", "I-ORG"] + + +class TestStrictEvaluation: + """Test cases for strict evaluation strategy.""" + + def test_perfect_match(self, base_sequence): + """Test case: Perfect match of all entities.""" + true = create_entities_from_bio(base_sequence) + pred = create_entities_from_bio(base_sequence) + + evaluator = StrictEvaluation() + result, _ = evaluator.evaluate(true, pred, ["PER", "ORG"]) + + assert result.correct == 2 + assert result.incorrect == 0 + assert result.partial == 0 + assert result.missed == 0 + assert result.spurious == 0 + + def test_missed_entity(self, base_sequence): + """Test case: One entity is missed in prediction.""" + true = create_entities_from_bio(base_sequence) + pred = create_entities_from_bio(["O", "B-PER", "I-PER", "O", "O", "O", "O", "O"]) + + evaluator = StrictEvaluation() + result, _ = evaluator.evaluate(true, pred, ["PER", "ORG"]) + + assert result.correct == 1 + assert result.incorrect == 0 + assert result.partial == 0 + assert result.missed == 1 + assert result.spurious == 0 + + def test_wrong_label(self, base_sequence): + """Test case: Entity with wrong label.""" + true = create_entities_from_bio(base_sequence) + pred = create_entities_from_bio(["O", "B-PER", "I-PER", "O", "O", "O", "B-LOC", "I-LOC"]) + + evaluator = StrictEvaluation() + result, _ = evaluator.evaluate(true, pred, ["PER", "ORG", "LOC"]) + + assert result.correct == 1 + assert result.incorrect == 1 + assert result.partial == 0 + assert result.missed == 0 + assert result.spurious == 0 + + def test_wrong_boundary(self, base_sequence): + """Test case: Entity with wrong boundary.""" + true = create_entities_from_bio(base_sequence) + pred = create_entities_from_bio(["O", "B-PER", "I-PER", "O", "O", "O", "B-LOC", "O"]) + + evaluator = StrictEvaluation() + result, _ = evaluator.evaluate(true, pred, ["PER", "ORG", "LOC"]) + + assert result.correct == 1 + assert result.incorrect == 1 + assert result.partial == 0 + assert result.missed == 0 + assert result.spurious == 0 + + def test_shifted_boundary(self, base_sequence): + """Test case: Entity with shifted boundary.""" + true = create_entities_from_bio(base_sequence) + pred = create_entities_from_bio(["O", "B-PER", "I-PER", "O", "O", "O", "O", "B-LOC"]) + + evaluator = StrictEvaluation() + result, _ = evaluator.evaluate(true, pred, ["PER", "ORG", "LOC"]) + + assert result.correct == 1 + assert result.incorrect == 1 + assert result.partial == 0 + assert result.missed == 0 + assert result.spurious == 0 + + def test_extra_entity(self, base_sequence): + """Test case: Extra entity in prediction.""" + true = create_entities_from_bio(base_sequence) + pred = create_entities_from_bio(["O", "B-PER", "I-PER", "O", "B-PER", "O", "B-LOC", "I-LOC"]) + + evaluator = StrictEvaluation() + result, _ = evaluator.evaluate(true, pred, ["PER", "ORG", "LOC"]) + + assert result.correct == 1 + assert result.incorrect == 1 + assert result.partial == 0 + assert result.missed == 0 + assert result.spurious == 1 + + +class TestEntityTypeEvaluation: + """Test cases for entity type evaluation strategy.""" + + def test_perfect_match(self, base_sequence): + """Test case: Perfect match of all entities.""" + true = create_entities_from_bio(base_sequence) + pred = create_entities_from_bio(base_sequence) + + evaluator = EntityTypeEvaluation() + result, _ = evaluator.evaluate(true, pred, ["PER", "ORG"]) + + assert result.correct == 2 + assert result.incorrect == 0 + assert result.partial == 0 + assert result.missed == 0 + assert result.spurious == 0 + + def test_missed_entity(self, base_sequence): + """Test case: One entity is missed in prediction.""" + true = create_entities_from_bio(base_sequence) + pred = create_entities_from_bio(["O", "B-PER", "I-PER", "O", "O", "O", "O", "O"]) + + evaluator = EntityTypeEvaluation() + result, _ = evaluator.evaluate(true, pred, ["PER", "ORG"]) + + assert result.correct == 1 + assert result.incorrect == 0 + assert result.partial == 0 + assert result.missed == 1 + assert result.spurious == 0 + + def test_wrong_label(self, base_sequence): + """Test case: Entity with wrong label.""" + true = create_entities_from_bio(base_sequence) + pred = create_entities_from_bio(["O", "B-PER", "I-PER", "O", "O", "O", "B-LOC", "I-LOC"]) + + evaluator = EntityTypeEvaluation() + result, _ = evaluator.evaluate(true, pred, ["PER", "ORG", "LOC"]) + + assert result.correct == 1 + assert result.incorrect == 1 + assert result.partial == 0 + assert result.missed == 0 + assert result.spurious == 0 + + def test_wrong_boundary(self, base_sequence): + """Test case: Entity with wrong boundary.""" + true = create_entities_from_bio(base_sequence) + pred = create_entities_from_bio(["O", "B-PER", "I-PER", "O", "O", "O", "B-LOC", "O"]) + + evaluator = EntityTypeEvaluation() + result, _ = evaluator.evaluate(true, pred, ["PER", "ORG", "LOC"]) + + assert result.correct == 1 + assert result.incorrect == 1 + assert result.partial == 0 + assert result.missed == 0 + assert result.spurious == 0 + + def test_shifted_boundary(self, base_sequence): + """Test case: Entity with shifted boundary.""" + true = create_entities_from_bio(base_sequence) + pred = create_entities_from_bio(["O", "B-PER", "I-PER", "O", "O", "O", "O", "B-LOC"]) + + evaluator = EntityTypeEvaluation() + result, _ = evaluator.evaluate(true, pred, ["PER", "ORG", "LOC"]) + + assert result.correct == 1 + assert result.incorrect == 1 + assert result.partial == 0 + assert result.missed == 0 + assert result.spurious == 0 + + def test_extra_entity(self, base_sequence): + """Test case: Extra entity in prediction.""" + true = create_entities_from_bio(base_sequence) + pred = create_entities_from_bio(["O", "B-PER", "I-PER", "O", "B-PER", "O", "B-LOC", "I-LOC"]) + + evaluator = EntityTypeEvaluation() + result, _ = evaluator.evaluate(true, pred, ["PER", "ORG", "LOC"]) + + assert result.correct == 1 + assert result.incorrect == 1 + assert result.partial == 0 + assert result.missed == 0 + assert result.spurious == 1 + + +class TestExactEvaluation: + """Test cases for exact evaluation strategy.""" + + def test_perfect_match(self, base_sequence): + """Test case: Perfect match of all entities.""" + true = create_entities_from_bio(base_sequence) + pred = create_entities_from_bio(base_sequence) + + evaluator = ExactEvaluation() + result, _ = evaluator.evaluate(true, pred, ["PER", "ORG"]) + + assert result.correct == 2 + assert result.incorrect == 0 + assert result.partial == 0 + assert result.missed == 0 + assert result.spurious == 0 + + def test_missed_entity(self, base_sequence): + """Test case: One entity is missed in prediction.""" + true = create_entities_from_bio(base_sequence) + pred = create_entities_from_bio(["O", "B-PER", "I-PER", "O", "O", "O", "O", "O"]) + + evaluator = ExactEvaluation() + result, _ = evaluator.evaluate(true, pred, ["PER", "ORG"]) + + assert result.correct == 1 + assert result.incorrect == 0 + assert result.partial == 0 + assert result.missed == 1 + assert result.spurious == 0 + + def test_wrong_label(self, base_sequence): + """Test case: Entity with wrong label.""" + true = create_entities_from_bio(base_sequence) + pred = create_entities_from_bio(["O", "B-PER", "I-PER", "O", "O", "O", "B-LOC", "I-LOC"]) + + evaluator = ExactEvaluation() + result, _ = evaluator.evaluate(true, pred, ["PER", "ORG", "LOC"]) + + assert result.correct == 2 + assert result.incorrect == 0 + assert result.partial == 0 + assert result.missed == 0 + assert result.spurious == 0 + + def test_wrong_boundary(self, base_sequence): + """Test case: Entity with wrong boundary.""" + true = create_entities_from_bio(base_sequence) + pred = create_entities_from_bio(["O", "B-PER", "I-PER", "O", "O", "O", "B-LOC", "O"]) + + evaluator = ExactEvaluation() + result, _ = evaluator.evaluate(true, pred, ["PER", "ORG", "LOC"]) + + assert result.correct == 1 + assert result.incorrect == 1 + assert result.partial == 0 + assert result.missed == 0 + assert result.spurious == 0 + + def test_shifted_boundary(self, base_sequence): + """Test case: Entity with shifted boundary.""" + true = create_entities_from_bio(base_sequence) + pred = create_entities_from_bio(["O", "B-PER", "I-PER", "O", "O", "O", "O", "B-LOC"]) + + evaluator = ExactEvaluation() + result, _ = evaluator.evaluate(true, pred, ["PER", "ORG", "LOC"]) + + assert result.correct == 1 + assert result.incorrect == 1 + assert result.partial == 0 + assert result.missed == 0 + assert result.spurious == 0 + + def test_extra_entity(self, base_sequence): + """Test case: Extra entity in prediction.""" + true = create_entities_from_bio(base_sequence) + pred = create_entities_from_bio(["O", "B-PER", "I-PER", "O", "B-PER", "O", "B-LOC", "I-LOC"]) + + evaluator = ExactEvaluation() + result, _ = evaluator.evaluate(true, pred, ["PER", "ORG", "LOC"]) + + assert result.correct == 2 + assert result.incorrect == 0 + assert result.partial == 0 + assert result.missed == 0 + assert result.spurious == 1 + + +class TestPartialEvaluation: + """Test cases for partial evaluation strategy.""" + + def test_perfect_match(self, base_sequence): + """Test case: Perfect match of all entities.""" + true = create_entities_from_bio(base_sequence) + pred = create_entities_from_bio(base_sequence) + + evaluator = PartialEvaluation() + result, _ = evaluator.evaluate(true, pred, ["PER", "ORG"]) + + assert result.correct == 2 + assert result.incorrect == 0 + assert result.partial == 0 + assert result.missed == 0 + assert result.spurious == 0 + + def test_missed_entity(self, base_sequence): + """Test case: One entity is missed in prediction.""" + true = create_entities_from_bio(base_sequence) + pred = create_entities_from_bio(["O", "B-PER", "I-PER", "O", "O", "O", "O", "O"]) + + evaluator = PartialEvaluation() + result, _ = evaluator.evaluate(true, pred, ["PER", "ORG"]) + + assert result.correct == 1 + assert result.incorrect == 0 + assert result.partial == 0 + assert result.missed == 1 + assert result.spurious == 0 + + def test_wrong_label(self, base_sequence): + """Test case: Entity with wrong label.""" + true = create_entities_from_bio(base_sequence) + pred = create_entities_from_bio(["O", "B-PER", "I-PER", "O", "O", "O", "B-LOC", "I-LOC"]) + + evaluator = PartialEvaluation() + result, _ = evaluator.evaluate(true, pred, ["PER", "ORG", "LOC"]) + + assert result.correct == 2 + assert result.incorrect == 0 + assert result.partial == 0 + assert result.missed == 0 + assert result.spurious == 0 + + def test_wrong_boundary(self, base_sequence): + """Test case: Entity with wrong boundary.""" + true = create_entities_from_bio(base_sequence) + pred = create_entities_from_bio(["O", "B-PER", "I-PER", "O", "O", "O", "B-LOC", "O"]) + + evaluator = PartialEvaluation() + result, _ = evaluator.evaluate(true, pred, ["PER", "ORG", "LOC"]) + + assert result.correct == 1 + assert result.incorrect == 0 + assert result.partial == 1 + assert result.missed == 0 + assert result.spurious == 0 + + def test_shifted_boundary(self, base_sequence): + """Test case: Entity with shifted boundary.""" + true = create_entities_from_bio(base_sequence) + pred = create_entities_from_bio(["O", "B-PER", "I-PER", "O", "O", "O", "O", "B-LOC"]) + + evaluator = PartialEvaluation() + result, _ = evaluator.evaluate(true, pred, ["PER", "ORG", "LOC"]) + + assert result.correct == 1 + assert result.incorrect == 0 + assert result.partial == 1 + assert result.missed == 0 + assert result.spurious == 0 + + def test_extra_entity(self, base_sequence): + """Test case: Extra entity in prediction.""" + true = create_entities_from_bio(base_sequence) + pred = create_entities_from_bio(["O", "B-PER", "I-PER", "O", "B-PER", "O", "B-LOC", "I-LOC"]) + + evaluator = PartialEvaluation() + result, _ = evaluator.evaluate(true, pred, ["PER", "ORG", "LOC"]) + + assert result.correct == 2 + assert result.incorrect == 0 + assert result.partial == 0 + assert result.missed == 0 + assert result.spurious == 1 diff --git a/tests/test_utils.py b/tests/test_utils.py index db5c76d..d26e3aa 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,7 +1,6 @@ from nervaluate import ( collect_named_entities, conll_to_spans, - find_overlap, list_to_spans, split_list, ) @@ -134,63 +133,3 @@ def test_collect_named_entities_no_entity(): result = collect_named_entities(tags) expected = [] assert result == expected - - -def test_find_overlap_no_overlap(): - pred_entity = {"label": "LOC", "start": 1, "end": 10} - true_entity = {"label": "LOC", "start": 11, "end": 20} - - pred_range = range(pred_entity["start"], pred_entity["end"]) - true_range = range(true_entity["start"], true_entity["end"]) - - pred_set = set(pred_range) - true_set = set(true_range) - - intersect = find_overlap(pred_set, true_set) - - assert not intersect - - -def test_find_overlap_total_overlap(): - pred_entity = {"label": "LOC", "start": 10, "end": 22} - true_entity = {"label": "LOC", "start": 11, "end": 20} - - pred_range = range(pred_entity["start"], pred_entity["end"]) - true_range = range(true_entity["start"], true_entity["end"]) - - pred_set = set(pred_range) - true_set = set(true_range) - - intersect = find_overlap(pred_set, true_set) - - assert intersect - - -def test_find_overlap_start_overlap(): - pred_entity = {"label": "LOC", "start": 5, "end": 12} - true_entity = {"label": "LOC", "start": 11, "end": 20} - - pred_range = range(pred_entity["start"], pred_entity["end"]) - true_range = range(true_entity["start"], true_entity["end"]) - - pred_set = set(pred_range) - true_set = set(true_range) - - intersect = find_overlap(pred_set, true_set) - - assert intersect - - -def test_find_overlap_end_overlap(): - pred_entity = {"label": "LOC", "start": 15, "end": 25} - true_entity = {"label": "LOC", "start": 11, "end": 20} - - pred_range = range(pred_entity["start"], pred_entity["end"]) - true_range = range(true_entity["start"], true_entity["end"]) - - pred_set = set(pred_range) - true_set = set(true_range) - - intersect = find_overlap(pred_set, true_set) - - assert intersect