Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding coreference resolution before chunking text for improved text-to-KG #9875

Open
wants to merge 6 commits into
base: rebase-txt2kg
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions torch_geometric/nn/nlp/txt2kg.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
import os
import re
import time
from typing import List, Optional, Tuple

try:
import spacy
from space.language import Language
WITH_SPACY = True
except:
WITH_SPACY = False
Language = None
import torch
import torch.multiprocessing as mp

Expand Down Expand Up @@ -46,6 +54,12 @@ def __init__(
self.relevant_triples = {}
self.total_chars_parsed = 0
self.time_to_parse = 0.0
# initializing for 'en' right now. We can add multi-lingual capabilities in future.
if WITH_SPACY:
self.nlp = spacy.load("en_core_web_lg")
self.nlp.add_pipe('coreferee')
else:
raise Exception("`pip install spacy` to use TXT2KG")

def save_kg(self, path: str) -> None:
torch.save(self.relevant_triples, path)
Expand All @@ -71,6 +85,7 @@ def add_doc_2_KG(
txt: str,
QA_pair: Optional[Tuple[str, str]],
) -> None:
resolved_txt = corefernce_resolution(txt, pipeline=self.nlp)
chunks = chunk_text(txt, chunk_size=self.chunk_size)
if QA_pair:
# QA_pairs should be unique keys
Expand Down Expand Up @@ -234,3 +249,43 @@ def chunk_text(text: str, chunk_size: int = 512) -> list[str]:
text = text[best_split:].lstrip()

return chunks


def corefernce_resolution(text: str, nlp_language_pipeline: Language) -> str:
"""Performs coreference resolution on input text using spaCy's coreferee.

Resolves pronouns and other references to their full entity mentions to improve
knowledge graph extraction. Also cleans up text formatting post resolution.

Args:
text (str): Input text to perform coreference resolution on
nlp_language_pipeline (spacy.language.Language): Initialized spaCy pipeline with coreferee

Returns:
str: Text with coreferences resolved and formatting cleaned up
"""
# Process the text with the NLP pipeline
doc = nlp_language_pipeline(text)

# Resolve coreferences
resolved_text = ""
for token in doc:
repres = doc._.coref_chains.resolve(token)
if repres:
resolved_text += " " + " and ".join([t.text for t in repres])
else:
resolved_text += " " + token.text

# Replace multiple newlines with a period
resolved_text = re.sub(r'\n+', '.', resolved_text)

# Remove references like [1], [2], etc.
resolved_text = re.sub(r'\[\d+\]', ' ', resolved_text)

# Remove spaces before commas and periods
#text = re.sub(r'\s+([,.])', r'\1', text)

# Remove spaces before all punctuation characters
resolved_text = re.sub(r'\s+([^\w\s])', r'\1', resolved_text)

return resolved_text
Loading