Skip to content

Commit

Permalink
[TR] [METAL] Added evaluation and improvements for link_amr (#598)
Browse files Browse the repository at this point in the history
Improved the AMR linking by being smarter about resolving IDs.
Added a quantitative evaluation endpoint for AMR linking
  • Loading branch information
enoriega authored Oct 27, 2023
1 parent a9d0c51 commit dd1e4ec
Show file tree
Hide file tree
Showing 8 changed files with 4,104 additions and 18 deletions.
449 changes: 449 additions & 0 deletions skema/metal/model_linker/examples/data/sidarthe_amr.json

Large diffs are not rendered by default.

3,378 changes: 3,378 additions & 0 deletions skema/metal/model_linker/examples/data/sidarthe_annotations.json

Large diffs are not rendered by default.

33 changes: 21 additions & 12 deletions skema/metal/model_linker/skema_model_linker/link_amr.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def link_amr(
amr_path: str, # Path of the AMR model
attribute_collection: str, # Path to the attribute collection
amr_type: str, # AMR model type. I.e. "petrinet" or "regnet"
eval_mode: bool = False, # True when the extractions are manual annotations
output_path: Optional[str] = None, # Output file path
clean_xml_codepoints: Optional[bool] = False, # Replaces html codepoints with the unicode character
similarity_model: str = "sentence-transformers/all-MiniLM-L6-v2", # Transformer model to compute similarities
Expand All @@ -46,22 +47,30 @@ def link_amr(
if clean_xml_codepoints:
amr = replace_xml_codepoints(amr)

# Handle extractions from the SKEMA service or directly from the library
try:
extractions = AttributeCollection.from_json(attribute_collection)
except KeyError:
with open(attribute_collection) as f:
service_output = json.load(f)
collections = list()
for collection in service_output['outputs']:
collection = AttributeCollection.from_json(collection['data'])
collections.append(collection)

extractions = AttributeCollection(attributes=list(it.chain.from_iterable(c.attributes for c in collections)))

linker = Linker(model_name=similarity_model, device=device, sim_threshold=similarity_threshold)

linked_model = linker.link_model_to_text_extractions(amr, extractions)
if not eval_mode:
# Handle extractions from the SKEMA service or directly from the library
try:
extractions = AttributeCollection.from_json(attribute_collection)
except KeyError:
with open(attribute_collection) as f:
service_output = json.load(f)
collections = list()
for collection in service_output['outputs']:
collection = AttributeCollection.from_json(collection['data'])
collections.append(collection)

extractions = AttributeCollection(
attributes=list(it.chain.from_iterable(c.attributes for c in collections)))
linked_model = linker.link_model_to_text_extractions(amr, extractions)
else:
with open(attribute_collection) as f:
annotations = json.load(f)
annotations = replace_xml_codepoints(annotations)
linked_model = linker.link_model_to_manual_annotations(amr, annotations)

if not output_path:
input_amr_name = str(Path(amr_path).name)
Expand Down
40 changes: 40 additions & 0 deletions skema/metal/model_linker/skema_model_linker/linkers/amr_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections import defaultdict
from typing import Iterable, Dict, List, Any, Tuple, Optional, Union

import pandas as pd
import torch
from askem_extractions.data_model import Attribute, AnchoredEntity, AttributeCollection, AttributeType
from sentence_transformers import SentenceTransformer, util
Expand Down Expand Up @@ -68,6 +69,45 @@ def link_model_to_text_extractions(self, data: Union[Any, Dict[str, Any]], extra

class AMRLinker(Linker, ABC):

def link_model_to_manual_annotations(self, data: Dict[str, Any], candidates: List[Dict[str, Any]]) -> pd.DataFrame:
"""
Similarly to linking a model to text extractions. This will link it to ground truth extractions
Used mostly for debugging
"""

# Make a copy of the amr to avoid mutating the original model
data = {**data}

# Filter out the targets from the annotations
targets = defaultdict(list)
for candidate in candidates:
if candidate['type'] == "Highlight" and candidate['color'] == "#f9cd59": # This color is an anchored extraction
key = candidate["text"]
targets[key].append(candidate)


walker = self._build_walker(data)

to_link = list(walker.walk())
sources = self._generate_linking_sources(to_link)

pairs = self._align_texts(list(sources.keys()), list(targets.keys()), threshold=self._threshold)

linked_targets = list()
for s_key, t_key in pairs:
source = sources[s_key]
target = targets[t_key]

# Get the AMR ID of the source and add it to the target extractions
for t in target:
t['amr_element_id'] = source['id']
linked_targets.append(t)

# Serialize the attribute collection to json, after alignment
data["metadata"] = linked_targets

return data

def link_model_to_text_extractions(self, data: Dict[str, Any], extractions: AttributeCollection) -> Dict[str, Any]:

# Make a copy of the amr to avoid mutating the original model
Expand Down
37 changes: 37 additions & 0 deletions skema/rest/integrated_text_reading_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,5 +677,42 @@ def quantitative_eval() -> TextReadingEvaluationResults:
return compute_text_reading_evaluation(gt_data, extractions)


@router.post("/eval", response_model=TextReadingEvaluationResults, status_code=200)
def quantitative_eval(extractions_file: UploadFile, gt_annotations: UploadFile):
"""
# Gets performance metrics of a set of text extractions againts a ground truth annotations file.
## Example:
```python
files = {
"extractions_file": ("paper_variable_extractions.json", open("paper_variable_extractions.json", 'rb')),
"gt_annotations": ("paper_gt_annotations.json", open("paper_gt_annotations.json", 'rb')),
}
response = requests.post(f"{endpoint}/text-reading/eval", files=files)
```
"""

gt_data = json.load(gt_annotations.file)

# Support both Attribute Collections serialized and within the envelop of this rest API
extractions_json = json.load(extractions_file.file)
try:
extractions = AttributeCollection.from_json(extractions_json)
except KeyError:
extractions_file.file.seek(0)
service_output = json.load(extractions_file.file)
collections = list()
for collection in service_output['outputs']:
collection = AttributeCollection.from_json(collection['data'])
collections.append(collection)

extractions = AttributeCollection(
attributes=list(it.chain.from_iterable(c.attributes for c in collections)))

return compute_text_reading_evaluation(gt_data, extractions)


app = FastAPI()
app.include_router(router)
24 changes: 23 additions & 1 deletion skema/rest/metal_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

from skema.metal.model_linker.skema_model_linker.linkers import PetriNetLinker, RegNetLinker
from skema.metal.model_linker.skema_model_linker.link_amr import replace_xml_codepoints
from skema.rest.schema import TextReadingAnnotationsOutput
from skema.rest.schema import TextReadingAnnotationsOutput, TextReadingEvaluationResults, AMRLinkingEvaluationResults
from skema.rest.utils import compute_amr_linking_evaluation

router = APIRouter()

Expand Down Expand Up @@ -88,6 +89,27 @@ def link_amr(amr_type: str,
def healthcheck():
return 200

@router.post("/eval", response_model=AMRLinkingEvaluationResults, status_code=200)
def quantitative_eval(linked_amr_file: UploadFile, gt_linked_amr_file: UploadFile) -> AMRLinkingEvaluationResults:
"""
# Gets performance metrics of a linked amr with variable extractions against a ground truth linked amr.
## Example:
```python
files = {
"linked_amr": ("linked_amr_file.json", open("linked_amr_file.json", 'rb')),
"gt_linked_amr_file": ("gt_linked_amr_file.json", open("gt_linked_amr_file.json", 'rb')),
}
response = requests.post(f"{endpoint}/metal/eval", files=files)
```
"""

linked_amr = json.load(linked_amr_file.file)
gt_linked_amr_file = json.load(gt_linked_amr_file.file)

return compute_amr_linking_evaluation(linked_amr, gt_linked_amr_file)

app = FastAPI()
app.include_router(router)
13 changes: 10 additions & 3 deletions skema/rest/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,21 +191,28 @@ class TextReadingEvaluationResults(BaseModel):
description="Total number of extractions detected in the current document",
),
correct_extractions: int = Field(
name="correct_extractions",
description="Number of extractions matched in the ground-truth annotations"
),
precision: float = Field(
description="Ratio of correct extractions against manual annotations"
)


class AMRLinkingEvaluationResults(BaseModel):
""" Evaluation results of the AMR Linking procedure """
num_gt_elems_with_metadata: int
precision: float
recall: float
f1: float


class TextReadingAnnotationsOutput(BaseModel):
"""Contains the TR document results for all the documents submitted for annotation"""

outputs: List[TextReadingDocumentResults] = Field(
name="outputs",
description="Contains the results of TR annotations for each input document. There is one entry per input and "
"inputs and outputs are matched by the same index in the list",
"inputs and outputs are matched by the same index in the list",
examples=[[
TextReadingDocumentResults(
data=AttributeCollection(attributes=[]), errors=None
Expand All @@ -221,4 +228,4 @@ class TextReadingAnnotationsOutput(BaseModel):
None, name="generalized_errors",
description="Any pipeline-wide errors, not specific to a particular input",
examples=[[TextReadingError(pipeline="MIT", message="API quota exceeded")]],
)
)
148 changes: 146 additions & 2 deletions skema/rest/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from collections import defaultdict
from typing import Any, Dict
import itertools as it

from askem_extractions.data_model import AttributeCollection, AttributeType, AnchoredEntity, Mention
from askem_extractions.data_model import AttributeCollection, AttributeType, Mention
from bs4 import BeautifulSoup, Comment

from skema.rest.schema import TextReadingEvaluationResults
from skema.rest.schema import TextReadingEvaluationResults, AMRLinkingEvaluationResults


def clean_mml(mml: str) -> str:
Expand Down Expand Up @@ -70,3 +71,146 @@ def compute_text_reading_evaluation(gt_data: list, attributes: AttributeCollecti
correct_extractions=num_matches,
precision=num_matches / len(gt_data)
)


greek_alphabet = {
'Α': 'alpha',
'α': 'alpha',
'Β': 'beta',
'β': 'beta',
'Γ': 'gamma',
'γ': 'gamma',
'Δ': 'delta',
'δ': 'delta',
'Ε': 'epsilon',
'ε': 'epsilon',
'Ζ': 'zeta',
'ζ': 'zeta',
'Η': 'eta',
'η': 'eta',
'Θ': 'theta',
'θ': 'theta',
'Ι': 'iota',
'ι': 'iota',
'Κ': 'kappa',
'κ': 'kappa',
'Λ': 'lambda',
'λ': 'lambda',
'Μ': 'mu',
'μ': 'mu',
'Ν': 'nu',
'ν': 'nu',
'Ξ': 'xi',
'ξ': 'xi',
'Ο': 'omicron',
'ο': 'omicron',
'Π': 'pi',
'π': 'pi',
'Ρ': 'rho',
'ρ': 'rho',
'Σ': 'sigma',
'σ': 'sigma',
'ς': 'sigma',
'Τ': 'tau',
'τ': 'tau',
'Υ': 'upsilon',
'υ': 'upsilon',
'Φ': 'phi',
'φ': 'phi',
'Χ': 'chi',
'χ': 'chi',
'Ψ': 'psi',
'ψ': 'psi',
'Ω': 'omega',
'ω': 'omega'
}

def compute_amr_linking_evaluation(linked_amr, gt_linked_amr) -> AMRLinkingEvaluationResults:

# Find the amr elements with metadata in the GT
gt_amr_ids = {m['amr_element_id'] for m in gt_linked_amr['metadata'] if m['amr_element_id'] is not None}

# Fetch the relevant elements from both amrs
def get_elem_by_id(data, ids):
ret = list()
if isinstance(data, list):
ret.extend(it.chain.from_iterable(get_elem_by_id(a, ids) for a in data))
elif isinstance(data, dict):
if "id" in data and data["id"] in ids:
ret.append(data)
else:
ret.extend(it.chain.from_iterable(get_elem_by_id(v, ids) for k, v in data.items() if k != "metadata"))
return ret

gt_elems = get_elem_by_id(gt_linked_amr, gt_amr_ids)
runtime_elems = get_elem_by_id(linked_amr, gt_amr_ids)

# Generate metadata dictionaries
gt_metadata = defaultdict(list)
for m in gt_linked_amr['metadata']:
gt_metadata[m['amr_element_id']].append(m)

runtime_metadata = defaultdict(list)
for m in linked_amr['metadata']['attributes']:
runtime_metadata[m['amr_element_id']].append(m)

# Compute the numbers
tp, tn, fp, fn = 0, 0, 0, 0

for amr_id in gt_amr_ids:
gt = gt_metadata[amr_id]
rt = runtime_metadata[amr_id]

# Get the text from the ground truth
gt_texts = {e['text'] for e in gt}
expanded_gt_texts = set()
for t in gt_texts:
for k, v in greek_alphabet.items():
if k in t:
expanded_gt_texts.add(t.replace(k, v))
gt_texts |= expanded_gt_texts

# Get the text from the automated extractions
rt_texts = set()
for e in rt:
e = e['payload']
for m in e['mentions']:
name = m['name']
for d in e['text_descriptions']:
desc = d['description']
rt_texts.add((name, desc))
for v in e['value_descriptions']:
val = v['value']['amount']
rt_texts.add((name, val))

# Compute hits and misses
if len(gt_texts) > 0:
hit = False
for gtt in gt_texts:
if not hit:
for (a, b) in rt_texts:
# Both the name and the desc have to be present in the
# annotation in order to be a "hit"
if a in gtt and b in gtt:
tp += 1
hit = True
break
# If we made it to this point and neither of the extractions matched
# then, this is a false negative
fn += 1
elif len(rt_texts) > 0:
fp += 1
else:
tn += 1

precision = tp / ((tp + fp) + 0.000000001)
recall = tp / ((tp + fn) + 0.000000001)

f1 = (2*precision*recall) / ((precision + recall) + 0.000000001)

return AMRLinkingEvaluationResults(
num_gt_elems_with_metadata=len(gt_amr_ids),
precision=precision,
recall=recall,
f1=f1
)

0 comments on commit dd1e4ec

Please sign in to comment.