Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhancement/partial match #82

Open
wants to merge 10 commits into
base: enhancement/partialMatch
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/FUNDING.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
github: Hironsan
26 changes: 26 additions & 0 deletions .github/workflows/pip.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
name: test package installation

on:
schedule:
- cron: "0 0 * * *"

jobs:
build:
if: contains(github.event.head_commit.message, '[skip ci]') == false
runs-on: ${{ matrix.os }}
strategy:
matrix:
python-version: [3.6, 3.7, 3.8]
os: [ubuntu-latest, macos-latest]

steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install --upgrade pip
pip install -U setuptools
- run: pip install seqeval
1 change: 1 addition & 0 deletions Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ autopep8 = "*"
flake8 = "*"
pytest-cov = "*"
isort = "*"
atomicwrites="*"

[packages]
numpy = "*"
Expand Down
352 changes: 201 additions & 151 deletions Pipfile.lock

Large diffs are not rendered by default.

15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,21 @@ In strict mode, the inputs are evaluated according to the specified schema. The
weighted avg 0.50 0.50 0.50 2
```

With the partial match, the inputs are evaluated according the number of tags. It isn't compatible with strict mode.
```python
print(classification_report(y_true, y_pred, partial_match=True))

precision recall f1-score support

MISC 0.75 1.00 0.86 3
PER 1.00 1.00 1.00 2

micro avg 0.83 1.00 0.91 5
macro avg 0.88 1.00 0.93 5
weighted avg 0.85 1.00 0.91 5
```


A minimum case to explain differences between the default and strict mode:

```python
Expand Down
72 changes: 58 additions & 14 deletions seqeval/metrics/sequence_labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def precision_recall_fscore_support(y_true: List[List[str]],
beta: float = 1.0,
sample_weight: Optional[List[int]] = None,
zero_division: str = 'warn',
suffix: bool = False) -> SCORES:
suffix: bool = False,
partial_match: bool = False) -> SCORES:
"""Compute precision, recall, F-measure and support for each class.

Args:
Expand Down Expand Up @@ -70,6 +71,8 @@ def precision_recall_fscore_support(y_true: List[List[str]],

suffix : bool, False by default.

partial_match : bool, False by default.

Returns:
precision : float (if average is not None) or array of float, shape = [n_unique_labels]

Expand Down Expand Up @@ -121,9 +124,32 @@ def extract_tp_actual_correct(y_true, y_pred, suffix, *args):
for type_name in target_names:
entities_true_type = entities_true.get(type_name, set())
entities_pred_type = entities_pred.get(type_name, set())
tp_sum = np.append(tp_sum, len(entities_true_type & entities_pred_type))
pred_sum = np.append(pred_sum, len(entities_pred_type))
true_sum = np.append(true_sum, len(entities_true_type))
if partial_match:
n_sublist = len(y_true)
vector_size = 0
if entities_true_type:
vector_size = max(entities_true_type)[1]
if entities_pred_type:
vector_size = max(max(entities_pred_type)[1], vector_size)

vector_size += n_sublist
entities_true_vector = np.zeros(vector_size, dtype=np.bool8)
# fill true values
for star, end in entities_true_type:
entities_true_vector[star:end + 1] = True
# fill predict values
entities_pred_vector = np.zeros(vector_size, dtype=np.bool8)
for star, end in entities_pred_type:
entities_pred_vector[star:end + 1] = True

tp_sum = np.append(tp_sum, (entities_true_vector * entities_pred_vector).sum())
pred_sum = np.append(pred_sum, entities_pred_vector.sum())
true_sum = np.append(true_sum, entities_true_vector.sum())

else:
tp_sum = np.append(tp_sum, len(entities_true_type & entities_pred_type))
pred_sum = np.append(pred_sum, len(entities_pred_type))
true_sum = np.append(true_sum, len(entities_true_type))

return pred_sum, tp_sum, true_sum

Expand Down Expand Up @@ -281,7 +307,8 @@ def f1_score(y_true: List[List[str]], y_pred: List[List[str]],
mode: Optional[str] = None,
sample_weight: Optional[List[int]] = None,
zero_division: str = 'warn',
scheme: Optional[Type[Token]] = None):
scheme: Optional[Type[Token]] = None,
partial_match: bool = False):
"""Compute the F1 score.

The F1 score can be interpreted as a weighted average of the precision and
Expand Down Expand Up @@ -330,6 +357,8 @@ def f1_score(y_true: List[List[str]], y_pred: List[List[str]],

suffix : bool, False by default.

partial_match : bool, False by default.

Returns:
score : float or array of float, shape = [n_unique_labels].

Expand All @@ -354,15 +383,17 @@ def f1_score(y_true: List[List[str]], y_pred: List[List[str]],
sample_weight=sample_weight,
zero_division=zero_division,
scheme=scheme,
suffix=suffix)
suffix=suffix
)
else:
_, _, f, _ = precision_recall_fscore_support(y_true, y_pred,
average=average,
warn_for=('f-score',),
beta=1,
sample_weight=sample_weight,
zero_division=zero_division,
suffix=suffix)
suffix=suffix,
partial_match=partial_match)
return f


Expand Down Expand Up @@ -406,7 +437,8 @@ def precision_score(y_true: List[List[str]], y_pred: List[List[str]],
mode: Optional[str] = None,
sample_weight: Optional[List[int]] = None,
zero_division: str = 'warn',
scheme: Optional[Type[Token]] = None):
scheme: Optional[Type[Token]] = None,
partial_match: bool = False):
"""Compute the precision.

The precision is the ratio ``tp / (tp + fp)`` where ``tp`` is the number of
Expand Down Expand Up @@ -454,6 +486,8 @@ def precision_score(y_true: List[List[str]], y_pred: List[List[str]],

suffix : bool, False by default.

partial_match : bool, False by default.

Returns:
score : float or array of float, shape = [n_unique_labels].

Expand Down Expand Up @@ -484,7 +518,8 @@ def precision_score(y_true: List[List[str]], y_pred: List[List[str]],
warn_for=('precision',),
sample_weight=sample_weight,
zero_division=zero_division,
suffix=suffix)
suffix=suffix,
partial_match=partial_match)
return p


Expand All @@ -495,7 +530,8 @@ def recall_score(y_true: List[List[str]], y_pred: List[List[str]],
mode: Optional[str] = None,
sample_weight: Optional[List[int]] = None,
zero_division: str = 'warn',
scheme: Optional[Type[Token]] = None):
scheme: Optional[Type[Token]] = None,
partial_match: bool = False,):
"""Compute the recall.

The recall is the ratio ``tp / (tp + fn)`` where ``tp`` is the number of
Expand Down Expand Up @@ -543,6 +579,8 @@ def recall_score(y_true: List[List[str]], y_pred: List[List[str]],

suffix : bool, False by default.

partial_match : bool, False by default.

Returns:
score : float.

Expand Down Expand Up @@ -573,7 +611,8 @@ def recall_score(y_true: List[List[str]], y_pred: List[List[str]],
warn_for=('recall',),
sample_weight=sample_weight,
zero_division=zero_division,
suffix=suffix)
suffix=suffix,
partial_match=partial_match)
return r


Expand Down Expand Up @@ -617,7 +656,8 @@ def classification_report(y_true, y_pred,
mode=None,
sample_weight=None,
zero_division='warn',
scheme=None):
scheme=None,
partial_match: bool = False):
"""Build a text report showing the main classification metrics.

Args:
Expand Down Expand Up @@ -648,6 +688,8 @@ def classification_report(y_true, y_pred,

suffix : bool, False by default.

partial_match : bool, False by default.

Returns:
report : string/dict. Summary of the precision, recall, F1 score for each class.

Expand Down Expand Up @@ -694,7 +736,8 @@ def classification_report(y_true, y_pred,
average=None,
sample_weight=sample_weight,
zero_division=zero_division,
suffix=suffix
suffix=suffix,
partial_match=partial_match,
)
for row in zip(target_names, p, r, f1, s):
reporter.write(*row)
Expand All @@ -708,7 +751,8 @@ def classification_report(y_true, y_pred,
average=average,
sample_weight=sample_weight,
zero_division=zero_division,
suffix=suffix
suffix=suffix,
partial_match=partial_match
)
reporter.write('{} avg'.format(average), avg_p, avg_r, avg_f1, support)
reporter.write_blank()
Expand Down
10 changes: 3 additions & 7 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
os.system('python setup.py sdist bdist_wheel upload')
sys.exit()

required = ['numpy==1.19.2', 'scikit-learn==0.23.2']
required = ['numpy>=1.14.0', 'scikit-learn>=0.21.3']

setup(
name=NAME,
Expand All @@ -45,13 +45,9 @@
classifiers=[
'License :: OSI Approved :: MIT License',
'Programming Language :: Python',
'Programming Language :: Python :: 2.6',
'Programming Language :: Python :: 2.7',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.3',
'Programming Language :: Python :: 3.4',
'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: Implementation :: CPython',
'Programming Language :: Python :: Implementation :: PyPy'
],
Expand Down
6 changes: 6 additions & 0 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,15 @@ def test_performance_measure(self):
def test_classification_report(self):
print(classification_report(self.y_true, self.y_pred))

def test_classification_report(self):
print(classification_report(self.y_true, self.y_pred, partial_match=True))

def test_inv_classification_report(self):
print(classification_report(self.y_true_inv, self.y_pred_inv, suffix=True))

def test_classification_report(self):
print(classification_report(self.y_true_inv, self.y_pred_inv, suffix=True, partial_match=True))

def test_by_ground_truth(self):
with open(self.file_name) as f:
output = subprocess.check_output(['perl', 'conlleval.pl'], stdin=f).decode('utf-8')
Expand Down