diff --git a/.gitignore b/.gitignore index 66637be2..3b4f2025 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,7 @@ # sphinx documentation docs/build/ +docs/modules/*.rst +docs/api-methods.rst # python caches and build __pycache__ diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index ac8758a2..00000000 --- a/.travis.yml +++ /dev/null @@ -1,20 +0,0 @@ -language: python -os: - - windows - - linux - - osx -python: - - "3.5" - - "3.6" -install: - - pip install '.[api,test]' -script: pytest -branches: - only: - - master -matrix: - include: - - { python: "3.7", dist: "xenial", os: "linux" } - allow_failures: - - os: windows - - os: osx diff --git a/Dockerfile b/Dockerfile index 646cd8ba..b439c64f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -8,7 +8,7 @@ RUN pip install --upgrade pip WORKDIR /home/benchmarkstt COPY . /home/benchmarkstt/ -RUN pip install '.[api]' +RUN pip install '.[test]' RUN chown -R benchmarkstt:benchmarkstt ./ USER benchmarkstt diff --git a/Makefile b/Makefile index 1d1ea652..7fd9c7bd 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,10 @@ .PHONY: docs test clean gh-pages test: - pytest --verbose + pytest src --doctest-modules --verbose + pytest tests --verbose + pycodestyle tests + pycodestyle src docs: html @@ -9,7 +12,7 @@ html: apidocs echo ".. Note, this was autogenerated, all changes will vanish...\n" > docs/api-methods.rst echo "Available JSON-RPC methods\n==========================\n\n" >> docs/api-methods.rst benchmarkstt api --list-methods >> docs/api-methods.rst - cd docs/ && make clean html && touch build/html/.nojekyll + cd docs/ && make clean html man: build-man cp resources/manpage/*.1 /usr/local/share/man/man1 @@ -34,4 +37,5 @@ gh-pages: # docs git add -A && git commit -a -m 'update docs' && git push --set-upstream origin gh-pages apidocs: + ls docs/modules/|grep '.rst$$' && rm docs/modules/*.rst || echo "no .rst files to clean" sphinx-apidoc -f -e -o docs/modules/ src/benchmarkstt/ && rm docs/modules/modules.rst diff --git a/README.md b/README.md deleted file mode 100644 index 442919b6..00000000 --- a/README.md +++ /dev/null @@ -1,73 +0,0 @@ - -# AI Benchmarking STT - -[![Azure Build](https://img.shields.io/azure-devops/build/danielthepope/ai-benchmarking-stt/3.svg?logo=azure-devops)](https://dev.azure.com/danielthepope/ai-benchmarking-stt/_build/latest?definitionId=3&branchName=master) -[![Azure Tests](https://img.shields.io/azure-devops/tests/danielthepope/ai-benchmarking-stt/3.svg?logo=azure-devops)](https://dev.azure.com/danielthepope/ai-benchmarking-stt/_build/latest?definitionId=3&branchName=master) -[![License: MIT](https://img.shields.io/github/license/ebu/ai-benchmarking-stt.svg)](https://opensource.org/licenses/MIT) - - - - - -This is a collaborative project to create a library for benchmarking AI/ML applications. It evolved out of conversations among broadcasters and providers of Access Services to media organisations, but anyone is welcome to contribute. The group behind this project is the EBU's [Media Information Management & AI group](https://tech.ebu.ch/groups/mim). Currently the group is focussing on Speech-to-Text, but it will consider creating benchmarking tools for other AI/ML services. - -For general information about this project, including the [motivations](https://github.com/ebu/ai-benchmarking-stt/wiki) and [guiding principles](https://github.com/ebu/ai-benchmarking-stt/wiki/Principles), please see the project [wiki](https://github.com/ebu/ai-benchmarking-stt/wiki) and [documentation](https://benchmarkstt.mikesmith.eu). - - - - - - - - diff --git a/README.rst b/README.rst new file mode 100644 index 00000000..a3557490 --- /dev/null +++ b/README.rst @@ -0,0 +1,20 @@ +AI Benchmarking STT +=================== + +.. image:: https://img.shields.io/github/license/ebu/ai-benchmarking-stt.svg + :target: https://github.com/ebu/ai-benchmarking-stt/blob/master/LICENCE.md + +.. image:: https://img.shields.io/azure-devops/build/danielthepope/ai-benchmarking-stt/3.svg?logo=azure-devops + :target: https://dev.azure.com/danielthepope/ai-benchmarking-stt/_build/latest?definitionId=3&branchName=master + +.. image:: https://img.shields.io/azure-devops/tests/danielthepope/ai-benchmarking-stt/3.svg?logo=azure-devops + :target: https://dev.azure.com/danielthepope/ai-benchmarking-stt/_build/latest?definitionId=3&branchName=master + +.. image:: https://img.shields.io/azure-devops/coverage/danielthepope/ai-benchmarking-stt/3.svg?logo=azure-devops + :target: https://dev.azure.com/danielthepope/ai-benchmarking-stt/_build + +This is a collaborative project to create a library for benchmarking AI/ML applications. It evolved out of conversations among broadcasters and providers of Access Services to media organisations, but anyone is welcome to contribute. The group behind this project is the EBU's `Media Information Management & AI group `_. Currently the group is focussing on Speech-to-Text, but it will consider creating benchmarking tools for other AI/ML services. + +For general information about this project, including the `motivations `_ and `guiding principles `_, please see the project `wiki `_ and `documentation `_. + + diff --git a/VERSION b/VERSION index 8a9ecc2e..7bcd0e36 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.0.1 \ No newline at end of file +0.0.2 \ No newline at end of file diff --git a/azure-pipelines.yml b/azure-pipelines.yml index faa60d86..11466295 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -36,20 +36,42 @@ jobs: versionSpec: '$(python.version)' architecture: 'x64' - - script: python -m pip install --upgrade pip && pip install ".[api,test]" + - script: python -m pip install --upgrade pip && pip install ".[test]" displayName: 'Install dependencies' - script: | - pip install pytest - pytest tests --doctest-modules --junitxml=junit/test-results.xml + pytest tests --junitxml=junit/test-results-unit.xml -vv displayName: 'pytest' - task: PublishTestResults@2 inputs: - testResultsFiles: '**/test-results.xml' + testResultsFiles: '**/test-results*.xml' testRunTitle: 'Python $(python.version)' condition: succeededOrFailed() + +- job: 'CodeCoverage' + pool: + vmImage: 'Ubuntu-16.04' + + steps: + - task: UsePythonVersion@0 + inputs: + versionSpec: '3.7' + architecture: 'x64' + + - script: | + python -m pip install --upgrade pip && pip install -e ".[test]" + pytest tests --cov=src --cov-report xml + displayName: 'Check code coverage' + + - task: PublishCodeCoverageResults@1 + displayName: 'Publish code coverage from $(System.DefaultWorkingDirectory)/**/coverage.xml' + inputs: + codeCoverageTool: Cobertura + summaryFileLocation: '$(System.DefaultWorkingDirectory)/**/coverage.xml' + + - job: 'Lint' pool: vmImage: 'Ubuntu-16.04' @@ -63,5 +85,7 @@ jobs: - script: python -m pip install --upgrade pip && pip install pycodestyle displayName: 'Install pycodestyle' - - script: pycodestyle + - script: | + pycodestyle tests + pycodestyle src displayName: 'pycodestyle' diff --git a/docs/INSTALL.rst b/docs/INSTALL.rst index 569141fb..5d0c591e 100644 --- a/docs/INSTALL.rst +++ b/docs/INSTALL.rst @@ -23,29 +23,13 @@ From source - `Installing Python 3 on Windows `_ - `Installing Python 3 on Linux `_ -2. Get the benchmarkstt source code from github (assumes :code:`git` is installed on the system) - - .. git clone https://github.com/ebu/ai-benchmarking-stt.git - - .. code-block:: bash - - git clone https://github.com/MikeSmithEU/ai-benchmarking-stt.git - -3. Install the package using :code:`pip`, this will also install all requirements - - .. code-block:: bash - - cd ai-benchmarking-stt - pip install '.[api]' - - Once this is done you may remove the git repository (optional). +2. Install the package using :code:`pip`, this will also install all requirements .. code-block:: bash - cd .. - rm -fr ai-benchmarking-stt + pip install https://github.com/ebu/ai-benchmarking-stt/archive/master.zip -4. Test and use +3. Test and use BenchmarkSTT should now be installed and usable. diff --git a/docs/LICENSE.rst b/docs/LICENSE.rst index 1c9a45a6..debe6c52 100644 --- a/docs/LICENSE.rst +++ b/docs/LICENSE.rst @@ -1,10 +1,5 @@ License ------- -Copyright 2019 EBU +.. include:: ../LICENCE.md -Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/docs/api-methods.rst b/docs/api-methods.rst deleted file mode 100644 index df3ec7ce..00000000 --- a/docs/api-methods.rst +++ /dev/null @@ -1,251 +0,0 @@ -.. Note, this was autogenerated, all changes will vanish... - -Available JSON-RPC methods -========================== - - -version -------- - -Get the version of benchmarkstt - -:return str: BenchmarkSTT version - -list.normalizers ----------------- - -Get a list of available core normalizers - -:return object: With key being the normalizer name, and value its description - -normalization.alphanumeric --------------------------- - -Simple alphanumeric filter - -:example text: "He's a lumberjack, and he's okay!" -:example return: "Hesalumberjackandhesokay" - -:param str text: The text to normalize -:param bool return_logs: Return normalizer logs - -normalization.alphanumericunicode ---------------------------------- - -Simple alphanumeric filter, takes into account all unicode alphanumeric -characters. - -:example text: "Das, öder die Flipper-Wåld Gespütt!" -:example return: "DasöderdieFlipperWåldGespütt" - -:param str text: The text to normalize -:param bool return_logs: Return normalizer logs - -normalization.config --------------------- - -Use config notation to define normalization rules. This notation is a -list of normalizers, one per line, with optional arguments (separated by a -space). - -The normalizers can be any of the core normalizers, or you can refer to your -own normalizer class (like you would use in a python import, eg. -`my.own.package.MyNormalizerClass`). - -Additional rules: - - Normalizer names are case-insensitive. - - Arguments MAY be wrapped in double quotes. - - If an argument contains a space, newline or double quote, it MUST be - wrapped in double quotes. - - A double quote itself is represented in this quoted argument as two - double quotes: `""`. - -The normalization rules are applied top-to-bottom and follow this format: - -.. code-block:: text - - Normalizer1 arg1 "arg 2" - # This is a comment - - Normalizer2 - # (Normalizer2 has no arguments) - Normalizer3 "This is argument 1 - Spanning multiple lines - " "argument 2" - Normalizer4 "argument with double quote ("")" - -:param str config: configuration text - -:example text: "He bravely turned his tail and fled" -:example config: '''# using a simple config file\nLowercase \n -# it even supports comments -# If there is a space in the argument, make sure you quote it though! -regexreplace "y t" "Y T" -\n\n -# extraneous whitespaces are ignored -replace e a\n''' -:example return: "ha bravalY Turnad his tail and flad" - -:param str text: The text to normalize -:param bool return_logs: Return normalizer logs - -normalization.configfile ------------------------- - -Load config from a file, see :py:class:`Config` for information about config -notation - -:param typing.io.TextIO file: The file -:param str encoding: The file encoding - -:example text: "He bravely turned his tail and fled" -:example file: "./resources/test/normalizers/configfile.conf" -:example encoding: "UTF-8" -:example return: "ha bravalY Turnad his tail and flad" - -:param str text: The text to normalize -:param bool return_logs: Return normalizer logs - -normalization.file ------------------- - -Read one per line and pass it to the given normalizer - -:param str|class normalizer: Normalizer name (or class) -:param str file: The file to read rules from -:param str encoding: The file encoding - -:example text: "This is an Ex-Parakeet" -:example normalizer: "regexreplace" -:example file: "./resources/test/normalizers/regexreplace/en_US" -:example encoding: "UTF-8" -:example return: "This is an Ex Parrot" - -:param str text: The text to normalize -:param bool return_logs: Return normalizer logs - -normalization.localizedfile ---------------------------- - -Reads and applies normalization rules from a locale-based file, it will -automatically determine the "best fit" for a given locale, if one is -available. - -:param str|class normalizer: Normalizer name (or class) -:param str locale: Which locale to search for -:param PathLike path: Location of available locale files -:param str encoding: The file encoding - -:example text: "This is an Ex-Parakeet" -:example normalizer: "regexreplace" -:example path: "./resources/test/normalizers/regexreplace" -:example locale: "en" -:example encoding: "UTF-8" -:example return: "This is an Ex Parrot" - -:param str text: The text to normalize -:param bool return_logs: Return normalizer logs - -normalization.lowercase ------------------------ - -Lowercase the text - - -:example text: "Easy, Mungo, easy... Mungo..." -:example return: "easy, mungo, easy... mungo..." - -:param str text: The text to normalize -:param bool return_logs: Return normalizer logs - -normalization.regexreplace --------------------------- - -Simple regex replace. By default the pattern is interpreted -case-sensitive. - -Case-insensitivity is supported by adding inline modifiers. - -You might want to use capturing groups to preserve the case. When replacing -a character not captured, the information about its case is lost... - -Eg. would replace "HAHA! Hahaha!" to "HeHe! Hehehe!": - - +------------------+-------------+ - | search | replace | - +==================+=============+ - | :code:`(?i)(h)a` | :code:`\1e` | - +------------------+-------------+ - - -No regex flags are set by default, you can set them yourself though in the -regex, and combine them at will, eg. multiline, dotall and ignorecase. - -Eg. would replace "Newline" to "newline": - - +------------------------+------------------+ - | search | replace | - +========================+==================+ - | :code:`(?msi)new.line` | :code:`newline` | - +------------------------+------------------+ - -:example text: "HAHA! Hahaha!" -:example search: '(?i)(h)a' -:example replace: r'\1e' -:example return: "HeHe! Hehehe!" - -:param str text: The text to normalize -:param bool return_logs: Return normalizer logs - -normalization.replace ---------------------- - -Simple search replace - -:param str search: Text to search for -:param str replace: Text to replace with - -:example text: "Nudge nudge!" -:example search: "nudge" -:example replace: "wink" -:example return: "Nudge wink!" - -:param str text: The text to normalize -:param bool return_logs: Return normalizer logs - -normalization.replacewords --------------------------- - -Simple search replace that only replaces "words", the first letter will be -checked case insensitive as well with preservation of case.. - -:param str search: Word to search for -:param str replace: Replace with - -:example text: "She has a heart of formica" -:example search: "a" -:example replace: "the" -:example return: "She has the heart of formica" - -:param str text: The text to normalize -:param bool return_logs: Return normalizer logs - -normalization.unidecode ------------------------ - -Unidecode characters to ASCII form, see `Python's Unidecode package -`_ for more info. - -:example text: "𝖂𝖊𝖓𝖓 𝖎𝖘𝖙 𝖉𝖆𝖘 𝕹𝖚𝖓𝖘𝖙ü𝖈𝖐 𝖌𝖎𝖙 𝖚𝖓𝖉 𝕾𝖑𝖔𝖙𝖊𝖗𝖒𝖊𝖞𝖊𝖗?" -:example return: "Wenn ist das Nunstuck git und Slotermeyer?" - -:param str text: The text to normalize -:param bool return_logs: Return normalizer logs - -help ----- - -Returns available api methods - -:return object: With key being the method name, and value its description - diff --git a/docs/api.rst b/docs/api.rst index b6fde458..8a7ef8e3 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -3,6 +3,9 @@ API BenchmarkSTT exposes its functionality through a JSON-RPC_ api. +.. attention:: + Only supported for Python versions 3.6 and above! + Starting the server ------------------- @@ -34,6 +37,10 @@ Using curl, for example: If you started the service with parameter `--with-explorer` (see :doc:`cli/api`), you can easily test the available JSON-RPC_ api calls by visiting the api url (eg. `http://localhost:8080/api` in the above example). +.. important:: + The API explorer is provided as-is, without any tests or code reviews. This + is marked as a low-priority feature. + .. toctree:: :maxdepth: 2 diff --git a/docs/cli.rst b/docs/cli.rst index e061cb6b..a793ab0f 100644 --- a/docs/cli.rst +++ b/docs/cli.rst @@ -15,6 +15,7 @@ Subcommands cli/normalization cli/api + cli/metrics Bash completion --------------- @@ -23,4 +24,4 @@ Bash completion is supported through ``argcomplete``. .. toctree:: - bash-completion \ No newline at end of file + bash-completion diff --git a/docs/cli/metrics.rst b/docs/cli/metrics.rst new file mode 100644 index 00000000..2101c95b --- /dev/null +++ b/docs/cli/metrics.rst @@ -0,0 +1,9 @@ +Subcommand metrics +================== + +.. argparse:: + :module: benchmarkstt.cli + :func: _parser + :prog: benchmarkstt + :path: metrics + diff --git a/docs/development.rst b/docs/development.rst index 181da28e..0f8d73fb 100644 --- a/docs/development.rst +++ b/docs/development.rst @@ -9,7 +9,7 @@ This assumes :code:`git` and :code:`Python` 3.5 or above are already installed o 1. Fork the `repository source code `_ from github to your own account. 2. Clone the repository from github to your local development environment (replace :code:`[YOURUSERNAME]` with your - github username. + github username). .. code-block:: bash diff --git a/docs/docker.rst b/docs/docker.rst index 5eef0b28..e9530f55 100644 --- a/docs/docker.rst +++ b/docs/docker.rst @@ -4,7 +4,13 @@ Running as a docker image Build the image --------------- -Inside the benchmarkstt folder (see :doc:`INSTALL`) run: +This assumes docker is already installed on your system. + +1. Download the code from github at https://github.com/ebu/ai-benchmarking-stt/archive/master.zip + +2. Unzip the file + +3. Inside the benchmarkstt folder run: .. code-block:: bash diff --git a/docs/index.rst b/docs/index.rst index 0475c00b..c27fe79c 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,15 +1,5 @@ -Welcome to BenchmarkSTT's documentation! -======================================== +.. include:: ../README.rst -.. image:: https://img.shields.io/azure-devops/build/danielthepope/ai-benchmarking-stt/3.svg?logo=azure-devops - :target: https://dev.azure.com/danielthepope/ai-benchmarking-stt/_build/latest?definitionId=3&branchName=master - -.. image:: https://img.shields.io/github/license/ebu/ai-benchmarking-stt.svg - :target: https://opensource.org/licenses/MIT - - - -"BenchmarkSTT: had to call it something..." .. toctree:: :maxdepth: 3 diff --git a/docs/modules/.placeholder b/docs/modules/.placeholder new file mode 100644 index 00000000..2789987f --- /dev/null +++ b/docs/modules/.placeholder @@ -0,0 +1 @@ +placeholder file to make sure directory exists in git diff --git a/docs/modules/benchmarkstt.api.cli.rst b/docs/modules/benchmarkstt.api.cli.rst deleted file mode 100644 index 17bda2f3..00000000 --- a/docs/modules/benchmarkstt.api.cli.rst +++ /dev/null @@ -1,7 +0,0 @@ -benchmarkstt.api.cli module -=========================== - -.. automodule:: benchmarkstt.api.cli - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/modules/benchmarkstt.api.gunicorn.rst b/docs/modules/benchmarkstt.api.gunicorn.rst deleted file mode 100644 index a5240cac..00000000 --- a/docs/modules/benchmarkstt.api.gunicorn.rst +++ /dev/null @@ -1,7 +0,0 @@ -benchmarkstt.api.gunicorn module -================================ - -.. automodule:: benchmarkstt.api.gunicorn - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/modules/benchmarkstt.api.jsonrpc.rst b/docs/modules/benchmarkstt.api.jsonrpc.rst deleted file mode 100644 index b5c0ce9b..00000000 --- a/docs/modules/benchmarkstt.api.jsonrpc.rst +++ /dev/null @@ -1,7 +0,0 @@ -benchmarkstt.api.jsonrpc module -=============================== - -.. automodule:: benchmarkstt.api.jsonrpc - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/modules/benchmarkstt.api.rst b/docs/modules/benchmarkstt.api.rst deleted file mode 100644 index d63fe0b5..00000000 --- a/docs/modules/benchmarkstt.api.rst +++ /dev/null @@ -1,19 +0,0 @@ -benchmarkstt.api package -======================== - -Submodules ----------- - -.. toctree:: - - benchmarkstt.api.cli - benchmarkstt.api.gunicorn - benchmarkstt.api.jsonrpc - -Module contents ---------------- - -.. automodule:: benchmarkstt.api - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/modules/benchmarkstt.cli.rst b/docs/modules/benchmarkstt.cli.rst deleted file mode 100644 index 3bcc9477..00000000 --- a/docs/modules/benchmarkstt.cli.rst +++ /dev/null @@ -1,7 +0,0 @@ -benchmarkstt.cli module -======================= - -.. automodule:: benchmarkstt.cli - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/modules/benchmarkstt.csv.rst b/docs/modules/benchmarkstt.csv.rst deleted file mode 100644 index 8afd7dbf..00000000 --- a/docs/modules/benchmarkstt.csv.rst +++ /dev/null @@ -1,7 +0,0 @@ -benchmarkstt.csv module -======================= - -.. automodule:: benchmarkstt.csv - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/modules/benchmarkstt.decorators.rst b/docs/modules/benchmarkstt.decorators.rst deleted file mode 100644 index 3febc16c..00000000 --- a/docs/modules/benchmarkstt.decorators.rst +++ /dev/null @@ -1,7 +0,0 @@ -benchmarkstt.decorators module -============================== - -.. automodule:: benchmarkstt.decorators - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/modules/benchmarkstt.docblock.rst b/docs/modules/benchmarkstt.docblock.rst deleted file mode 100644 index 102d8a2b..00000000 --- a/docs/modules/benchmarkstt.docblock.rst +++ /dev/null @@ -1,7 +0,0 @@ -benchmarkstt.docblock module -============================ - -.. automodule:: benchmarkstt.docblock - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/modules/benchmarkstt.normalization.cli.rst b/docs/modules/benchmarkstt.normalization.cli.rst deleted file mode 100644 index 288d2c4e..00000000 --- a/docs/modules/benchmarkstt.normalization.cli.rst +++ /dev/null @@ -1,7 +0,0 @@ -benchmarkstt.normalization.cli module -===================================== - -.. automodule:: benchmarkstt.normalization.cli - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/modules/benchmarkstt.normalization.core.rst b/docs/modules/benchmarkstt.normalization.core.rst deleted file mode 100644 index 739cfe7a..00000000 --- a/docs/modules/benchmarkstt.normalization.core.rst +++ /dev/null @@ -1,7 +0,0 @@ -benchmarkstt.normalization.core module -====================================== - -.. automodule:: benchmarkstt.normalization.core - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/modules/benchmarkstt.normalization.logger.rst b/docs/modules/benchmarkstt.normalization.logger.rst deleted file mode 100644 index f5411fcd..00000000 --- a/docs/modules/benchmarkstt.normalization.logger.rst +++ /dev/null @@ -1,7 +0,0 @@ -benchmarkstt.normalization.logger module -======================================== - -.. automodule:: benchmarkstt.normalization.logger - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/modules/benchmarkstt.normalization.rst b/docs/modules/benchmarkstt.normalization.rst deleted file mode 100644 index cff84c04..00000000 --- a/docs/modules/benchmarkstt.normalization.rst +++ /dev/null @@ -1,19 +0,0 @@ -benchmarkstt.normalization package -================================== - -Submodules ----------- - -.. toctree:: - - benchmarkstt.normalization.cli - benchmarkstt.normalization.core - benchmarkstt.normalization.logger - -Module contents ---------------- - -.. automodule:: benchmarkstt.normalization - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/modules/benchmarkstt.rst b/docs/modules/benchmarkstt.rst deleted file mode 100644 index 4137ef9d..00000000 --- a/docs/modules/benchmarkstt.rst +++ /dev/null @@ -1,28 +0,0 @@ -benchmarkstt package -==================== - -Subpackages ------------ - -.. toctree:: - - benchmarkstt.api - benchmarkstt.normalization - -Submodules ----------- - -.. toctree:: - - benchmarkstt.cli - benchmarkstt.csv - benchmarkstt.decorators - benchmarkstt.docblock - -Module contents ---------------- - -.. automodule:: benchmarkstt - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 00000000..683a1363 --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1 @@ +sphinx-rtd-theme>=0.4.2 diff --git a/olddocs/adr/hld/README.md b/olddocs/adr/hld/README.md index 2d39e311..3e893c98 100644 --- a/olddocs/adr/hld/README.md +++ b/olddocs/adr/hld/README.md @@ -31,7 +31,7 @@ Each of these connects to an STT vendor in order to create transcripion jobs. Implement the _transcript format guidelines_ to create consistently formated transcripts prior to analysis. ## Metric analyser -Each of these socres the transcripts returned by the providers using a single metric (for example Word Error Rate). +Each of these socres the transcripts returned by the providers using a single metric (for example Item Error Rate). ## Results metrics Lists the scores for the selected metrics for the selected providers. diff --git a/resources/manpage/benchmarkstt.1 b/resources/manpage/benchmarkstt.1 index e8b9d5ff..a0597be6 100644 --- a/resources/manpage/benchmarkstt.1 +++ b/resources/manpage/benchmarkstt.1 @@ -209,16 +209,19 @@ configuration text "He bravely turned his tail and fled" .TP .B example config -\(aq\(aq\(aq# using a simple config filenLowercase n -.UNINDENT -.sp +# using a simple config file +Lowercase n # it even supports comments -# If there is a space in the argument, make sure you quote it though! +# If there is a space in the argument, make sure you quote it +# though! regexreplace "y t" "Y T" nn # extraneous whitespaces are ignored -replace e an\(aq\(aq\(aq -:example return: "ha bravalY Turnad his tail and flad" +replace e a +.TP +.B example return +"ha bravalY Turnad his tail and flad" +.UNINDENT .TP .B\-\-configfile Load config from a file, see \fBConfig\fP for information about config @@ -422,7 +425,7 @@ checked case insensitive as well with preservation of case.. .INDENT 7.0 .TP .B param str search -Word to search for +Item to search for .TP .B param str replace Replace with diff --git a/setup.py b/setup.py index 9b090399..6224a603 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ __author__ = %s ''' % (repr(__version__), repr(__author__))) -with open('README.md') as f: +with open('README.rst') as f: long_description = f.read() @@ -30,7 +30,7 @@ version=__version__, author=__author__, author_email='temp@example.com', - description='', + description='A library for benchmarking AI/ML applications.', long_description=long_description, classifiers=[ 'Programming Language :: Python', @@ -42,18 +42,15 @@ package_data={'benchmarkstt': ['api/templates/*.html']}, include_package_data=True, install_requires=[ - 'MarkupSafe>=1.0', - 'Unidecode>=1.0.22', - 'langcodes>=1.4.1' + 'MarkupSafe>=1.0', + 'Unidecode>=1.0.22', + 'langcodes>=1.4.1', + 'Flask>=1.0.2', + 'jsonrpcserver>=4.0.1', + 'gunicorn>=19.9.0', + 'docutils>=0.14', ], extras_require={ - 'api': [ - 'Flask>=1.0.2', - 'jsonrpcserver>=4.0.1', - 'gunicorn>=19.9.0', - 'docutils>=0.14', - # 'Pygments>=2.2.0', - ], 'docs': [ "sphinx==1.8.3", "sphinx_rtd_theme==0.4.2", @@ -61,7 +58,8 @@ ], 'test': [ "pytest>=4.2.0", - "pycodestyle==2.5.0" + "pycodestyle==2.5.0", + "pytest-cov>=2.5.1" ] }, platforms='any', diff --git a/src/benchmarkstt/__init__.py b/src/benchmarkstt/__init__.py index 58aab393..a2e6bd7c 100644 --- a/src/benchmarkstt/__init__.py +++ b/src/benchmarkstt/__init__.py @@ -3,20 +3,6 @@ """ from .__meta__ import __author__, __version__ -from importlib import import_module -import textwrap - -modules = ('normalization', 'api') - - -def get_modules(sub_module=None): - postfix = '' if sub_module is None else '.' + sub_module - for module in modules: - yield module, import_module('benchmarkstt.%s%s' % (module, postfix)) - - -def get_modules_dict(sub_module=None): - return {module: cli for module, cli in get_modules(sub_module)} class DeferredStr: @@ -44,6 +30,21 @@ def __repr__(self): return self.__str__() +class DeferredList: + def __init__(self, cb): + self._cb = cb + self._list = None + + @property + def list(self): + if self._list is None: + self._list = self._cb() + return self._list + + def __getitem__(self, item): + return self.list[item] + + def make_printable(char): """ Return printable representation of ascii/utf-8 control characters diff --git a/src/benchmarkstt/__meta__.py b/src/benchmarkstt/__meta__.py index 29311fd5..148d04f2 100644 --- a/src/benchmarkstt/__meta__.py +++ b/src/benchmarkstt/__meta__.py @@ -1,3 +1,3 @@ # Automatically created. DO NOT EDIT -__version__ = '0.0.1' +__version__ = '0.0.2' __author__ = 'EBU' diff --git a/src/benchmarkstt/api/cli.py b/src/benchmarkstt/api/cli.py index ca0d1506..75dff9a1 100644 --- a/src/benchmarkstt/api/cli.py +++ b/src/benchmarkstt/api/cli.py @@ -1,5 +1,7 @@ """ -Make benchmarkstt available through a rudimentary JSON-RPC_ interface +Make benchmarkstt available through a rudimentary JSON-RPC_ interfacea + +Only supported for Python versions 3.6 and above .. _JSON-RPC: https://www.jsonrpc.org @@ -7,7 +9,7 @@ import jsonrpcserver from flask import Flask, request, Response, render_template -from benchmarkstt.docblock import format_docs, parse, rst_to_html +from benchmarkstt.docblock import format_docs, parse, process_rst from .jsonrpc import get_methods @@ -24,12 +26,14 @@ def argparser(parser): help='port used by the server') parser.add_argument('--entrypoint', default='/api', help='the jsonrpc api address') + parser.add_argument('--list-methods', action='store_true', + help='list the available jsonrpc methods') parser.add_argument('--with-explorer', action='store_true', help='also create the explorer to test api calls with, ' 'this is a rudimentary feature currently ' - 'only meant for testing and debugging') - parser.add_argument('--list-methods', action='store_true', - help='list the available jsonrpc methods') + 'only meant for testing and debugging.\n' + 'Warning: the API explorer is provided as-is, without any tests ' + 'or code reviews. This is marked as a low-priority feature.') return parser @@ -55,8 +59,8 @@ def jsonrpc(): response = jsonrpcserver.dispatch(req, methods=methods, debug=True, convert_camel_case=False) return Response(str(response), response.http_status, mimetype="application/json") - if with_explorer: - app.template_filter('parse_rst')(rst_to_html) + if with_explorer: # pragma: nocover + app.template_filter('parse_rst')(process_rst) @app.route(entrypoint, methods=['GET']) def explorer(): @@ -89,5 +93,6 @@ def main(parser, args): print(format_docs(func.__doc__)) print('') else: + app = create_app(args.entrypoint, args.with_explorer) app.run(host=args.host, port=args.port, debug=args.debug) diff --git a/src/benchmarkstt/api/gunicorn.py b/src/benchmarkstt/api/gunicorn.py index 92b9f1ab..8ecc9f64 100644 --- a/src/benchmarkstt/api/gunicorn.py +++ b/src/benchmarkstt/api/gunicorn.py @@ -2,6 +2,5 @@ Entry point for a gunicorn server, serves at /api """ -from .cli import create_app - -application = create_app('/api', with_explorer=True) +from .cli import create_app # pragma: no cover +application = create_app('/api', with_explorer=True) # pragma: no cover diff --git a/src/benchmarkstt/api/jsonrpc.py b/src/benchmarkstt/api/jsonrpc.py index 701033a0..eab888df 100644 --- a/src/benchmarkstt/api/jsonrpc.py +++ b/src/benchmarkstt/api/jsonrpc.py @@ -1,6 +1,9 @@ """ Make benchmarkstt available through a rudimentary JSON-RPC_ interface +.. warning:: + Only supported for Python versions 3.6 and above! + .. _JSON-RPC: https://www.jsonrpc.org """ @@ -8,51 +11,24 @@ import jsonrpcserver import json from benchmarkstt import __meta__ -from benchmarkstt.normalization import available_normalizers, logger from functools import wraps from benchmarkstt.docblock import format_docs -import inspect +from benchmarkstt.modules import Modules +from inspect import _empty, Parameter, signature import os -import benchmarkstt.csv as csv - - -def get_methods() -> jsonrpcserver.methods.Methods: - """ - Returns the available JSON-RPC api methods - - :return: jsonrpcserver.methods.Methods - """ - - methods = jsonrpcserver.methods.Methods() - - def method(f, name=None): - if name is None: - name = f.__name__.lstrip('_').replace('_', '.') - - methods.add(**{name: f}) - - @method - def version(): - """ - Get the version of benchmarkstt - :return str: BenchmarkSTT version - """ - return __meta__.__version__ +class SecurityError(ValueError): + """Trying to do or access something that isn't allowed""" - normalizers = available_normalizers() - @method - def list_normalizers(): - """ - Get a list of available core normalizers +class MagicMethods: + possible_path_args = ['file', 'path'] - :return object: With key being the normalizer name, and value its description - """ - return {name: conf.docs - for name, conf in normalizers.items()} + def __init__(self): + self.methods = jsonrpcserver.methods.Methods() + @staticmethod def is_safe_path(path): """ Determines whether the file or path is within the current working directory @@ -61,24 +37,26 @@ def is_safe_path(path): """ return os.path.abspath(path).startswith(os.path.abspath(os.getcwd())) - class SecurityError(ValueError): - """Trying to do or access something that isn't allowed""" + def serve(self, config, callback): + """ + Responsible for creating a callback with proper documentation and arguments + signature that can be registered as an api call. - def serve_normalizer(config): + :param config: + :param callback: + :return: callable + """ cls = config.cls @wraps(cls) - def _(text, return_logs=None, *args, **kwargs): + def _(*args, **kwargs): # only allow files from cwd to be used... try: - if 'file' in kwargs: - if not is_safe_path(kwargs['file']): - raise SecurityError("Access to unallowed file attempted", 'file') - - if 'path' in kwargs: - if not is_safe_path(kwargs['path']): - raise SecurityError("Access to unallowed directory attempted", 'path') - + # todo (?) add available files and folders as select options + for name in self.possible_path_args: + if name in kwargs: + if not self.is_safe_path(kwargs[name]): + raise SecurityError("Access to unallowed file attempted", name) except SecurityError as e: data = { "message": e.args[0], @@ -86,64 +64,98 @@ def _(text, return_logs=None, *args, **kwargs): } raise AssertionError(json.dumps(data)) - if return_logs: - handler = logger.ListHandler() - handler.setFormatter(logger.DiffLoggingFormatter(dialect='html')) - logger.normalize_logger.addHandler(handler) - - try: - result = { - "text": cls(*args, **kwargs).normalize(text) - } - if return_logs: - logs = handler.flush() - result['logs'] = [] - for log in logs: - result['logs'].append(dict(names=log[0], message=log[1])) - return result - except csv.CSVParserError as e: - message = 'on line %d, character %d' % (e.line, e.char) - message = '\n'.join([e.__doc__, e.message, message]) - data = { - "message": message, - "line": e.line, - "char": e.char, - "index": e.index, - "field": "config" - } - raise AssertionError(json.dumps(data)) - finally: - if return_logs: - logger.normalize_logger.removeHandler(handler) - - # copy signature from original normalizer, and add text param - sig = inspect.signature(cls) - params = list() - params.append(inspect.Parameter('text', kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=str)) - params.extend(sig.parameters.values()) - params.append(inspect.Parameter('return_logs', kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=bool, - default=None)) - sig = sig.replace(parameters=params) + result = callback(cls, *args, **kwargs) + if isinstance(result, tuple) and hasattr(result, '_asdict'): + result = result._asdict() + return result + + # copy signature from original + sig = signature(cls) + + cb_params = signature(callback).parameters.values() + extra_params = [parameter for parameter in cb_params + if parameter.name != 'cls' and + parameter.kind not in (Parameter.VAR_KEYWORD, + Parameter.VAR_POSITIONAL)] + if len(extra_params): + params = list(filter(lambda x: x.default is _empty, extra_params)) + params.extend(sig.parameters.values()) + params.extend(list(filter(lambda x: x.default is not _empty, extra_params))) + sig = sig.replace(parameters=params) + + _.__doc__ += callback.__doc__ _.__signature__ = sig - _.__doc__ += '\n :param str text: The text to normalize' - _.__doc__ += '\n :param bool return_logs: Return normalizer logs' - - # todo (?) add available files and folders as select options return _ - # add each normalizer as its own api call - for conf in normalizers.values(): - method(serve_normalizer(conf), name='normalization.%s' % (conf.name,)) + def load(self, name, module): + """ + Load all possible callbacks for a given module - @method - def _help(): + :param str name: + :param Module module: """ - Returns available api methods + factory = module.factory + callables = list(factory) + + def lister(): + """ + Get a list of available core %s + + :return object: With key being the %s name, and value its description + """ + return {config.name: config.docs for config in callables} + + lister.__doc__ = lister.__doc__ % (name, name) + + self.register("list.%s" % (name,), lister) + + # add each callable as its own api call + for conf in callables: + apicallname = '%s.%s' % (name, conf.name,) + self.register(apicallname, self.serve(conf, module.callback)) - :return object: With key being the method name, and value its description + def register(self, name, callback): """ + Register a callback as an api call + :param str name: + :param callable callback: + """ + self.methods.add(**{name: callback}) + + +class DefaultMethods: + @staticmethod + def version(): + """ + Get the version of benchmarkstt + + :return str: BenchmarkSTT version + """ + return __meta__.__version__ + + @staticmethod + def help(methods): + def _(): + """ + Returns available api methods + + :return object: With key being the method name, and value its description + """ + return {name: format_docs(func.__doc__) for name, func in methods.items.items()} + return _ + + +def get_methods() -> jsonrpcserver.methods.Methods: + """ + Returns the available JSON-RPC api methods + + :return: jsonrpcserver.methods.Methods + """ - return {name: format_docs(func.__doc__) - for name, func in methods.items.items()} + methods = MagicMethods() + methods.register('version', DefaultMethods.version) + for name, module in Modules('api'): + methods.load(name, module) - return methods + methods.register('help', DefaultMethods.help(methods.methods)) + return methods.methods diff --git a/src/benchmarkstt/cli.py b/src/benchmarkstt/cli.py index e6bac7b8..fb390fe4 100644 --- a/src/benchmarkstt/cli.py +++ b/src/benchmarkstt/cli.py @@ -1,16 +1,18 @@ import argparse import logging -from . import get_modules_dict -import textwrap from . import __meta__ - -modules = get_modules_dict('cli') +import textwrap +import itertools +from benchmarkstt.modules import Modules def _parser_no_sub(dont_add_submodule=False): - parser = argparse.ArgumentParser(prog='benchmarkstt', add_help=__name__ != '__main__', + parser = argparse.ArgumentParser(prog='benchmarkstt', add_help=False, description='BenchmarkSTT main command line script') + parser.add_argument('--help', action='help', default=argparse.SUPPRESS, + help=argparse._('show this help message and exit')) + parser.add_argument('--log-level', type=str.lower, default='warning', dest='log_level', choices=list(map(str.lower, logging._nameToLevel.keys())), help='set the logging output level') @@ -18,9 +20,9 @@ def _parser_no_sub(dont_add_submodule=False): parser.add_argument('--version', action='store_true', help='output benchmarkstt version number') - # this is for argpars autodoc purposes - if not dont_add_submodule: - parser.add_argument('subcommand', choices=modules.keys()) + # this is for argparse autodoc purposes + if not dont_add_submodule: # pragma: no cover + parser.add_argument('subcommand', choices=Modules('cli').keys()) return parser @@ -29,20 +31,102 @@ def _parser() -> argparse.ArgumentParser: parser = _parser_no_sub(True) subparsers = parser.add_subparsers(dest='subcommand') - for module, cli in modules.items(): - if not hasattr(cli, 'argparser'): - subparsers.add_parser(module) - continue + for module, cli in Modules('cli'): kwargs = dict() if hasattr(cli, 'Formatter'): kwargs['formatter_class'] = cli.Formatter - kwargs['description'] = textwrap.dedent(cli.__doc__) - subparser = subparsers.add_parser(module, **kwargs) + else: + kwargs['formatter_class'] = ActionWithArgumentsFormatter + + if cli.__doc__ is None: + docs = 'TODO: add description to benchmarkstt.%s.cli' % (module,) + else: + docs = cli.__doc__ + kwargs['description'] = textwrap.dedent(docs) + subparser = subparsers.add_parser(module, add_help=False, **kwargs) + + subparser.add_argument('--help', action='help', default=argparse.SUPPRESS, + help=argparse._('show this help message and exit')) cli.argparser(subparser) return parser +def args_from_factory(action, factory, parser): + for conf in factory: + name = conf.name + docs = conf.docs + + arguments = dict() + arguments['help'] = docs + arguments['nargs'] = 0 + + if len(conf.required_args) or len(conf.optional_args): + arguments['nargs'] = '+' if len(conf.required_args) else '*' + optionals = list(map(lambda x: '[%s]' % x, conf.optional_args)) + arguments['metavar'] = tuple(conf.required_args + optionals) + + arguments['action'] = action_with_arguments(action, + conf.required_args, + conf.optional_args) + + parser.add_argument('--%s' % (name,), **arguments) + + +class _ActionWithArguments: + """ + Placeholder class to recognize an argument is a NormalizerAction in argparse + """ + + +def action_with_arguments(action, required_args, optional_args): + """ + Custom argparse action to support a variable amount of arguments + :param str action: name of the action + :param list required_args: required arguments + :param list optional_args: optional arguments + :rtype: ActionWithArguments + """ + + minlen = len(required_args) + maxlen = minlen + len(optional_args) + + class ActionWithArguments(argparse.Action, _ActionWithArguments): + def __call__(self, parser, args, values, option_string=None): + if len(values) < minlen or len(values) > maxlen: + raise argparse.ArgumentTypeError('argument "%s" requires between %d and %d arguments (got %d)' % + (self.dest, minlen, maxlen, len(values))) + + if not hasattr(args, action): + setattr(args, action, []) + + getattr(args, action).append([self.dest] + values) + + return ActionWithArguments + + +class ActionWithArgumentsFormatter(argparse.HelpFormatter): + """ + Custom formatter for argparse that allows us to properly display _ActionWithArguments and docblock documentation + """ + + def _format_args(self, action, default_metavar): + if isinstance(action, _ActionWithArguments): + return ' '.join(action.metavar) + + return super()._format_args(action, default_metavar) + + def _split_lines(self, text, width): + def wrap(txt): + if txt == '': + return [''] + return textwrap.wrap(txt, width=width) + + text = text.splitlines() + text = list(itertools.chain.from_iterable(map(wrap, text))) + return text + + def main(): parser = _parser() @@ -50,20 +134,22 @@ def main(): # support argument completion if package is installed import argcomplete argcomplete.autocomplete(parser) - except ImportError: + except ImportError: # pragma: no cover pass + args = parser.parse_args() if args.version: print("benchmarkstt: %s" % (__meta__.__version__,)) parser.exit(0) logging.basicConfig(level=args.log_level.upper()) - logger = logging.getLogger().setLevel(args.log_level.upper()) + logging.getLogger().setLevel(args.log_level.upper()) if not args.subcommand: parser.error("expects at least 1 argument") - modules[args.subcommand].main(parser, args) + Modules('cli')[args.subcommand].main(parser, args) + exit(0) -if __name__ == '__main__': +if __name__ == '__main__': # pragma: nocover main() diff --git a/src/benchmarkstt/csv.py b/src/benchmarkstt/csv.py index 01d6409e..b326a519 100644 --- a/src/benchmarkstt/csv.py +++ b/src/benchmarkstt/csv.py @@ -5,7 +5,7 @@ import typing import sys from functools import partial -from benchmarkstt import make_printable +from benchmarkstt import DeferredList, make_printable class InvalidDialectError(ValueError): @@ -113,11 +113,6 @@ def _is_ignore_left(self, char: str): return False return char in self._dialect.trimleft - def _is_ignore_right(self, char: str): - if self._dialect.trimright is None: - return False - return char in self._dialect.trimright - def _is_comment(self, char: str): if self._dialect.commentchar is None: return False @@ -135,13 +130,34 @@ def __iter__(self): readchar = iter(partial(self._file.read, 1), '') cur_line = 1 + if self._debug: + current_module = sys.modules[__name__] + # print the color key the different modes + print('MODES: ', end='') + print(' '.join(['\033[1;%d;40m%s\033[0;0m' % (32 + getattr(current_module, name), name[5:]) + for name in dir(current_module) + if name.startswith('MODE_') + ])) + + def debug(txt='', args=tuple(), **kwargs): + if type(args) is not tuple: + args = tuple(DeferredList(args)) + print(txt % args, **kwargs) + pass + else: + def debug(*args, **kwargs): + pass + newlinechars = '\n\r' mode = MODE_FIRST field = [] line = Line() - delimiter_is_whitespace = self._dialect.delimiter in self._dialect.trimright + if self._dialect.trimright is not None: + delimiter_is_whitespace = self._dialect.delimiter in self._dialect.trimright + else: + delimiter_is_whitespace = False def yield_line(): nonlocal line, field, mode, delimiter_is_whitespace, is_newline, cur_line @@ -166,16 +182,6 @@ def next_field(): field = [] mode = MODE_OUTSIDE - debug = self._debug - if debug: - # print the color key the different modes - current_module = sys.modules[__name__] - print('MODES: ', end='') - print(' '.join(['\033[1;%d;40m%s\033[0;0m' % (32 + getattr(current_module, name), name[5:]) - for name in dir(current_module) - if name.startswith('MODE_') - ])) - cur_char = 0 last_quote_line = None last_quote_char = None @@ -185,9 +191,8 @@ def next_field(): cur_char += 1 idx += 1 - if debug: - # print char to stdout with color defining mode - print('\033[1;%d;40m%s\033[0;0m' % (32+mode, make_printable(char)), end='') + # print char to stdout with color defining mode + debug('\033[1;%d;40m%s\033[0;0m', lambda: (32+mode, make_printable(char)), end='') is_newline = char in newlinechars if is_newline: @@ -266,8 +271,7 @@ def next_field(): field.append(char) continue - if debug: - print() + debug() if mode == MODE_INSIDE_QUOTED: raise UnclosedQuoteError("Unexpected end", last_quote_line, last_quote_char, last_quote_idx) @@ -276,7 +280,7 @@ def next_field(): yield yield_line() -def reader(file: typing.io.TextIO, dialect: typing.Union[None, str, Dialect] = None) -> Reader: +def reader(file: typing.io.TextIO, dialect: typing.Union[None, str, Dialect] = None, **kwargs) -> Reader: if dialect is None: dialect = DefaultDialect elif type(dialect) is str: @@ -284,4 +288,4 @@ def reader(file: typing.io.TextIO, dialect: typing.Union[None, str, Dialect] = N raise UnknownDialectError("Dialect not known", dialect) dialect = known_dialects[dialect] - return Reader(file, dialect) + return Reader(file, dialect, **kwargs) diff --git a/src/benchmarkstt/decorators.py b/src/benchmarkstt/decorators.py index ca9a350a..7f7ad1c1 100644 --- a/src/benchmarkstt/decorators.py +++ b/src/benchmarkstt/decorators.py @@ -4,42 +4,59 @@ def log_call(logger: logging.Logger, log_level=None, result=None): """ - Decorator to log all calls to decorated function to a given logger + Decorator to log all calls to decorated function to given logger - >>> import logging, sys - >>> logging.basicConfig(stream=sys.stdout, format='%(levelname)s:%(name)s: %(message)s') - >>> logger = logging.getLogger('logger_name') - >>> logger.setLevel(logging.DEBUG) + >>> import logging, sys, io + >>> + >>> def get_logger(): + ... logger = logging.getLogger('logger_name') + ... logger.setLevel(logging.DEBUG) + ... stream = io.StringIO() + ... ch = logging.StreamHandler(stream) + ... ch.setLevel(logging.DEBUG) + ... ch.setFormatter(logging.Formatter('%(levelname)s:%(name)s: %(message)s')) + ... logger.addHandler(ch) + ... return logger, stream + >>> + >>> logger, stream = get_logger() >>> @log_call(logger, logging.WARNING) ... def test(*args, **kwargs): ... return 'result' >>> test('arg1', arg2='someval', arg3='someotherval') - WARNING:logger_name: CALL test('arg1', arg2='someval', arg3='someotherval') 'result' + >>> print(stream.getvalue().strip()) + WARNING:logger_name: test('arg1', arg2='someval', arg3='someotherval') + >>> logger, stream = get_logger() >>> @log_call(logger, result=True) ... def test(*args, **kwargs): ... return 'result' - >>> test(arg2='someval', arg3='someotherval', arg4=None) - DEBUG:logger_name: CALL test(arg2='someval', arg3='someotherval', arg4=None) - DEBUG:logger_name: RESULT test = 'result' + >>> test(arg2='someval', arg3='someotherval') 'result' + >>> print(stream.getvalue().strip()) + DEBUG:logger_name: test(arg2='someval', arg3='someotherval') + DEBUG:logger_name: test returned: result """ if log_level is None: log_level = logging.DEBUG - def _log_call(func): - nonlocal log_level, result - + def _log_call(func: callable): def _(*args, **kwargs): - nonlocal log_level, result - logger.log(log_level, 'CALL %s(%s)', - func.__name__, - DeferredStr(lambda: ', '.join([repr(a) for a in args] + - [k + '=' + repr(v) for k, v in kwargs.items()]))) + arguments_format = [] + arguments_list = [] + if len(args): + arguments_format.append('%s') + arguments_list.append(DeferredStr(lambda: ', '.join([repr(a) for a in args]))) + if len(kwargs): + arguments_format.append('%s') + arguments_list.append(DeferredStr(lambda: ', '.join([k + '=' + repr(kwargs[k]) for k in kwargs]))) + + arguments_format = '%s(%s)' % (func.__name__, ', '.join(arguments_format)) + + logger.log(log_level, arguments_format, *arguments_list) result_ = func(*args, **kwargs) if result: - logger.log(log_level, 'RESULT %s = %s', func.__name__, DeferredStr(lambda: repr(result_))) + logger.log(log_level, '%s returned: %s', func.__name__, result_) return result_ return _ return _log_call diff --git a/src/benchmarkstt/diff/__init__.py b/src/benchmarkstt/diff/__init__.py new file mode 100644 index 00000000..1253541b --- /dev/null +++ b/src/benchmarkstt/diff/__init__.py @@ -0,0 +1,26 @@ +from benchmarkstt.factory import Factory + + +class Base: + def __init__(self, a='', b=''): + raise NotImplementedError() + + def get_opcodes(self): + """ + Return list of 5-tuples describing how to turn `a` into `b`. + + Each tuple is of the form `(tag, i1, i2, j1, j2)`. The first tuple has + `i1 == j1 == 0`, and remaining tuples have `i1` equals the `i2` from the + tuple preceding it, and likewise for `j1` equals the previous `j2`. + + The tags are strings, with these meanings: + + - 'replace': `a[i1:i2]` should be replaced by `b[j1:j2]` + - 'delete': `a[i1:i2]` should be deleted. Note that `j1==j2` in this case. + - 'insert': `b[j1:j2]` should be inserted at `a[i1:i1]`. Note that `i1==i2` in this case. + - 'equal': `a[i1:i2] == b[j1:j2]` + """ + raise NotImplementedError() + + +factory = Factory(Base) diff --git a/src/benchmarkstt/diff/core.py b/src/benchmarkstt/diff/core.py new file mode 100644 index 00000000..40001efa --- /dev/null +++ b/src/benchmarkstt/diff/core.py @@ -0,0 +1,46 @@ +from difflib import SequenceMatcher +from benchmarkstt.diff import Base + + +# class HuntMcIlroy: +# """ +# Implements the Hunt–McIlroy algorithm. +# +# More information available at https://en.wikipedia.org/wiki/Hunt%E2%80%93McIlroy_algorithm +# +# Mimics structure of difflib.SequenceMatcher +# +# Status: TODO make a proper diff implementing Hunt–McIlroy algorithm +# (see https://github.com/ebu/ai-benchmarking-stt/issues/30) +# +# """ +# +# def __init__(self, a='', b=''): +# self.a = a +# self.b = b +# self.opcodes = None +# self.matching_blocks = None +# +# def set_seqs(self, a, b): +# self.a = a +# self.b = b +# +# def set_seq1(self, a): +# self.a = a +# +# def set_seq2(self, b): +# self.b = b +# +# def find_longest_match(self, alo, ahi, blo, bhi): +# raise NotImplementedError() +# +# # re-use get_matching_blocks and get_opcodes from difflib +# get_matching_blocks = SequenceMatcher.get_matching_blocks +# get_opcodes = SequenceMatcher.get_opcodes + + +class RatcliffObershelp(SequenceMatcher, Base): + def __init__(self, a, b, *args, **kwargs): + if 'autojunk' not in kwargs: + kwargs['autojunk'] = False + super().__init__(a=a, b=b, *args, **kwargs) diff --git a/src/benchmarkstt/diff/formatter.py b/src/benchmarkstt/diff/formatter.py new file mode 100644 index 00000000..431bc89f --- /dev/null +++ b/src/benchmarkstt/diff/formatter.py @@ -0,0 +1,123 @@ +import logging +from benchmarkstt import make_printable +import difflib +from markupsafe import escape + +logger = logging.getLogger(__name__) + + +class Dialect: + preprocessor = None + delete_format = '%s' + insert_format = '%s' + equal_format = '%s' + replace_format = None + + @staticmethod + def format(names, diff): + raise NotImplementedError() + + +class CLIDiffDialect(Dialect): + preprocessor = make_printable + delete_format = '\033[31m%s\033[0m' + insert_format = '\033[32m%s\033[0m' + + @staticmethod + def format(names, diff): + return '|'.join(names) + ': ' + diff + + +class UTF8Dialect(Dialect): + preprocessor = make_printable + + @staticmethod + def delete_format(txt): + return ''.join(c + '\u0338' for c in txt) + + @staticmethod + def insert_format(txt): + return ''.join(c + '\u0359' for c in txt) + + @staticmethod + def format(names, diff): + return '|'.join(names) + ': ' + diff + + +class HTMLDiffDialect(Dialect): + preprocessor = escape + delete_format = '%s' + insert_format = '%s' + + @staticmethod + def format(names, diff): + return names, diff + + +class DiffFormatter: + diff_dialects = { + "cli": CLIDiffDialect, + "html": HTMLDiffDialect, + "text": UTF8Dialect + } + + def __init__(self, dialect=None): + if dialect is None: + dialect = 'text' + if dialect not in self.diff_dialects: + raise ValueError("Unknown diff dialect", dialect) + self._dialect = self.diff_dialects[dialect] + + def format(self, record): + return self._dialect.format(record.args[0], self.diff(record.args[1], record.args[2])) + + def diff(self, a, b, opcodes=None, preprocessor=None): + dialect = self._dialect + + def format_string(formatting): + def _(*args): + return formatting % args + return _ + + formats = dict(insert=None, delete=None, equal=None, replace=None) + + for f in formats.keys(): + formatter = getattr(dialect, f + '_format') + if type(formatter) is str: + formats[f] = format_string(formatter) + else: + formats[f] = formatter + + if formats['replace'] is None: + def _(deleted, inserted): + return formats['delete'](deleted) + formats['insert'](inserted) + formats['replace'] = _ + + if preprocessor is not None: + def _pre(txt): + return dialect.preprocessor(preprocessor(txt)) + else: + _pre = dialect.preprocessor + if opcodes is None: + opcodes = difflib.SequenceMatcher(None, a, b).get_opcodes() + + res = [] + for tag, alo, ahi, blo, bhi in opcodes: + a_ = _pre(a[alo:ahi]) + + if tag in ('equal', 'delete'): + res.append(formats[tag](_pre(a[alo:ahi]))) + continue + + b_ = _pre(b[blo:bhi]) + if tag == 'insert': + res.append(formats[tag](b_)) + continue + + res.append(formats[tag](a_, b_)) + return ''.join(res) + + +def format_diff(a, b, opcodes=None, dialect=None, preprocessor=None): + formatter = DiffFormatter(dialect) + return formatter.diff(a, b, opcodes, preprocessor) diff --git a/src/benchmarkstt/docblock.py b/src/benchmarkstt/docblock.py index f9d840dc..c38a58c3 100644 --- a/src/benchmarkstt/docblock.py +++ b/src/benchmarkstt/docblock.py @@ -6,6 +6,7 @@ import logging from docutils.core import publish_string from docutils.writers import html5_polyglot +import docutils logger = logging.getLogger(__name__) @@ -26,7 +27,7 @@ def doc_param_parser(docstring, key, no_name=None, allow_multiple=None, if replace_strat is None: def replace_strat(match, param): - return match[0] + return match.group(0) elif type(replace_strat) is str: _replace_strat = replace_strat @@ -37,32 +38,33 @@ def replace_strat(match, param): def _(match): nonlocal results, key, no_name, replace_strat if no_name: - param = dict(name=key, type=match[1], value=match[2]) + param = dict(name=key, type=match.group(1), value=match.group(2)) return_val = replace_strat(match, param) results.append(DocblockParam(**param)) else: - param = dict(name=match[2], type=match[1], value=match[3]) + value = textwrap.dedent(match.group(3)).strip() + param = dict(name=match.group(2), type=match.group(1), value=value) return_val = replace_strat(match, param) param = DocblockParam(**param) if allow_multiple: # check if it already exists, if not create a new object idx = [idx for idx, val in enumerate(results) - if match[2] not in val] + if match.group(2) not in val] if not len(idx): idx = len(results) results.append({}) else: idx = idx[0] - results[idx][match[2]] = param + results[idx][match.group(2)] = param else: - results[match[2]] = param + results[match.group(2)] = param return return_val if no_name: regex = r'^[ \t]*:%s[ \t]*([a-z_]+)?:[ \t]+(.*)$' else: - regex = r'^[ \t]*:%s[ \t]+(?:([^:]+)[ \t]+)?([a-z_]+):(?:[ \t]+(.*))?$' + regex = r'^[ \t]*:%s[ \t]+(?:([^:]+)[ \t]+)?([a-z_]+):[ \t]*(.+$|(?:$[ \t]*\n)*([ \t]+)([^\n]*)$(?:\4.*\n|\n)+)' docs = re.sub( regex % (re.escape(key),), _, docstring, flags=re.MULTILINE @@ -97,7 +99,10 @@ def parse(func): docs, doc_result = doc_param_parser(docs, 'return', no_name=True) def decode_examples(match, param): - param['value'] = decode_literal(param['value']) + if match.group(5) is None: + param['value'] = decode_literal(param['value']) + else: + param['value'] = process_rst(param['value'], 'text') return '' docs, examples = doc_param_parser(docs, 'example', allow_multiple=True, @@ -136,8 +141,32 @@ def apply_template(self): return subs['body'] -def rst_to_html(text): - writer = HTML5Writer() +class TextWriter(docutils.writers.Writer): + class TextVisitor(docutils.nodes.SparseNodeVisitor): + _text = '' + + def visit_Text(self, node): + self._text += node.astext() + + def visit_paragraph(self, node): + self._text += '\n\n' + + def text(self): + return self._text + + def translate(self): + visitor = self.TextVisitor(self.document) + self.document.walkabout(visitor) + self.output = visitor.text() + + +def process_rst(text, writer=None): + if writer is None or writer == 'html': + writer = HTML5Writer() + elif writer == 'text': + writer = TextWriter() + elif type(writer) is str: + raise ValueError("Unknown writer %s", str) settings = {'output_encoding': 'unicode', 'table_style': 'table'} return publish_string(text, writer=writer, writer_name='html5', settings_overrides=settings) diff --git a/src/benchmarkstt/factory.py b/src/benchmarkstt/factory.py new file mode 100644 index 00000000..f2a0e554 --- /dev/null +++ b/src/benchmarkstt/factory.py @@ -0,0 +1,148 @@ +import inspect +from benchmarkstt import DeferredRepr +import logging +from importlib import import_module +from benchmarkstt.docblock import format_docs +from collections import namedtuple +from typing import Dict + +logger = logging.getLogger(__name__) + +ClassConfig = namedtuple('ClassConfig', ['name', 'cls', 'docs', 'optional_args', 'required_args']) + + +class Factory: + """ + Factory class with auto-loading of namespaces according to a base class. + """ + + def __init__(self, base_class, namespaces=None): + self.base_class = base_class + if namespaces is None: + self.namespaces = [base_class.__module__ + '.core'] + else: + self.namespaces = namespaces + + self._registry = {} + + for namespace in self.namespaces: + self.register_namespace(namespace) + + def create(self, alias, *args, **kwargs): + return self.get_class(alias)(*args, **kwargs) + + @staticmethod + def normalize_class_name(clsname): + """ + Normalizes the class name for automatic lookup of a class, by default + this means lowercasing the class name, but may be overrided by a child + class. + + :param str clsname: The class name + :return: The normalized class name + :rtype: str + """ + return clsname.lower() + + def get_class(self, name): + """ + Loads the proper class based on a name + + :param str name: Case-insensitive name of the class + :return: The class + :rtype: class + """ + name = self.normalize_class_name(name) + if name not in self._registry: + raise ImportError("Could not find class '%s'" % (name,)) + + return self._registry[name] + + def is_valid(self, tocheck): + """ + Checks that tocheck is a valid class extending base_class + + :param class tocheck: The class to check + :rtype: bool + """ + + if tocheck is self.base_class: + return False + if not inspect.isclass(tocheck): + return False + if issubclass(tocheck, self.base_class): + return True + logger.info('Not a valid class (must inherit from Base class): "%s"', DeferredRepr(tocheck)) + return False + + def register_namespace(self, namespace): + """ + Registers all valid classes from a given namespace + + :param str|module namespace: + """ + + module = '.'.join(filter(len, namespace.split('.'))) + if module == '': + module = globals() + else: + module = import_module(module) + + for clsname in dir(module): + cls = getattr(module, clsname) + if not self.is_valid(cls): + continue + clsname = self.normalize_class_name(clsname) + if clsname in self._registry: + raise ValueError("Conflict: alias '%s' is already registered" % (clsname,)) + self._registry[clsname] = cls + + def register(self, cls, alias=None): + """ + Register an alias for a class + + :param self.base_class cls: + :param str|None alias: The alias to use when trying to get the class back, + by default will use normalized class name. + :return: None + """ + if not self.is_valid(cls): + raise ValueError('Invalid class (must inherit from Base class)"') + + if alias is None: + alias = cls.__name__ + + alias = self.normalize_class_name(alias) + if alias in self._registry: + raise ValueError("Conflict: alias '%s' is already registered" % (alias,)) + self._registry[alias] = cls + + def __iter__(self): + """ + Get available classes with a proper ClassConfig + + :return: A dictionary of registered classes + :rtype: Dict[str, ClassConfig] + """ + + for clsname, cls in self._registry.items(): + if cls.__doc__ is None: + docs = '' + logger.warning("No docstring for '%s'", cls.__name__) + else: + docs = cls.__doc__ + docs = format_docs(docs) + + argspec = inspect.getfullargspec(cls.__init__) + args = list(argspec.args)[1:] + defaults = [] + if argspec.defaults: + defaults = list(argspec.defaults) + + defaults_idx = len(args) - len(defaults) + required_args = args[0:defaults_idx] + optional_args = args[defaults_idx:] + + yield ClassConfig(name=clsname, cls=cls, docs=docs, + optional_args=optional_args, + required_args=required_args) diff --git a/src/benchmarkstt/input/__init__.py b/src/benchmarkstt/input/__init__.py new file mode 100644 index 00000000..4d483c7e --- /dev/null +++ b/src/benchmarkstt/input/__init__.py @@ -0,0 +1,18 @@ +""" +Subpackage responsible for dealing with input formats and converting them to benchmarkstt native schema +""" + +from benchmarkstt.factory import Factory + + +class Base: + def __iter__(self): + """ + Each input class should be accessible as iterator, each iteration should + return a Item, so the input format is essentially usable and can be easily + converted to a :py:class:`benchmarkstt.schema.Schema` + """ + raise NotImplementedError() + + +factory = Factory(Base) diff --git a/src/benchmarkstt/input/core.py b/src/benchmarkstt/input/core.py new file mode 100644 index 00000000..b193aeb2 --- /dev/null +++ b/src/benchmarkstt/input/core.py @@ -0,0 +1,56 @@ +""" +Default input formats + +""" + +import benchmarkstt.segmentation.core as segmenters +from benchmarkstt import input + + +class PlainText(input.Base): + def __init__(self, text, segmenter=None): + if segmenter is None: + segmenter = segmenters.Simple + self._text = text + self._segmenter = segmenter + + def __iter__(self): + return iter(self._segmenter(self._text)) + + +class File(input.Base): + """ + Load the input class based on a file + """ + + _extension_to_class = { + "txt": PlainText, + "json": None + } + + def __init__(self, file, input_type=None): + if input_type is None or input_type == 'infer': + if '.' not in file: + raise ValueError('Cannot infer input file type of files without an extension') + + extension = file.rsplit('.', 1)[1].lower() + if extension not in self._extension_to_class: + raise ValueError('Cannot infer input file type for files of extension %s' % (extension,)) + + input_type = self._extension_to_class[extension] + + with open(file): + """Just checks that file is readable...""" + + self._file = file + + if type(input_type) is str: + input_type = input.factory.get_class(input_type) + + self._input_class = input_type + + def __iter__(self): + with open(self._file) as f: + text = f.read() + + return iter(self._input_class(text)) diff --git a/src/benchmarkstt/metrics/__init__.py b/src/benchmarkstt/metrics/__init__.py new file mode 100644 index 00000000..f96f9413 --- /dev/null +++ b/src/benchmarkstt/metrics/__init__.py @@ -0,0 +1,13 @@ +from benchmarkstt.schema import Schema +from benchmarkstt.factory import Factory + + +class Base: + """ + Base class for metrics + """ + def compare(self, ref: Schema, hyp: Schema): + raise NotImplementedError() + + +factory = Factory(Base) diff --git a/src/benchmarkstt/metrics/api.py b/src/benchmarkstt/metrics/api.py new file mode 100644 index 00000000..4edb25a4 --- /dev/null +++ b/src/benchmarkstt/metrics/api.py @@ -0,0 +1,14 @@ +from benchmarkstt.input.core import PlainText +import benchmarkstt.metrics as metrics + +factory = metrics.factory + + +def callback(cls, ref: str, hyp: str, *args, **kwargs): + """ + :param ref: Reference text + :param hyp: Hypothesis text + """ + ref = PlainText(ref) + hyp = PlainText(hyp) + return cls(*args, **kwargs).compare(ref, hyp) diff --git a/src/benchmarkstt/metrics/cli.py b/src/benchmarkstt/metrics/cli.py new file mode 100644 index 00000000..33210afe --- /dev/null +++ b/src/benchmarkstt/metrics/cli.py @@ -0,0 +1,62 @@ +""" +Calculate metrics based on the comparison of a hypothesis with a reference. +""" + +from benchmarkstt.input import core +from benchmarkstt.metrics import factory +from benchmarkstt.cli import args_from_factory +import argparse + + +def argparser(parser: argparse.ArgumentParser): + # steps: input normalize[pre?] segmentation normalize[post?] compare + parser.add_argument('-r', '--reference', required=True, + help='The file to use as reference') + parser.add_argument('-h', '--hypothesis', required=True, + help='The file to use as hypothesis') + + parser.add_argument('-rt', '--reference-type', default='infer', + help='Type of reference file') + parser.add_argument('-ht', '--hypothesis-type', default='infer', + help='Type of hypothesis file') + + # parser.add_argument('-m', '--metric', default='wer', nargs='+', + # help='The type of metric(s) to run') + + metrics_desc = " A list of metrics to calculate. At least one metric needs to be provided." + + subparser = parser.add_argument_group('available metrics', description=metrics_desc) + args_from_factory('metrics', factory, subparser) + return parser + + +def main(parser, args): + if args.reference_type == 'argument': + ref = core.PlainText(args.reference) + else: + ref = core.File(args.reference, args.reference_type) + + if args.hypothesis_type == 'argument': + hyp = core.PlainText(args.hypothesis) + else: + hyp = core.File(args.hypothesis, args.hypothesis_type) + + ref = list(ref) + hyp = list(hyp) + + if 'metrics' not in args or not len(args.metrics): + parser.error("need at least one metric") + + for item in args.metrics: + metric_name = item.pop(0).replace('-', '.') + print(metric_name) + print('=' * len(metric_name)) + print() + metric = factory.create(metric_name, *item) + # todo: different output options + result = metric.compare(ref, hyp) + if type(result) is float: + print("%.6f" % (result,)) + else: + print(result) + print() diff --git a/src/benchmarkstt/metrics/core.py b/src/benchmarkstt/metrics/core.py new file mode 100644 index 00000000..8fa9c7fc --- /dev/null +++ b/src/benchmarkstt/metrics/core.py @@ -0,0 +1,109 @@ +from benchmarkstt.schema import Schema +import logging +from benchmarkstt.diff.core import RatcliffObershelp +from benchmarkstt.diff.formatter import format_diff +from benchmarkstt.metrics import Base +from collections import namedtuple + +logger = logging.getLogger(__name__) + +OpcodeCounts = namedtuple('OpcodeCounts', + ('equal', 'replace', 'insert', 'delete')) + + +def traversible(schema, key=None): + if key is None: + key = 'item' + return [word[key] for word in schema] + + +def get_opcode_counts(opcodes): + counts = OpcodeCounts(0, 0, 0, 0)._asdict() + for tag, alo, ahi, blo, bhi in opcodes: + if tag in ['equal', 'replace', 'delete']: + counts[tag] += ahi - alo + elif tag == 'insert': + counts[tag] += bhi - blo + return OpcodeCounts(counts['equal'], counts['replace'], counts['insert'], counts['delete']) + + +def get_differ(a, b, differ_class): + if differ_class is None: + # differ_class = HuntMcIlroy + differ_class = RatcliffObershelp + return differ_class(traversible(a), traversible(b)) + + +class WordDiffs(Base): + """ + Calculate the differences on a per-word basis + """ + + def __init__(self, differ_class=None, dialect=None): + self._differ_class = differ_class + if dialect is None: + dialect = 'cli' + self._dialect = dialect + + def compare(self, ref: Schema, hyp: Schema): + differ = get_differ(ref, hyp, differ_class=self._differ_class) + a = traversible(ref) + b = traversible(hyp) + return format_diff(a, b, differ.get_opcodes(), + dialect=self._dialect, + preprocessor=lambda x: ' %s' % (' '.join(x),)) + + +class WER(Base): + """ + Word Error Rate, basically defined as: + + .. code-block:: text + + insertions + deletions + substitions + ------------------------------------ + number of reference words + + See: https://en.wikipedia.org/wiki/Word_error_rate + """ + + # TODO: proper documenting of different modes + MODE_STRICT = 0 + MODE_HUNT = 1 + + DEL_PENALTY = 1 + INS_PENALTY = 1 + SUB_PENALTY = 1 + + def __init__(self, mode=None, differ_class=None): + if differ_class is None: + differ_class = RatcliffObershelp + self._differ_class = differ_class + if mode is self.MODE_HUNT: + self.DEL_PENALTY = self.INS_PENALTY = .5 + + def compare(self, ref: Schema, hyp: Schema): + diffs = get_differ(ref, hyp, differ_class=self._differ_class) + + counts = get_opcode_counts(diffs.get_opcodes()) + + changes = counts.replace * self.SUB_PENALTY + \ + counts.delete * self.DEL_PENALTY + \ + counts.insert * self.INS_PENALTY + + return changes / (counts.equal + changes) + + +class DiffCounts(Base): + """ + Get the amount of differences between reference and hypothesis + """ + + def __init__(self, differ_class=None): + if differ_class is None: + differ_class = RatcliffObershelp + self._differ_class = differ_class + + def compare(self, ref: Schema, hyp: Schema): + diffs = get_differ(ref, hyp, differ_class=self._differ_class) + return get_opcode_counts(diffs.get_opcodes()) diff --git a/src/benchmarkstt/modules.py b/src/benchmarkstt/modules.py new file mode 100644 index 00000000..d3c10f66 --- /dev/null +++ b/src/benchmarkstt/modules.py @@ -0,0 +1,32 @@ +import sys +from importlib import import_module + +_modules = ['normalization', 'metrics'] + +if sys.version_info >= (3, 6): + _modules.append('api') + + +class Modules: + def __init__(self, sub_module): + self._postfix = '' if sub_module is None else '.' + sub_module + + def __iter__(self): + for module in _modules: + try: + yield (module, self[module]) + except IndexError: + pass + + def __getattr__(self, name): + return self[name] + + def __getitem__(self, key): + name = 'benchmarkstt.%s%s' % (key, self._postfix) + try: + return import_module(name) + except ImportError: + raise IndexError('Module not found', key) + + def keys(self): + return [key for key, value in iter(self)] diff --git a/src/benchmarkstt/normalization/__init__.py b/src/benchmarkstt/normalization/__init__.py index ad013ae3..e2873351 100644 --- a/src/benchmarkstt/normalization/__init__.py +++ b/src/benchmarkstt/normalization/__init__.py @@ -1,99 +1,34 @@ -from collections import namedtuple -import inspect -from importlib import import_module -from typing import Dict -from benchmarkstt.docblock import format_docs +from benchmarkstt.normalization.logger import log +import logging +from benchmarkstt.factory import Factory _normalizer_namespaces = ( "benchmarkstt.normalization.core", "" ) -NormalizerConfig = namedtuple('NormalizerConfig', ['name', 'cls', 'docs', 'optional_args', 'required_args']) +logger = logging.getLogger(__name__) -def is_normalizer(cls): - return inspect.isclass(cls) and hasattr(cls, 'normalize') +class Base: + @log + def normalize(self, text: str) -> str: + """ + Returns normalized text with rules supplied by the called class. + """ + return self._normalize(text) -def available_normalizers() -> Dict[str, NormalizerConfig]: - normalizers = {} - core = import_module('benchmarkstt.normalization.core') - for cls in dir(core): - name = cls.lower() - cls = getattr(core, cls) - if not is_normalizer(cls): - continue - - docs = format_docs(cls.__doc__) - # docs = docs.split(':param', 1)[0] - # remove rst blocks - # docs = re.sub(r'^\s*\.\. [a-z-]+::\s+[a-z]+\s*$', '', docs, flags=re.MULTILINE) - - argspec = inspect.getfullargspec(cls.__init__) - args = list(argspec.args)[1:] - defaults = [] - if argspec.defaults: - defaults = list(argspec.defaults) - - defaults_idx = len(args) - len(defaults) - required_args = args[0:defaults_idx] - optional_args = args[defaults_idx:] - - normalizers[name] = NormalizerConfig(name=name, cls=cls, docs=docs, - optional_args=optional_args, required_args=required_args) - - return normalizers - - -def name_to_normalizer(name): - """ - Loads the proper normalizer based on a name - - :param str name: Case-insensitive name of the normalizer - :return: The normalization class - :rtype: class - """ - requested = name.split('.') - requested_module = [] - - if len(requested) > 1: - requested_module = requested[:-1] - - requested_class = requested[-1] - lname = requested_class.lower() - for lookup in _normalizer_namespaces: - try: - module = '.'.join(filter(len, lookup.split('.') + requested_module)) - if module == '': - continue - module = import_module(module) - - if hasattr(module, requested_class): - cls = getattr(module, requested_class) - if inspect.isclass(cls) and hasattr(cls, 'normalize'): - return cls - - # fallback, check case-insensitive matches - realname = [class_name for class_name in dir(module) - if class_name.lower() == lname and - is_normalizer(getattr(module, class_name))] + def _normalize(self, text: str) -> str: + raise NotImplementedError() - if len(realname) > 1: - raise ImportError("Cannot determine which class to use for '$s': %s" % - (lname, repr(realname))) - elif len(realname): - return getattr(module, realname[0]) - except ModuleNotFoundError: - pass - raise ImportError("Could not find normalizer '%s'" % (name,)) +factory = Factory(Base, _normalizer_namespaces) -class NormalizationComposite: +class NormalizationComposite(Base): """ Combining normalizers - """ def __init__(self): @@ -104,7 +39,7 @@ def add(self, normalizer): """ self._normalizers.append(normalizer) - def normalize(self, text: str) -> str: + def _normalize(self, text: str) -> str: # allow for an empty file if not self._normalizers: return text diff --git a/src/benchmarkstt/normalization/api.py b/src/benchmarkstt/normalization/api.py new file mode 100644 index 00000000..770f39de --- /dev/null +++ b/src/benchmarkstt/normalization/api.py @@ -0,0 +1,42 @@ +from benchmarkstt.normalization.logger import ListHandler, DiffLoggingFormatter, normalize_logger +import json +import benchmarkstt.csv as csv +import benchmarkstt.normalization as normalization + +factory = normalization.factory + + +def callback(cls, text: str, return_logs: bool = None, *args, **kwargs): + """ + :param str text: The text to normalize + :param bool return_logs: Return normalizer logs + """ + if return_logs: + handler = ListHandler() + handler.setFormatter(DiffLoggingFormatter(dialect='html')) + normalize_logger.addHandler(handler) + + try: + result = { + "text": cls(*args, **kwargs).normalize(text) + } + if return_logs: + logs = handler.flush() + result['logs'] = [] + for log in logs: + result['logs'].append(dict(names=log[0], message=log[1])) + return result + except csv.CSVParserError as e: + message = 'on line %d, character %d' % (e.line, e.char) + message = '\n'.join([e.__doc__, e.message, message]) + data = { + "message": message, + "line": e.line, + "char": e.char, + "index": e.index, + "field": "config" + } + raise AssertionError(json.dumps(data)) + finally: + if return_logs: + normalize_logger.removeHandler(handler) diff --git a/src/benchmarkstt/normalization/cli.py b/src/benchmarkstt/normalization/cli.py index dbf59990..8ad84be8 100644 --- a/src/benchmarkstt/normalization/cli.py +++ b/src/benchmarkstt/normalization/cli.py @@ -3,66 +3,18 @@ """ import sys -from . import core, NormalizationComposite +from . import NormalizationComposite import argparse -from . import available_normalizers, name_to_normalizer -import textwrap -import itertools -from . import logger +from . import factory +from .logger import DiffLoggingFormatter, normalize_logger import logging +from benchmarkstt.cli import args_from_factory -class _NormalizerAction: - """ - Placeholder class to recognize an argument is a NormalizerAction in argparse - """ - - -def normalizer_action(required_args, optional_args): - """ - Custom argparse action to support a variable amount of arguments - :param list required_args: required arguments - :param list optional_args: optional arguments - :rtype: NormalizerAction - """ - - minlen = len(required_args) - maxlen = minlen + len(optional_args) - - class NormalizerAction(argparse.Action, _NormalizerAction): - def __call__(self, parser, args, values, option_string=None): - if len(values) < minlen or len(values) > maxlen: - raise argparse.ArgumentTypeError('argument "%s" requires between %d and %d arguments (got %d)' % - (self.dest, minlen, maxlen, len(values))) - - if 'normalizers' not in args: - args.normalizers = [] - - args.normalizers.append([self.dest] + values) - - return NormalizerAction - - -class Formatter(argparse.HelpFormatter): - """ - Custom formatter for argparse that allows us to properly display _NormalizerActions and docblock documentation - """ - - def _format_args(self, action, default_metavar): - if isinstance(action, _NormalizerAction): - return ' '.join(action.metavar) - - return super()._format_args(action, default_metavar) - - def _split_lines(self, text, width): - def wrap(txt): - if txt == '': - return [''] - return textwrap.wrap(txt, width=width) - - text = text.splitlines() - text = list(itertools.chain.from_iterable(map(wrap, text))) - return text +def args_inputfile(parser): + parser.add_argument('-i', '--inputfile', action='append', nargs=1, + help='read input from this file, defaults to STDIN', + metavar='file') def argparser(parser: argparse.ArgumentParser): @@ -70,6 +22,9 @@ def argparser(parser: argparse.ArgumentParser): Adds the help and arguments specific to this module """ + parser.add_argument('--log', action='store_true', + help='show normalizer logs') + files_desc = """ You can provide multiple input and output files, each preceded by -i and -o respectively. @@ -79,15 +34,10 @@ def argparser(parser: argparse.ArgumentParser): output file.""" files = parser.add_argument_group('input and output files', description=files_desc) - - files.add_argument('-i', '--inputfile', action='append', nargs=1, - help='read input from this file, defaults to STDIN', - metavar='file') + args_inputfile(files) files.add_argument('-o', '--outputfile', action='append', nargs=1, help='write output to this file, defaults to STDOUT', metavar='file') - files.add_argument('--log', action='store_true', - help='show normalizer logs') normalizers_desc = """ A list of normalizers to execute on the input, can be one or more normalizers @@ -97,23 +47,7 @@ def argparser(parser: argparse.ArgumentParser): At least one normalizer needs to be provided.""" normalizers = parser.add_argument_group('available normalizers', description=normalizers_desc) - - for name, conf in available_normalizers().items(): - docs = conf.docs - - arguments = dict() - arguments['help'] = docs - arguments['nargs'] = 0 - - if len(conf.required_args) or len(conf.optional_args): - arguments['nargs'] = '+' - optionals = list(map(lambda x: '[%s]' % x, conf.optional_args)) - arguments['metavar'] = tuple(conf.required_args + optionals) - - arguments['action'] = normalizer_action(conf.required_args, conf.optional_args) - - normalizers.add_argument('--%s' % (name,), **arguments) - + args_from_factory('normalizers', factory, normalizers) return parser @@ -133,15 +67,15 @@ def main(parser, args): if args.log: handler = logging.StreamHandler() - handler.setFormatter(logger.DiffLoggingFormatter('cli')) + handler.setFormatter(DiffLoggingFormatter('cli')) handler.setLevel(logging.INFO) - logger.normalize_logger.addHandler(handler) + normalize_logger.addHandler(handler) composite = NormalizationComposite() for item in args.normalizers: normalizer_name = item.pop(0).replace('-', '.') - cls = name_to_normalizer(normalizer_name) - composite.add(cls(*item)) + normalizer = factory.create(normalizer_name, *item) + composite.add(normalizer) if output_files is not None: # pre-open the output files before doing the grunt work diff --git a/src/benchmarkstt/normalization/core.py b/src/benchmarkstt/normalization/core.py index 6fb07288..a592cd68 100644 --- a/src/benchmarkstt/normalization/core.py +++ b/src/benchmarkstt/normalization/core.py @@ -1,15 +1,6 @@ """ Some basic/simple normalization classes - -Each normalization class has a method called `normalize`: - -.. code-block:: python - - def normalize(text: str) -> str: - "\""Returns normalized text with rules supplied by the called class. - "\"" - """ import re @@ -19,12 +10,12 @@ def normalize(text: str) -> str: import inspect from langcodes import best_match, standardize_tag from benchmarkstt import csv, normalization -from benchmarkstt.normalization.logger import log +from benchmarkstt.normalization import Base default_encoding = 'UTF-8' -class LocalizedFile: +class LocalizedFile(Base): """ Reads and applies normalization rules from a locale-based file, it will automatically determine the "best fit" for a given locale, if one is @@ -67,12 +58,11 @@ def __init__(self, normalizer, locale: str, path: str, encoding=None): self._normalizer = File(normalizer, file, encoding=encoding) - @log - def normalize(self, text: str) -> str: + def _normalize(self, text: str) -> str: return self._normalizer.normalize(text) -class Replace: +class Replace(Base): """ Simple search replace @@ -89,12 +79,11 @@ def __init__(self, search: str, replace: str = ''): self._search = search self._replace = replace - @log - def normalize(self, text: str) -> str: + def _normalize(self, text: str) -> str: return text.replace(self._search, self._replace) -class ReplaceWords: +class ReplaceWords(Base): """ Simple search replace that only replaces "words", the first letter will be checked case insensitive as well with preservation of case.. @@ -127,12 +116,11 @@ def _replacement_callback(self, matches): return ''.join([self._replace[0].lower(), self._replace[1:]]) - @log - def normalize(self, text: str) -> str: + def _normalize(self, text: str) -> str: return self._pattern.sub(self._replacement_callback, text) -class File: +class File(Base): """ Read one per line and pass it to the given normalizer @@ -149,8 +137,10 @@ class File: def __init__(self, normalizer, file, encoding=None): try: - cls = normalizer if inspect.isclass(normalizer) else \ - normalization.name_to_normalizer(normalizer) + if inspect.isclass(normalizer): + cls = normalizer + else: + cls = normalization.factory.get_class(normalizer) except ValueError: raise ValueError("Unknown normalizer %s" % (repr(normalizer))) @@ -167,12 +157,11 @@ def __init__(self, normalizer, file, encoding=None): except TypeError as e: raise ValueError("Line %d: %s" % (line.lineno, str(e))) - @log - def normalize(self, text: str) -> str: + def _normalize(self, text: str) -> str: return self._normalizer.normalize(text) -class RegexReplace: +class RegexReplace(Base): r""" Simple regex replace. By default the pattern is interpreted case-sensitive. @@ -212,8 +201,7 @@ def __init__(self, search: str, replace: str = None): self._pattern = re.compile(search) self._substitution = replace if replace is not None else '' - @log - def normalize(self, text: str) -> str: + def _normalize(self, text: str) -> str: return self._pattern.sub(self._substitution, text) @@ -242,7 +230,7 @@ def __init__(self): super().__init__(r'[^\w]+') -class Lowercase: +class Lowercase(Base): """ Lowercase the text @@ -251,12 +239,11 @@ class Lowercase: :example return: "easy, mungo, easy... mungo..." """ - @log - def normalize(self, text: str) -> str: + def _normalize(self, text: str) -> str: return text.lower() -class Unidecode: +class Unidecode(Base): """ Unidecode characters to ASCII form, see `Python's Unidecode package `_ for more info. @@ -265,12 +252,11 @@ class Unidecode: :example return: "Wenn ist das Nunstuck git und Slotermeyer?" """ - @log - def normalize(self, text: str) -> str: + def _normalize(self, text: str) -> str: return unidecode(text) -class Config: +class Config(Base): r""" Use config notation to define normalization rules. This notation is a list of normalizers, one per line, with optional arguments (separated by a @@ -305,13 +291,22 @@ class Config: :param str config: configuration text :example text: "He bravely turned his tail and fled" - :example config: '''# using a simple config file\nLowercase \n - # it even supports comments - # If there is a space in the argument, make sure you quote it though! - regexreplace "y t" "Y T" - \n\n - # extraneous whitespaces are ignored - replace e a\n''' + :example config: + + .. code-block:: text + + # using a simple config file + Lowercase + + # it even supports comments + # If there is a space in the argument, + # make sure you quote it though! + + regexreplace "y t" "Y T" + + # extraneous whitespaces are ignored + replace e a + :example return: "ha bravalY Turnad his tail and flad" """ @@ -322,14 +317,13 @@ def _parse_config(self, file): self._normalizer = normalization.NormalizationComposite() for line in csv.reader(file, dialect='whitespace'): try: - normalizer = normalization.name_to_normalizer(line[0]) + normalizer = normalization.factory.get_class(line[0]) except ValueError: raise ValueError("Unknown normalizer %s on line %d: %s" % (repr(line[0]), line.lineno, repr(' '.join(line)))) self._normalizer.add(normalizer(*line[1:])) - @log - def normalize(self, text: str) -> str: + def _normalize(self, text: str) -> str: return self._normalizer.normalize(text) diff --git a/src/benchmarkstt/normalization/logger.py b/src/benchmarkstt/normalization/logger.py index 37acb22f..3b71646d 100644 --- a/src/benchmarkstt/normalization/logger.py +++ b/src/benchmarkstt/normalization/logger.py @@ -1,9 +1,6 @@ -import difflib -from benchmarkstt import make_printable import logging -from markupsafe import escape import os -from benchmarkstt import DeferredRepr +from benchmarkstt.diff.formatter import DiffFormatter normalize_logger = logging.getLogger('benchmarkstt.normalize') normalize_logger.setLevel(logging.INFO) @@ -11,28 +8,6 @@ normalize_stack = [] -class CLIDiffDialect: - preprocessor = make_printable - delete_format = '\033[31m%s\033[0m' - insert_format = '\033[32m%s\033[0m' - formats = None - - @staticmethod - def format(names, diff): - return '|'.join(names) + ': ' + diff - - -class HTMLDiffDialect: - preprocessor = escape - delete_format = '%s' - insert_format = '%s' - formats = None - - @staticmethod - def format(names, diff): - return names, diff - - class ListHandler(logging.StreamHandler): def __init__(self): self._logs = [] @@ -50,55 +25,13 @@ def flush(self): class DiffLoggingFormatter(logging.Formatter): def __init__(self, dialect=None): - self._differ = Differ(dialect) + self._differ = DiffFormatter(dialect) super().__init__() def format(self, record): return self._differ.format(record) -class Differ: - diff_dialects = { - "cli": CLIDiffDialect, - "html": HTMLDiffDialect - } - - def __init__(self, dialect=None): - if dialect is None: - dialect = 'cli' - if dialect not in self.diff_dialects: - raise ValueError("Unknown diff dialect", dialect) - self._dialect = self.diff_dialects[dialect] - - def format(self, record): - return self._dialect.format(record.args[0], self.diff(record.args[1], record.args[2])) - - def diff(self, a, b): - dialect = self._dialect - preprocessor = dialect.preprocessor - cruncher = difflib.SequenceMatcher(None, a, b) - - if dialect.formats is None: - formats = { - 'replace': dialect.delete_format + dialect.insert_format, - 'delete': dialect.delete_format + '%s', - 'insert': '%s' + dialect.insert_format, - 'equal': '%s', - } - - res = [] - for tag, alo, ahi, blo, bhi in cruncher.get_opcodes(): - a_ = preprocessor(a[alo:ahi]) - - if tag == 'equal': - res.append(formats['equal'] % (preprocessor(a[alo:ahi]),)) - continue - - b_ = preprocessor(b[blo:bhi]) - res.append(formats[tag] % (a_, b_)) - return ''.join(res) - - def log(func): """ Log decorator for normalization classes diff --git a/src/benchmarkstt/schema.py b/src/benchmarkstt/schema.py new file mode 100644 index 00000000..d2ac311e --- /dev/null +++ b/src/benchmarkstt/schema.py @@ -0,0 +1,177 @@ +""" +Defines the main schema for comparison and implements json serialization +""" +import json +from collections.abc import Mapping +from typing import Union +from collections import defaultdict + + +class SchemaError(ValueError): + """Top Error class for all schema related exceptions""" + + +class SchemaJSONError(SchemaError): + """When loading incompatible JSON""" + + +class SchemaInvalidItemError(SchemaError): + """Attempting to add an invalid item""" + + +class Item(Mapping): + """ + Basic structure of each field to compare + + :raises: ValueError, SchemaInvalidItemError + """ + + def __init__(self, *args, **kwargs): + if len(args) > 1: + raise ValueError('Expected max 1 argument') + if len(args) and len(kwargs): + raise ValueError("Cannot combine both a positional and keyword arguments") + if len(args): + if not isinstance(args[0], dict): + raise SchemaInvalidItemError("Expected a dict object", args[0]) + self._val = args[0] + else: + self._val = dict(**kwargs) + self.meta = Meta() + + def __getitem__(self, k): + return self._val[k] + + def __len__(self) -> int: + return len(self._val) + + def __iter__(self): + return iter(self._val) + + def __repr__(self): + return 'Item(%s)' % (self.json(),) + + def json(self, **kwargs): + return Schema.dumps(self, **kwargs) + + def _asdict(self): + return self._val + + def __eq__(self, other): + if type(other) is Item: + other = other._asdict() + return self._val == other + + def __ne__(self, other): + return self._val != other + + +class Meta(defaultdict): + """Containing metadata for an item, such as skipped""" + + +class Schema: + """ + Basically a list of :py:class:`Item` + """ + + def __init__(self, data=None): + # make Schema.dump/dumps methods available as instance methods + self.dump = self.__dump + self.dumps = self.__dumps + if data is None: + self._data = [] + else: + self._data = [item if type(item) is Item else Item(item) for item in data] + + def __repr__(self): + return 'Schema(%s)' % (self.json(),) + + def __len__(self): + return len(self._data) + + def __iter__(self): + return iter(self._data) + + def __getitem__(self, item): + return self._data[item] + + """ + :raises: SchemaJSONError + """ + @staticmethod + def load(*args, **kwargs): + return json.load(*args, **kwargs, cls=JSONDecoder) + + """ + :raises: SchemaJSONError + """ + @staticmethod + def loads(*args, **kwargs): + return json.loads(*args, **kwargs, cls=JSONDecoder) + + @staticmethod + def dump(cls, *args, **kwargs): + return json.dump(cls, *args, **kwargs, cls=JSONEncoder) + + @staticmethod + def dumps(cls, *args, **kwargs): + return json.dumps(cls, *args, **kwargs, cls=JSONEncoder) + + def __dump(self, *args, **kwargs): + return Schema.dump(self, *args, **kwargs) + + def __dumps(self, *args, **kwargs): + return Schema.dumps(self, *args, **kwargs) + + def json(self, **kwargs): + return self.dumps(**kwargs) + + def append(self, obj: Union[Item, dict]): + if isinstance(obj, dict): + obj = Item(obj) + elif type(obj) is not Item: + raise SchemaError("Wrong type", type(obj)) + self._data.append(obj) + + def extend(self, iterable): + self._data.extend((item if type(item) is Item else Item(item) for item in iterable)) + + def _aslist(self): + return self._data + + def __eq__(self, other): + if type(other) is Schema: + other = other._aslist() + return self._data == other + + def __ne__(self, other): + return self._data != other + + +class JSONEncoder(json.JSONEncoder): + """Custom JSON encoding for schema""" + + def default(self, obj): + if isinstance(obj, Schema): + return obj._aslist() + if isinstance(obj, Item): + return obj._asdict() + return super().default(obj) + + +class JSONDecoder(json.JSONDecoder): + """Custom JSON decoding for schema""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs, object_hook=self.object_hook) + + def decode(self, *args, **kwargs): + result = super().decode(*args, **kwargs) + if type(result) is not list: + raise SchemaJSONError("Expected a list") + return Schema(result) + + @staticmethod + def object_hook(obj): + return Item(obj) diff --git a/src/benchmarkstt/segmentation/__init__.py b/src/benchmarkstt/segmentation/__init__.py new file mode 100644 index 00000000..f5ceed87 --- /dev/null +++ b/src/benchmarkstt/segmentation/__init__.py @@ -0,0 +1,14 @@ +from benchmarkstt.factory import Factory + + +class Base: + def __iter__(self): + """ + Each segmentation class should be accessible as iterator, each iteration should + return a Item, so the input format is essentially usable and can be easily + converted to a :py:class:`benchmarkstt.schema.Schema` + """ + raise NotImplementedError() + + +factory = Factory(Base) diff --git a/src/benchmarkstt/segmentation/core.py b/src/benchmarkstt/segmentation/core.py new file mode 100644 index 00000000..56fb54a1 --- /dev/null +++ b/src/benchmarkstt/segmentation/core.py @@ -0,0 +1,38 @@ +""" +Core segmenters, each segmenter must be Iterable returning a Item +""" + +import re +from benchmarkstt.schema import Item +from benchmarkstt.segmentation import Base + + +class Simple(Base): + """ + Simplest case, split into words by white space + """ + + def __init__(self, text: str, pattern=r'[\n\t\s]+'): + self._text = text + self._re = re.compile('(%s)' % (pattern,)) + + def __iter__(self): + start_match = self._re.match(self._text) + iterable = self._re.split(self._text) + if iterable[0] == '': + iterable.pop(0) + + pos = 0 + length = len(iterable) + + # special case, starts with word break, add it to first word + if start_match is not None: + matches = iterable[0:3] + pos = 3 + yield Item({"item": matches[1], "type": "word", "@raw": ''.join(matches)}) + + while pos < length: + raw = ''.join(iterable[pos:pos+2]) + if raw != '': + yield Item({"item": iterable[pos], "type": "word", "@raw": raw}) + pos += 2 diff --git a/tests/_data/a.txt b/tests/_data/a.txt new file mode 100644 index 00000000..495577c9 --- /dev/null +++ b/tests/_data/a.txt @@ -0,0 +1 @@ +TEST my data should be one difference diff --git a/tests/_data/b.txt b/tests/_data/b.txt new file mode 100644 index 00000000..e24d2ae9 --- /dev/null +++ b/tests/_data/b.txt @@ -0,0 +1 @@ +TEST my data should be ONE difference diff --git a/tests/_data/candide.txt b/tests/_data/candide.txt new file mode 100644 index 00000000..73cd5e33 --- /dev/null +++ b/tests/_data/candide.txt @@ -0,0 +1,10 @@ + +"There is a concatenation of events in this best of all possible worlds: +for if you had not been kicked out of a magnificent castle for love of +Miss Cunegonde: if you had not been put into the Inquisition: if you had +not walked over America: if you had not stabbed the Baron: if you had +not lost all your sheep from the fine country of El Dorado: you would +not be here eating preserved citrons and pistachio-nuts." + +"All that is very well," answered Candide, "but let us cultivate our +garden." diff --git a/tests/_data/csv.debugging.output.txt b/tests/_data/csv.debugging.output.txt new file mode 100644 index 00000000..7a003fda --- /dev/null +++ b/tests/_data/csv.debugging.output.txt @@ -0,0 +1,2 @@ +MODES: COMMENT FIRST INSIDE INSIDE_QUOTED INSIDE_QUOTED_QUOTE OUTSIDE +␊Some·line,·some·other·␉··␉␊dsfgdsg␊␊·␊·········␉··␍␊·␊·␍␊␊"stay","togther··"␊␊#·commented·out␊fsdss␊ diff --git a/tests/benchmarkstt/normalization/test_init.py b/tests/benchmarkstt/normalization/test_init.py deleted file mode 100644 index 8a45f3e3..00000000 --- a/tests/benchmarkstt/normalization/test_init.py +++ /dev/null @@ -1,35 +0,0 @@ -from benchmarkstt.normalization import core -from benchmarkstt.normalization import NormalizationComposite, name_to_normalizer, is_normalizer -from benchmarkstt.normalization import NormalizerConfig, available_normalizers - - -def test_name_to_normalizer(): - assert name_to_normalizer('Replace') is core.Replace - assert name_to_normalizer('replace') is core.Replace - - -def test_available_normalizers(): - normalizers = available_normalizers() - assert type(normalizers) is dict - - for name, conf in normalizers.items(): - # normalizer = name_to_normalizer(name) - assert type(conf) is NormalizerConfig - assert is_normalizer(conf.cls) - assert name_to_normalizer(name.upper()) is conf.cls - - -def test_is_normalizer(): - nope = [ - True, - False, - None, - "replace", - is_normalizer, - NormalizerConfig - ] - - for not_normalizer in nope: - assert is_normalizer(not_normalizer) is False - - assert is_normalizer(NormalizationComposite) is True diff --git a/tests/benchmarkstt/test_baseclasses.py b/tests/benchmarkstt/test_baseclasses.py new file mode 100644 index 00000000..f0eee5da --- /dev/null +++ b/tests/benchmarkstt/test_baseclasses.py @@ -0,0 +1,25 @@ +from benchmarkstt.normalization import Base as NormalizationBase +from benchmarkstt.metrics import Base as MetricsBase +from benchmarkstt.diff import Base as DiffBase +from benchmarkstt.segmentation import Base as SegmentationBase +from benchmarkstt.input import Base as InputBase +import pytest +from inspect import signature + + +@pytest.mark.parametrize('base_class,methods', [ + [NormalizationBase, '_normalize'], + [MetricsBase, 'compare'], + [DiffBase, ['__init__', 'get_opcodes']], + [SegmentationBase, '__iter__'], + [InputBase, '__iter__'], +]) +def test_baseclasses(base_class, methods): + if type(methods) is str: + methods = [methods] + for method in methods: + to_call = getattr(base_class, method) + sig = signature(to_call) + with pytest.raises(NotImplementedError): + args = [None] * len(sig.parameters) + to_call(*args) diff --git a/tests/benchmarkstt/test_benchmarkstt.py b/tests/benchmarkstt/test_benchmarkstt.py new file mode 100644 index 00000000..45a09dac --- /dev/null +++ b/tests/benchmarkstt/test_benchmarkstt.py @@ -0,0 +1,50 @@ +from benchmarkstt import DeferredRepr, DeferredStr, make_printable +import pytest + + +def cb(txt): + def _(): + _.cb_count += 1 + return '[%s]' % (txt,) + _.cb_count = 0 + return _ + + +class ToDefer: + def __init__(self, value): + self.value = value + self.cb_count = 0 + + def __repr__(self): + self.cb_count += 1 + return '' % (repr(self.value),) + + +def test_deferred_str(): + callback = cb('test') + deferred = DeferredStr(callback) + assert callback.cb_count == 0 + assert str(deferred) == '[test]' + assert callback.cb_count == 1 + assert repr(deferred) == '[test]' + assert callback.cb_count == 2 + + +def test_deferred_repr(): + callback = ToDefer(None) + deferred = DeferredRepr(callback) + assert callback.cb_count == 0 + assert str(deferred) == '' + assert callback.cb_count == 1 + assert repr(deferred) == '' + assert callback.cb_count == 2 + + +@pytest.mark.parametrize('orig,printable', [ + ['', ''], + [' ', '·'], + ['I\'m afraid I\thave no choice\nbut to sell you all for\tscientific experiments.', + 'I\'m·afraid·I␉have·no·choice␊but·to·sell·you·all·for␉scientific·experiments.'] +]) +def test_make_printable(orig, printable): + assert make_printable(orig) == printable diff --git a/tests/benchmarkstt/test_cli.py b/tests/benchmarkstt/test_cli.py new file mode 100644 index 00000000..5adc888b --- /dev/null +++ b/tests/benchmarkstt/test_cli.py @@ -0,0 +1,70 @@ +import pytest +import sys +from textwrap import dedent +from benchmarkstt.cli import main +from unittest import mock + +from benchmarkstt.__meta__ import __version__ + + +candide_lowercase = """ +"there is a concatenation of events in this best of all possible worlds: +for if you had not been kicked out of a magnificent castle for love of +miss cunegonde: if you had not been put into the inquisition: if you had +not walked over america: if you had not stabbed the baron: if you had +not lost all your sheep from the fine country of el dorado: you would +not be here eating preserved citrons and pistachio-nuts." + +"all that is very well," answered candide, "but let us cultivate our +garden." +""" + + +@pytest.mark.parametrize('argv,result', [ + [[], 2], + ['--version', 'benchmarkstt: %s\n' % (__version__,)], + ['invalidsubmodule', 2], + ['normalization', 2], + ['--help', 0], + ['normalization -i tests/_data/candide.txt --lowercase', candide_lowercase], + ['normalization -i tests/_data/candide.txt --file', 2], + ['metrics -r tests/_data/a.txt -h tests/_data/b.txt', 2], + ['metrics -r "HI" -h "HELLO" -rt argument -ht argument --wer', "wer\n===\n\n1.000000\n\n"], + ['metrics -r tests/_data/a.txt -h tests/_data/b.txt --wer --worddiffs --diffcounts', + dedent(''' + wer + === + + 0.142857 + + worddiffs + ========= + + ·TEST·my·data·should·be\033[31m·one\033[0m\033[32m·ONE\033[0m·difference + + diffcounts + ========== + + OpcodeCounts(equal=6, replace=1, insert=0, delete=0) + + ''').lstrip()] +]) +def test_cli(argv, result, capsys): + if type(argv) is str: + argv = argv.split() + with mock.patch('sys.argv', ['benchmarkstt'] + argv): + if type(result) is int: + with pytest.raises(SystemExit) as err: + main() + assert str(err).endswith(': %d' % (result,)) + else: + with pytest.raises(SystemExit) as err: + main() + assert str(err).endswith(': 0') + + captured = capsys.readouterr() + if type(result) is list: + assert captured.out == result[0] + assert captured.err == result[1] + else: + assert captured.out == result diff --git a/tests/benchmarkstt/test_csv.py b/tests/benchmarkstt/test_csv.py index b97a4395..47829631 100644 --- a/tests/benchmarkstt/test_csv.py +++ b/tests/benchmarkstt/test_csv.py @@ -2,11 +2,28 @@ import pytest from io import StringIO +example1 = ''' +Some line, some other \t \t +dsfgdsg + + \n \t \r + \n \r + +"stay","togther " + +# commented out +fsdss +''' + +expected1 = [['Some line', 'some other'], ['dsfgdsg'], ['stay', 'togther '], ['fsdss']] -def test_csv(): - def _reader(text): - return list(reader(StringIO(text))) +def get_reader(text, *args, **kwargs): + return list(reader(StringIO(text), *args, **kwargs)) + + +def test_csv(): + _reader = get_reader assert _reader('replace," ","\n"') == [['replace', ' ', '\n']] assert type(reader(StringIO(''))) is Reader assert type(Reader(StringIO(''), DefaultDialect)) is Reader @@ -15,34 +32,12 @@ def _reader(text): assert _reader('') == [] - expected = [['Some line', 'some other'], ['dsfgdsg'], ['stay', 'togther '], - ['fsdss']] - assert _reader(''' - Some line, some other \t \t - dsfgdsg - - \n \t \r - \n \r - - "stay","togther " - - # commented out - fsdss - ''') == expected - - with pytest.raises(CSVParserError): - _reader('stray"quote') + assert _reader(example1) == expected1 assert _reader('"","test"," quiot"""') == [['', 'test', ' quiot"']] assert _reader(' val1 ,\t val2 \n') == [['val1', 'val2']] - with pytest.raises(UnclosedQuoteError): - _reader(' s ,"') - - with pytest.raises(UnallowedQuoteError): - _reader(' fsd","') - assert _reader(' ","') == [[',']] assert _reader('""') == [['']] @@ -58,6 +53,7 @@ def _reader(text): assert _reader('\t t ') == [['t']] assert _reader('t') == [['t']] assert _reader('replace," ","\n"') == [['replace', ' ', '\n']] + assert _reader(',') == [['', '']] def test_conf(): @@ -111,3 +107,49 @@ def _reader(text): assert _reader('test "stuff\n\t"\n\t \t YEs \t \n') == \ [['test', 'stuff\n\t'], ['YEs']] assert _reader("\n\n\n\nline5")[0].lineno == 5 + + +def test_exceptions(): + _reader = get_reader + + with pytest.raises(InvalidDialectError): + Reader(StringIO(''), dialect=InvalidDialectError) + + with pytest.raises(UnknownDialectError): + _reader('', dialect='notknown') + + with pytest.raises(UnallowedQuoteError) as exc: + _reader('test "') + + assert "Quote not allowed here" in str(exc) + + with pytest.raises(CSVParserError): + _reader('stray"quote') + + with pytest.raises(UnclosedQuoteError) as exc: + _reader(' s ,"') + + assert "Unexpected end" in str(exc) + + with pytest.raises(UnallowedQuoteError): + _reader(' fsd","') + + with pytest.raises(UnallowedQuoteError) as exc: + _reader('""test,') + assert "Single quote inside quoted field" in str(exc) + + +def test_own_dialect(): + class OwnDialect(Dialect): + delimiter = ';' + + assert get_reader("Tester \n No Trim ", dialect=OwnDialect) == [['Tester '], [' No Trim ']] + + +def test_debugger(capsys): + gotten = get_reader(example1, debug=True) + + assert gotten == expected1 + with open('tests/_data/csv.debugging.output.txt', encoding='UTF-8') as f: + expected_debug = f.read() + assert capsys.readouterr().out == expected_debug diff --git a/tests/benchmarkstt/test_decorators.py b/tests/benchmarkstt/test_decorators.py new file mode 100644 index 00000000..7984cf74 --- /dev/null +++ b/tests/benchmarkstt/test_decorators.py @@ -0,0 +1,44 @@ +import logging +from benchmarkstt.decorators import log_call + + +def test_log_call(caplog): + logger = logging.getLogger() + + @log_call(logger, logging.WARNING) + def test(*args, **kwargs): + return 'result' + + test('arg1', arg2='someval') + assert caplog.record_tuples == [ + ('root', logging.WARNING, "test('arg1', arg2='someval')") + ] + + +def test_log_call2(caplog): + logger = logging.getLogger('testname') + caplog.set_level(logging.INFO) + + @log_call(logger, result=True, log_level=logging.INFO) + def test(*args, **kwargs): + return 'result' + + test(arg2='someval') + assert caplog.record_tuples == [ + ('testname', logging.INFO, "test(arg2='someval')"), + ('testname', logging.INFO, 'test returned: result') + ] + + +def test_log_call3(caplog): + logger = logging.getLogger('testname') + caplog.set_level(logging.DEBUG) + + @log_call(logger) + def funcname(): + return None + + funcname() + assert caplog.record_tuples == [ + ('testname', logging.DEBUG, "funcname()"), + ] diff --git a/tests/benchmarkstt/test_diff.py b/tests/benchmarkstt/test_diff.py new file mode 100644 index 00000000..6df3b705 --- /dev/null +++ b/tests/benchmarkstt/test_diff.py @@ -0,0 +1,24 @@ +from benchmarkstt import diff +import pytest + +differs = [differ.cls for differ in diff.factory] +differs_decorator = pytest.mark.parametrize('differ', differs) + + +@differs_decorator +def test_one_insert(differ): + sm = differ('b' * 100, 'a' + 'b' * 100) + assert list(sm.get_opcodes()) == [('insert', 0, 0, 0, 1), + ('equal', 0, 100, 1, 101)] + sm = differ('b' * 100, 'b' * 50 + 'a' + 'b' * 50) + assert list(sm.get_opcodes()) == [('equal', 0, 50, 0, 50), + ('insert', 50, 50, 50, 51), + ('equal', 50, 100, 51, 101)] + + +@differs_decorator +def test_one_delete(differ): + sm = differ('a' * 40 + 'c' + 'b' * 40, 'a' * 40 + 'b' * 40) + assert list(sm.get_opcodes()) == [('equal', 0, 40, 0, 40), + ('delete', 40, 41, 40, 40), + ('equal', 41, 81, 40, 80)] diff --git a/tests/benchmarkstt/test_diff_formatter.py b/tests/benchmarkstt/test_diff_formatter.py new file mode 100644 index 00000000..3bdd0087 --- /dev/null +++ b/tests/benchmarkstt/test_diff_formatter.py @@ -0,0 +1,21 @@ +from benchmarkstt.diff.formatter import format_diff +import pytest + +a = 'ABCDEFGHJKLMN' +b = 'ABBCDEFHHIJKLM' + + +@pytest.mark.parametrize('dialect,expected', [ + ['text', 'AB\u0359BCDEFG\u0338HH\u0359I\u0359JKLMN\u0338'], + ['cli', 'A\033[32mB\033[0mBCDEF\033[31mG\033[0mH\033[32mHI\033[0mJKLM\033[31mN\033[0m'], + ['html', 'ABBCDEFG' + 'HHIJKLMN'] +]) +def test_format_diff(dialect, expected): + gotten = format_diff(a, b, dialect=dialect) + assert gotten == expected + assert gotten == expected + + +def test_no_diff(): + assert format_diff(a, a, dialect='cli') == a diff --git a/tests/benchmarkstt/test_docblock.py b/tests/benchmarkstt/test_docblock.py new file mode 100644 index 00000000..72175f42 --- /dev/null +++ b/tests/benchmarkstt/test_docblock.py @@ -0,0 +1,113 @@ +from benchmarkstt.docblock import * +from textwrap import dedent + + +def test_text(): + txt = " \t \t\n\n\t " + + assert process_rst(txt, 'text') == '' + + txt = ''' + .. code-block:: text + + Some block + In samem block + + Still included + + + Not anymore +''' + print(process_rst(txt, 'text')) + + assert process_rst(txt, 'text') == """Some block +In samem block + +Still included + +Not anymore""" + + +def test_parse(): + def dummy_func(config): + """ + The normalization rules are applied top-to-bottom and follow this format: + + .. code-block:: text + + Normalizer1 arg1 "arg 2" + # This is a comment + + Normalizer2 + # (Normalizer2 has no arguments) + Normalizer3 "This is argument 1 + Spanning multiple lines + " "argument 2" + Normalizer4 "argument with double quote ("")" + + :param str config: configuration text + + :example text: "He bravely turned his tail and fled" + :example config: + + .. code-block:: text + + # using a simple config file + Lowercase + + # it even supports comments + # If there is a space in the argument, + # make sure you quote it though! + + regexreplace "y t" "Y T" + + # extraneous whitespaces are ignored + replace e a + + :example return: "ha bravalY Turnad his tail and flad" + """ + + expected = Docblock( + docs=dedent(''' + The normalization rules are applied top-to-bottom and follow this format: + + .. code-block:: text + + Normalizer1 arg1 "arg 2" + # This is a comment + + Normalizer2 + # (Normalizer2 has no arguments) + Normalizer3 "This is argument 1 + Spanning multiple lines + " "argument 2" + Normalizer4 "argument with double quote ("")" + ''').strip(), + params=[Param(name='config', type=None, type_doc='str', is_required=True, description='configuration text', + examples=[ + {'text': DocblockParam(name='text', type=None, value='He bravely turned his tail and fled'), + 'config': DocblockParam(name='config', type=None, + value=dedent(''' + # using a simple config file + Lowercase + + # it even supports comments + # If there is a space in the argument, + # make sure you quote it though! + + regexreplace "y t" "Y T" + + # extraneous whitespaces are ignored + replace e a''').strip()), + 'return': DocblockParam(name='return', type=None, + value='ha bravalY Turnad his tail and flad') + } + ] + ) + ], + result=None, + result_type=None) + + parsed = parse(dummy_func) + assert parsed.docs == expected.docs + # todo: test the other Docblock properties as well diff --git a/tests/benchmarkstt/test_factory.py b/tests/benchmarkstt/test_factory.py new file mode 100644 index 00000000..7e6f3500 --- /dev/null +++ b/tests/benchmarkstt/test_factory.py @@ -0,0 +1,56 @@ +from benchmarkstt.factory import Factory +from pytest import raises + +module_name = __name__ + + +class Base: + pass + + +class ValidClass(Base): + pass + + +class InvalidClass: + pass + + +def test_factory_exception(): + factory = Factory(Base, [test_factory_exception.__module__]) + assert factory.get_class('validclass') == ValidClass + with raises(ValueError) as exc: + factory.register(InvalidClass) + + assert "Invalid class (must inherit from Base class)" in str(exc) + + +def test_factory(): + factory = Factory(Base, []) + factory.register(ValidClass) + assert factory.get_class('validclass') == ValidClass + factory.register(ValidClass, 'alias') + assert factory.get_class('alias') == ValidClass + + assert type(factory.create('alias')) == ValidClass + + with raises(ValueError) as exc: + factory.register(ValidClass) + + assert "Conflict: alias 'validclass' is already registered" in str(exc) + + with raises(ValueError) as exc: + factory.register_namespace(module_name) + + assert "Conflict: alias 'validclass' is already registered" in str(exc) + + +def test_nodoclog(caplog): + factory = Factory(Base, [test_nodoclog.__module__]) + + for conf in factory: + pass + + assert caplog.record_tuples == [ + ('benchmarkstt.factory', 30, "No docstring for 'ValidClass'") + ] diff --git a/tests/benchmarkstt/test_input_core.py b/tests/benchmarkstt/test_input_core.py new file mode 100644 index 00000000..e7b45e53 --- /dev/null +++ b/tests/benchmarkstt/test_input_core.py @@ -0,0 +1,120 @@ +from benchmarkstt.input.core import PlainText, File +from benchmarkstt.schema import Item, Schema +import pytest + +candide_file = 'tests/_data/candide.txt' +with open(candide_file) as f: + candide = f.read() + +candide_schema = [Item({"item": "\"There", "type": "word", "@raw": "\n\"There "}), + Item({"item": "is", "type": "word", "@raw": "is "}), + Item({"item": "a", "type": "word", "@raw": "a "}), + Item({"item": "concatenation", "type": "word", "@raw": "concatenation "}), + Item({"item": "of", "type": "word", "@raw": "of "}), + Item({"item": "events", "type": "word", "@raw": "events "}), + Item({"item": "in", "type": "word", "@raw": "in "}), + Item({"item": "this", "type": "word", "@raw": "this "}), + Item({"item": "best", "type": "word", "@raw": "best "}), + Item({"item": "of", "type": "word", "@raw": "of "}), + Item({"item": "all", "type": "word", "@raw": "all "}), + Item({"item": "possible", "type": "word", "@raw": "possible "}), + Item({"item": "worlds:", "type": "word", "@raw": "worlds:\n"}), + Item({"item": "for", "type": "word", "@raw": "for "}), + Item({"item": "if", "type": "word", "@raw": "if "}), + Item({"item": "you", "type": "word", "@raw": "you "}), + Item({"item": "had", "type": "word", "@raw": "had "}), + Item({"item": "not", "type": "word", "@raw": "not "}), + Item({"item": "been", "type": "word", "@raw": "been "}), + Item({"item": "kicked", "type": "word", "@raw": "kicked "}), + Item({"item": "out", "type": "word", "@raw": "out "}), + Item({"item": "of", "type": "word", "@raw": "of "}), + Item({"item": "a", "type": "word", "@raw": "a "}), + Item({"item": "magnificent", "type": "word", "@raw": "magnificent "}), + Item({"item": "castle", "type": "word", "@raw": "castle "}), + Item({"item": "for", "type": "word", "@raw": "for "}), + Item({"item": "love", "type": "word", "@raw": "love "}), + Item({"item": "of", "type": "word", "@raw": "of\n"}), + Item({"item": "Miss", "type": "word", "@raw": "Miss "}), + Item({"item": "Cunegonde:", "type": "word", "@raw": "Cunegonde: "}), + Item({"item": "if", "type": "word", "@raw": "if "}), + Item({"item": "you", "type": "word", "@raw": "you "}), + Item({"item": "had", "type": "word", "@raw": "had "}), + Item({"item": "not", "type": "word", "@raw": "not "}), + Item({"item": "been", "type": "word", "@raw": "been "}), + Item({"item": "put", "type": "word", "@raw": "put "}), + Item({"item": "into", "type": "word", "@raw": "into "}), + Item({"item": "the", "type": "word", "@raw": "the "}), + Item({"item": "Inquisition:", "type": "word", "@raw": "Inquisition: "}), + Item({"item": "if", "type": "word", "@raw": "if "}), + Item({"item": "you", "type": "word", "@raw": "you "}), + Item({"item": "had", "type": "word", "@raw": "had\n"}), + Item({"item": "not", "type": "word", "@raw": "not "}), + Item({"item": "walked", "type": "word", "@raw": "walked "}), + Item({"item": "over", "type": "word", "@raw": "over "}), + Item({"item": "America:", "type": "word", "@raw": "America: "}), + Item({"item": "if", "type": "word", "@raw": "if "}), + Item({"item": "you", "type": "word", "@raw": "you "}), + Item({"item": "had", "type": "word", "@raw": "had "}), + Item({"item": "not", "type": "word", "@raw": "not "}), + Item({"item": "stabbed", "type": "word", "@raw": "stabbed "}), + Item({"item": "the", "type": "word", "@raw": "the "}), + Item({"item": "Baron:", "type": "word", "@raw": "Baron: "}), + Item({"item": "if", "type": "word", "@raw": "if "}), + Item({"item": "you", "type": "word", "@raw": "you "}), + Item({"item": "had", "type": "word", "@raw": "had\n"}), + Item({"item": "not", "type": "word", "@raw": "not "}), + Item({"item": "lost", "type": "word", "@raw": "lost "}), + Item({"item": "all", "type": "word", "@raw": "all "}), + Item({"item": "your", "type": "word", "@raw": "your "}), + Item({"item": "sheep", "type": "word", "@raw": "sheep "}), + Item({"item": "from", "type": "word", "@raw": "from "}), + Item({"item": "the", "type": "word", "@raw": "the "}), + Item({"item": "fine", "type": "word", "@raw": "fine "}), + Item({"item": "country", "type": "word", "@raw": "country "}), + Item({"item": "of", "type": "word", "@raw": "of "}), + Item({"item": "El", "type": "word", "@raw": "El "}), + Item({"item": "Dorado:", "type": "word", "@raw": "Dorado: "}), + Item({"item": "you", "type": "word", "@raw": "you "}), + Item({"item": "would", "type": "word", "@raw": "would\n"}), + Item({"item": "not", "type": "word", "@raw": "not "}), + Item({"item": "be", "type": "word", "@raw": "be "}), + Item({"item": "here", "type": "word", "@raw": "here "}), + Item({"item": "eating", "type": "word", "@raw": "eating "}), + Item({"item": "preserved", "type": "word", "@raw": "preserved "}), + Item({"item": "citrons", "type": "word", "@raw": "citrons "}), + Item({"item": "and", "type": "word", "@raw": "and "}), + Item({"item": "pistachio-nuts.\"", "type": "word", "@raw": "pistachio-nuts.\"\n\n"}), + Item({"item": "\"All", "type": "word", "@raw": "\"All "}), + Item({"item": "that", "type": "word", "@raw": "that "}), + Item({"item": "is", "type": "word", "@raw": "is "}), + Item({"item": "very", "type": "word", "@raw": "very "}), + Item({"item": "well,\"", "type": "word", "@raw": "well,\" "}), + Item({"item": "answered", "type": "word", "@raw": "answered "}), + Item({"item": "Candide,", "type": "word", "@raw": "Candide, "}), + Item({"item": "\"but", "type": "word", "@raw": "\"but "}), + Item({"item": "let", "type": "word", "@raw": "let "}), + Item({"item": "us", "type": "word", "@raw": "us "}), + Item({"item": "cultivate", "type": "word", "@raw": "cultivate "}), + Item({"item": "our", "type": "word", "@raw": "our\n"}), + Item({"item": "garden.\"", "type": "word", "@raw": "garden.\"\n"})] + + +@pytest.mark.parametrize('cls,args', [ + [PlainText, [candide]], + [File, [candide_file]], + [File, [candide_file, 'infer']], + [File, [candide_file, 'plaintext']], +]) +def test_file(cls, args): + assert list(cls(*args)) == candide_schema + assert Schema(cls(*args)) == candide_schema + + +def test_exceptions(): + with pytest.raises(ValueError) as e: + File('noextension') + assert 'without an extension' in str(e) + + with pytest.raises(ValueError) as e: + File('unknownextension.thisisntknowm') + assert 'thisisntknowm' in str(e) diff --git a/tests/benchmarkstt/normalization/test_core.py b/tests/benchmarkstt/test_normalization_core.py similarity index 100% rename from tests/benchmarkstt/normalization/test_core.py rename to tests/benchmarkstt/test_normalization_core.py diff --git a/tests/benchmarkstt/test_normalization_init.py b/tests/benchmarkstt/test_normalization_init.py new file mode 100644 index 00000000..1d2cb817 --- /dev/null +++ b/tests/benchmarkstt/test_normalization_init.py @@ -0,0 +1,44 @@ +from benchmarkstt.normalization import core +from benchmarkstt.normalization import NormalizationComposite +from benchmarkstt.normalization import factory +from benchmarkstt.factory import ClassConfig +from inspect import isgenerator +import pytest + + +def test_name_to_normalizer(): + assert factory.get_class('Replace') is core.Replace + assert factory.get_class('replace') is core.Replace + + +def test_available_normalizers(): + normalizers = iter(factory) + assert isgenerator(normalizers) + + for conf in normalizers: + name = conf.name + assert type(conf) is ClassConfig + assert factory.is_valid(conf.cls) + assert factory.get_class(name.upper()) is conf.cls + + +def test_not_available_normalizers(): + with pytest.raises(ImportError): + factory.get_class('SomeRandomUnavailableNormalizer') + + +def test_is_normalizer(): + nope = [ + True, + False, + None, + "replace", + factory.is_valid, + ClassConfig + ] + + for not_normalizer in nope: + assert factory.is_valid(not_normalizer) is False + + assert factory.is_valid(NormalizationComposite) is True + assert NormalizationComposite().normalize('NON-normalized') == 'NON-normalized' diff --git a/tests/benchmarkstt/test_schema.py b/tests/benchmarkstt/test_schema.py new file mode 100644 index 00000000..7d973cd4 --- /dev/null +++ b/tests/benchmarkstt/test_schema.py @@ -0,0 +1,144 @@ +from benchmarkstt.schema import Schema, Item, JSONEncoder +from benchmarkstt.schema import SchemaError, SchemaJSONError, SchemaInvalidItemError +import textwrap +from random import sample, randint +import pytest +from json.decoder import JSONDecodeError +import json +from collections import OrderedDict +from pytest import raises +import io + + +def test_equality(): + assert Schema.loads('[]') == Schema() + assert Schema([Item(item='test')]) != Schema() + assert Item(item='test') == {'item': 'test'} + assert Item({'item': 'test', 'item2': 55}) == Item(item='test', item2=55) + assert Item({'item2': 55, 'item': 'test'}) == Item(item='test', item2=55) + + +def test_encode(): + item = Item(item='word', start=12, end=23) + itemdict = item._asdict() + line = json.dumps(itemdict) + line_formatted = json.dumps(itemdict, indent=2) + + assert item.json() == line + assert item.json(indent=2) == line_formatted + + buffer = io.StringIO() + Schema.dump(Schema([item]), buffer) + assert ('[%s]' % (item.json(),)) == buffer.getvalue() + + buffer = io.StringIO() + Schema([item]).dump(buffer) + assert ('[%s]' % (item.json(),)) == buffer.getvalue() + + schema = Schema() + schema.append(item) + schema.append(item) + assert len(schema) is 2 + assert schema.json() == '[%s, %s]' % ((line,) * 2) + assert schema.json(indent=2) == '[\n%s,\n%s\n]' % ((textwrap.indent(line_formatted, ' '),) * 2) + assert repr(schema) == ('Schema(%s)' % (schema.json())) + + class T: + ok = False + + with raises(TypeError) as exc: + assert json.dumps(T(), cls=JSONEncoder) + assert "is not JSON serializable" in str(exc) + + +def test_decode(): + res = Schema.loads('[{"item": "test"}]') + + assert type(res) is Schema + assert len(res) is 1 + assert type(res[0]) is Item + + schema = Schema.load(io.StringIO(res.json())) + assert len(schema) is 1 + assert type(schema[0]) is Item + assert schema == res + + with pytest.raises(SchemaJSONError) as exc: + Schema.loads('{"test": "test"}') + assert "Expected a list" in str(exc) + + with pytest.raises(JSONDecodeError): + Schema.loads('InvalidJSON') + + with pytest.raises(SchemaJSONError) as exc: + Schema.loads('"test"') + assert "Expected a list" in str(exc) + + with pytest.raises(SchemaJSONError) as exc: + Schema.loads('24') + assert "Expected a list" in str(exc) + + with pytest.raises(SchemaInvalidItemError): + Schema.loads('["test"]') + + +def random_str(minlen=None, maxlen=None): + """ + Returns a random (printable) utf-8 string between approx. minlen and maxlen characters + :return: str + """ + if minlen is None: + minlen = 10 + if maxlen is None: + maxlen = 30 + return ''.join(chr(i) for i in sample(range(1, 0x10ffff), randint(minlen, maxlen)) if chr(i).isprintable()) + + +def test_roundtrip(): + schema = Schema() + testlen = 1 + for i in range(testlen): + schema.append(dict(item=random_str(), start=randint(0, 1e10), end=randint(0, 1e10))) + schema.append(Item(OrderedDict(item=random_str(), start=randint(0, 1e10), end=randint(0, 1e10)))) + + schema.extend(list(schema)) + assert len(schema) == testlen * 4 + + for item in schema: + assert type(item) is Item + + json_ = schema.json() + assert Schema.loads(json_) == schema + schema = Schema.loads(json_) + for item in schema: + assert type(item) is Item + + +def test_exceptions(): + with raises(ValueError) as exc: + Item(None, None) + assert 'Expected max 1 argument' in str(exc) + + with raises(ValueError) as exc: + Item(None, somekeyword=None) + assert "Cannot combine both a positional and keyword arguments" in str(exc) + + schema = Schema() + with raises(SchemaError) as exc: + schema.append(None) + assert "Wrong type" in str(exc) + + +def test_item(): + item1 = Item() + assert len(item1) == 0 + assert repr(item1) == 'Item({})' + + item = Item(a='a_', b='b_') + assert len(item) == 2 + for k in item: + assert k+'_' == item[k] + assert repr(item) in ['Item({"a": "a_", "b": "b_"})', + 'Item({"b": "b_", "a": "a_"})'] + + assert item1 != item diff --git a/tests/benchmarkstt/test_segmentation_core.py b/tests/benchmarkstt/test_segmentation_core.py new file mode 100644 index 00000000..f9fd494d --- /dev/null +++ b/tests/benchmarkstt/test_segmentation_core.py @@ -0,0 +1,28 @@ +from benchmarkstt.segmentation import core +from benchmarkstt.schema import Item +import pytest + + +@pytest.mark.parametrize('text,expected', [ + ('hello world! how are you doing?! ', ['hello ', 'world! ', 'how ', 'are ', 'you ', 'doing?! ']), + ('\nhello world! how are you doing?! ', ['\nhello ', 'world! ', 'how ', 'are ', 'you ', 'doing?! ']), + ('single-word', ['single-word']), + (' test', [' test']), + (' test', [' test']), + (' test ', [' test ']), + ('test ', ['test ']), + ('test B', ['test ', 'B']), + ('test B ', ['test ', 'B ']), + ('\n\n', ['\n\n']) +]) +def test_simple(text, expected): + result = list(core.Simple(text)) + assert ''.join([word['@raw'] for word in result]) == text + assert len(result) == len(expected) + + for i in range(0, len(expected)): + expected_raw = expected[i] + gotten = result[i] + assert type(gotten) is Item + assert expected_raw == gotten['@raw'] + assert expected_raw.strip() == gotten['item'] diff --git a/tox.ini b/tox.ini index 49020953..af4a321f 100644 --- a/tox.ini +++ b/tox.ini @@ -3,8 +3,9 @@ envlist = py35,py36,py37 [testenv] commands = pytest {posargs} -deps = .[test,api] +deps = .[test] [pycodestyle] count = False max-line-length = 120 +