Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nltk #5

Draft
wants to merge 20 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.0rc6
1.1a1
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def from_file(filename):
'jsonrpcserver>=4.0.1',
'gunicorn>=19.9.0',
'docutils>=0.14',
'nltk>=3.4.1',
'editdistance>=0.5.3',
],
extras_require={
Expand Down
3 changes: 2 additions & 1 deletion src/benchmarkstt/api/templates/api-explorer.html
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,8 @@ <h4>{{ item.name }}</h4>
{% else %}
{% set example = '' %}
{% endif %}
{% if param.name in ('text', 'hyp', 'ref') or '\n' in example %}
{% if example is string and '\n' in example %}

<textarea class="form-input {{ param.name }}"{% if param.is_required %}required{% endif %} rows="6" name="{{ param.name }}" id="{{ item.id }}_{{ param.name }}">{{ example }}</textarea>
{% else %}
<input class="form-input {{ param.name }}"{% if param.is_required %}required{% endif %} type="text" id="{{ item.id }}_{{ param.name }}" name="{{ param.name }}" value="{{ example }}">
Expand Down
63 changes: 37 additions & 26 deletions src/benchmarkstt/benchmark/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from io import StringIO
from benchmarkstt.input.core import PlainText
from benchmarkstt.normalization.core import Config
from benchmarkstt.normalization import NormalizationComposite
from benchmarkstt.normalization.logger import LogCapturer

factory = metrics.factory
Expand All @@ -14,50 +15,60 @@ def callback(cls, ref: str, hyp: str, config: str = None, return_logs: bool = No
:param config: The config to use
:param bool return_logs: Return normalization logs

:example ref: 'Hello darkness my OLD friend'
:example hyp: 'Hello darkness my old foe'
:example ref:

.. code-block:: text

Brave Sir Robin ran away. Bravely ran away away. When danger
reared it’s ugly head, he bravely turned his tail and
fled. Brave Sir Robin turned about and gallantly he chickened out...

:example hyp:

.. code-block:: text

Brave Sir Robin ran away. Bravely ran away away. When danger
reared it’s wicked head, he bravely turned his tail and
fled. Brave Sir Chicken turned about and chickened out... Innit?

:example config:

.. code-block:: text

[normalization]
# using a simple config file
Lowercase

:example result: ""
"""

normalizer = None
normalizer_ref = None
normalizer_hyp = None

if config is not None and len(config.strip()):
normalizer = Config(StringIO(config), section='normalization')
normalizer_ref = NormalizationComposite(title='Reference')
normalizer_ref.add(normalizer)
normalizer_hyp = NormalizationComposite(title='Hypothesis')
normalizer_hyp.add(normalizer)

ref = PlainText(ref, normalizer=normalizer)
hyp = PlainText(hyp, normalizer=normalizer)
ref = PlainText(ref, normalizer=normalizer_ref)
hyp = PlainText(hyp, normalizer=normalizer_hyp)

metric = cls(*args, **kwargs)
cls_name = cls.__name__.lower()

if not return_logs:
result = metric.compare(list(ref), list(hyp))
def get_result():
result = metric.compare(ref, hyp)
if isinstance(result, tuple) and hasattr(result, '_asdict'):
result = result._asdict()
return result

if not return_logs:
return {
cls_name: result
cls_name: get_result()
}

with LogCapturer(dialect='html', diff_formatter_dialect='dict', title='Reference') as logcap:
ref = list(ref)
logs_ref = logcap.logs

with LogCapturer(dialect='html', diff_formatter_dialect='dict', title='Hypothesis') as logcap:
hyp = list(hyp)
logs_hyp = logcap.logs

result = metric.compare(ref, hyp)
if isinstance(result, tuple) and hasattr(result, '_asdict'):
result = result._asdict()

return {
cls_name: result,
"logs": logs_ref + logs_hyp
}
with LogCapturer(dialect='html', diff_formatter_dialect='dict') as logcap:
return {
cls_name: get_result(),
"logs": logcap.logs
}
5 changes: 3 additions & 2 deletions src/benchmarkstt/benchmark/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@ def argparser(parser: argparse.ArgumentParser):


def main(parser, args):
normalizer = get_normalizer_from_args(args)
metrics_cli.main(parser, args, normalizer)
normalizer_ref = get_normalizer_from_args(args, 'Reference')
normalizer_hyp = get_normalizer_from_args(args, 'Hypothesis')
metrics_cli.main(parser, args, normalizer_ref, normalizer_hyp)
11 changes: 2 additions & 9 deletions src/benchmarkstt/docblock.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import textwrap
import inspect
import re
import ast
from collections import namedtuple
import logging
from docutils.core import publish_string
Expand Down Expand Up @@ -64,7 +63,7 @@ def _(match):
if no_name:
regex = r'^[ \t]*:%s[ \t]*([a-z_]+)?:[ \t]+(.*)$'
else:
regex = r'^[ \t]*:%s[ \t]+(?:([^:]+)[ \t]+)?([a-z_]+):[ \t]*(.+$|(?:$[ \t]*\n)*([ \t]+)([^\n]*)$(?:\4.*\n|\n)+)'
regex = r'^[ \t]*:%s[ \t]+(?:([^:]+)[ \t]+)?([a-z_]+):[ \t]*(.+$|(?:$[ \t]*\n)*([ \t]+)([^\n]*)$(?:\4.*|\n)+)'

docs = re.sub(
regex % (re.escape(key),), _, docstring, flags=re.MULTILINE
Expand All @@ -76,12 +75,7 @@ def _(match):
def decode_literal(txt: str):
if txt is None:
return ''

try:
return ast.literal_eval(txt)
except (ValueError, SyntaxError) as e:
logger.warning('%s "%s" for: %s', type(e), e, txt)
return txt
return txt


def parse(func):
Expand Down Expand Up @@ -123,7 +117,6 @@ def decode_examples(match, param):
idx < defaults_idx,
description,
examples)

params.append(param)

# quick hack to remove this
Expand Down
3 changes: 2 additions & 1 deletion src/benchmarkstt/input/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@


class Base:
def __iter__(self):
def segmented(self, segmenter):
"""
Each input class should be accessible as iterator, each iteration should
return a Item, so the input format is essentially usable and can be easily
converted to a :py:class:`benchmarkstt.schema.Schema`
"""

raise NotImplementedError()


Expand Down
33 changes: 21 additions & 12 deletions src/benchmarkstt/input/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

"""

import benchmarkstt.segmentation.core as segmenters
from benchmarkstt import input, settings
# from benchmarkstt.modules import LoadObjectProxy

Expand All @@ -12,15 +11,15 @@ class PlainText(input.Base):
"""
plain text
"""
def __init__(self, text, segmenter=None, normalizer=None):
if segmenter is None:
segmenter = segmenters.Simple
def __init__(self, text, normalizer=None):
self._text = text
self._segmenter = segmenter
self._normalizer = normalizer

def __iter__(self):
return iter(self._segmenter(self._text, normalizer=self._normalizer))
def segmented(self, segmenter):
return iter(segmenter(self._text, normalizer=self._normalizer))

def __str__(self):
return self._text


class File(input.Base):
Expand Down Expand Up @@ -62,13 +61,23 @@ def __init__(self, file, input_type=None, normalizer=None):
input_type = input.factory[input_type]

self._input_class = input_type
self._text = None

def __iter__(self):
encoding = settings.default_encoding
with open(self._file, encoding=encoding) as f:
text = f.read()
@property
def text(self):
if self._text is None:
encoding = settings.default_encoding
with open(self._file, encoding=encoding) as f:
self._text = f.read()

return self._text

def segmented(self, segmenter):
return self._input_class(self.text, normalizer=self._normalizer).segmented(segmenter)

def __str__(self):
return self.text

return iter(self._input_class(text, normalizer=self._normalizer))

# For future versions
# class ExternalInput(LoadObjectProxy, input.Base):
Expand Down
4 changes: 2 additions & 2 deletions src/benchmarkstt/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from benchmarkstt.schema import Schema
from benchmarkstt.factory import Factory
from benchmarkstt.input import Base as InputBase


class Base:
"""
Base class for metrics
"""
def compare(self, ref: Schema, hyp: Schema):
def compare(self, ref: InputBase, hyp: InputBase):
raise NotImplementedError()


Expand Down
18 changes: 18 additions & 0 deletions src/benchmarkstt/metrics/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,26 @@

def callback(cls, ref: str, hyp: str, *args, **kwargs):
"""
:example ref:

.. code-block:: text

Brave Sir Robin ran away. Bravely ran away away. When danger
reared it’s ugly head, he bravely turned his tail and
fled.
Brave Sir Robin turned about and gallantly he chickened out...

:example hyp:

.. code-block:: text

Brave Sir Robin ran away. Bravely ran away away. When danger
reared it’s wicked head, he bravely turned his tail and
fled. Brave Sir Chicken turned about and chickened out... Innit?

:param ref: Reference text
:param hyp: Hypothesis text

"""
ref = PlainText(ref)
hyp = PlainText(hyp)
Expand Down
28 changes: 13 additions & 15 deletions src/benchmarkstt/metrics/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@
from benchmarkstt.output import factory as output_factory
from benchmarkstt.metrics import factory
from benchmarkstt.cli import args_from_factory
from benchmarkstt.normalization.logger import Logger
import argparse
from inspect import signature, Parameter
import logging
from collections import OrderedDict
from functools import partial


def argparser(parser: argparse.ArgumentParser):
Expand Down Expand Up @@ -47,20 +46,18 @@ def argparser(parser: argparse.ArgumentParser):
return parser


def file_to_iterable(file, type_, normalizer=None):
if type_ == 'argument':
return core.PlainText(file, normalizer=normalizer)
return core.File(file, type_, normalizer=normalizer)
def main(parser, args, normalizer_ref=None, normalizer_hyp=None):
def file_to_inputclass(name, normalizer):
arg = partial(getattr, args)
file = arg(name)
type_ = arg('%s_type' % name)

if type_ == 'argument':
return core.PlainText(file, normalizer=normalizer)
return core.File(file, type_, normalizer=normalizer)

def main(parser, args, normalizer=None):
logging.getLogger()
prev_title = Logger.title
Logger.title = 'Reference'
ref = list(file_to_iterable(args.reference, args.reference_type, normalizer=normalizer))
Logger.title = 'Hypothesis'
hyp = list(file_to_iterable(args.hypothesis, args.hypothesis_type, normalizer=normalizer))
Logger.title = prev_title
ref = file_to_inputclass('reference', normalizer_ref)
hyp = file_to_inputclass('hypothesis', normalizer_hyp)

if 'metrics' not in args or not len(args.metrics):
parser.error("need at least one metric")
Expand Down Expand Up @@ -89,4 +86,5 @@ def main(parser, args, normalizer=None):

metric = cls(*item, **kwargs)
result = metric.compare(ref, hyp)
out.result(metric_name, result)
out.title(metric_name)
out.result(result)
Loading