diff --git a/.github/workflows/check_links.yml b/.github/workflows/check_links.yml new file mode 100644 index 000000000000..da36848cfd50 --- /dev/null +++ b/.github/workflows/check_links.yml @@ -0,0 +1,28 @@ +name: Check for Broken Links + +on: + repository_dispatch: + workflow_dispatch: + schedule: + - cron: "0 0 * * *" + +jobs: + check_for_broken_links: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + + - name: Link Checker + id: lychee + uses: lycheeverse/lychee-action@v1 + with: + args: './**/*.md' + fail: true + + - name: Create Issue From File + if: env.lychee_exit_code != 0 + uses: diffusers/create-issue-from-file@v4 + with: + title: Link Checker Report + content-filepath: ./lychee/out.md + labels: report, automated issue \ No newline at end of file diff --git a/.github/workflows/doctests.yml b/.github/workflows/doctests.yml new file mode 100644 index 000000000000..6f2bbc173c8a --- /dev/null +++ b/.github/workflows/doctests.yml @@ -0,0 +1,80 @@ +name: Doctests + +on: + push: + branches: + - doctest* + repository_dispatch: + schedule: + - cron: "0 0 * * *" + +env: + HF_HOME: /mnt/cache + RUN_SLOW: yes + OMP_NUM_THREADS: 16 + MKL_NUM_THREADS: 16 + +jobs: + run_doctests: + runs-on: [single-gpu, nvidia-gpu, a10, ci] + container: + image: huggingface/diffusers-all-latest-gpu + options: --gpus 0 --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/ + + steps: + - name: Checkout diffusers + uses: actions/checkout@v3 + with: + fetch-depth: 2 + + - name: NVIDIA-SMI + uses: actions/checkout@v3 + run: | + nvidia-smi + + - name: Install dependencies + run: python3 -m pip install -e .[quality,test,training] + + - name: Environment + run: | + python3 utils/print_env.py + + - name: Get doctest files + run: | + $(python3 -c 'from utils.tests_fetcher import get_all_doctest_files; to_test = get_all_doctest_files(); to_test = " ".join(to_test); fp = open("doc_tests.txt", "w"); fp.write(to_test); fp.close()') + + - name: Run doctests + env: + HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} + run: | + python3 -m pytest -v --make-reports doc_tests_gpu --doctest-modules $(cat doc_tests.txt) -sv --doctest-continue-on-failure --doctest-glob="*.md" + + - name: Failure short reports + if: ${{ failure() }} + continue-on-error: true + run: cat reports/doc_tests_gpu/failures_short.txt + + - name: Test suite reports artifacts + if: ${{ always() }} + uses: actions/upload-artifact@v3 + with: + name: doc_tests_gpu_test_reports + path: reports/doc_tests_gpu + + send_results: + name: Send results to webhook + runs-on: ubuntu-22.04 + if: always() + needs: [run_doctests] + steps: + - uses: actions/checkout@v3 + - uses: actions/download-artifact@v3 + - name: Send message to Slack + env: + CI_SLACK_BOT_TOKEN: ${{ secrets.CI_SLACK_BOT_TOKEN }} + CI_SLACK_CHANNEL_ID: ${{ secrets.CI_SLACK_CHANNEL_ID_DAILY_DOCS }} + CI_SLACK_CHANNEL_ID_DAILY: ${{ secrets.CI_SLACK_CHANNEL_ID_DAILY_DOCS }} + CI_SLACK_CHANNEL_DUMMY_TESTS: ${{ secrets.CI_SLACK_CHANNEL_DUMMY_TESTS }} + run: | + pip install slack_sdk + python utils/notification_service_doc_tests.py diff --git a/Makefile b/Makefile index c92285b48c71..0ea26399a9ca 100644 --- a/Makefile +++ b/Makefile @@ -36,6 +36,7 @@ repo-consistency: python utils/check_dummies.py python utils/check_repo.py python utils/check_inits.py + python utils/check_doctest_list.py # this target runs checks on all files @@ -67,6 +68,7 @@ fixup: modified_only_fixup extra_style_checks autogenerate_code repo-consistency fix-copies: python utils/check_copies.py --fix_and_overwrite python utils/check_dummies.py --fix_and_overwrite + python utils/check_doctest_list.py --fix_and_overwrite # Run tests for the library diff --git a/pyproject.toml b/pyproject.toml index 0612f2f9e059..55de12d36bfa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,3 +25,7 @@ skip-magic-trailing-comma = false # Like Black, automatically detect the appropriate line ending. line-ending = "auto" + +[tool.pytest.ini_options] +doctest_optionflags="NUMBER NORMALIZE_WHITESPACE ELLIPSIS" +doctest_glob="**/*.md" \ No newline at end of file diff --git a/src/diffusers/doctest_utils.py b/src/diffusers/doctest_utils.py new file mode 100644 index 000000000000..2157db4ac2b5 --- /dev/null +++ b/src/diffusers/doctest_utils.py @@ -0,0 +1,185 @@ +import doctest +import inspect +import os +import re +from typing import Iterable + +from .utils import is_pytest_available + + +if is_pytest_available(): + from _pytest.doctest import ( + Module, + _get_checker, + _get_continue_on_failure, + _get_runner, + _is_mocked, + _patch_unwrap_mock_aware, + get_optionflags, + import_path, + ) + from _pytest.outcomes import skip + from pytest import DoctestItem +else: + Module = object + DoctestItem = object + +""" +The following contains utils to run the documentation tests without having to overwrite any files. + +The `preprocess_string` function adds `# doctest: +IGNORE_RESULT` markers on the fly anywhere a `load_dataset` call is +made as a print would otherwise fail the corresonding line. + +To skip cuda tests, make sure to call `SKIP_CUDA_DOCTEST=1 pytest --doctest-modules +""" + + +def preprocess_string(string, skip_cuda_tests): + """Prepare a docstring or a `.md` file to be run by doctest. + + The argument `string` would be the whole file content if it is a `.md` file. For a python file, it would be one of + its docstring. In each case, it may contain multiple python code examples. If `skip_cuda_tests` is `True` and a + cuda stuff is detective (with a heuristic), this method will return an empty string so no doctest will be run for + `string`. + """ + codeblock_pattern = r"(```(?:python|py)\s*\n\s*>>> )((?:(?!```)[^])*?```)" + codeblocks = re.split(re.compile(codeblock_pattern, flags=re.MULTILINE | re.DOTALL), string) + is_cuda_found = False + for i, codeblock in enumerate(codeblocks): + if "load_dataset(" in codeblock and "# doctest: +IGNORE_RESULT" not in codeblock: + codeblocks[i] = re.sub(r"(>>> .*load_dataset\(.*)", r"\1 # doctest: +IGNORE_RESULT", codeblock) + if ( + (">>>" in codeblock or "..." in codeblock) + and re.search(r"cuda|to\(0\)|device=0", codeblock) + and skip_cuda_tests + ): + is_cuda_found = True + break + + modified_string = "" + if not is_cuda_found: + modified_string = "".join(codeblocks) + + return modified_string + + +class HfDocTestParser(doctest.DocTestParser): + """ + Overwrites the DocTestParser from doctest to properly parse the codeblocks that are formatted with black. This + means that there are no extra lines at the end of our snippets. The `# doctest: +IGNORE_RESULT` marker is also + added anywhere a `load_dataset` call is made as a print would otherwise fail the corresponding line. + + Tests involving cuda are skipped base on a naive pattern that should be updated if it is not enough. + """ + + # This regular expression is used to find doctest examples in a + # string. It defines three groups: `source` is the source code + # (including leading indentation and prompts); `indent` is the + # indentation of the first (PS1) line of the source code; and + # `want` is the expected output (including leading indentation). + # fmt: off + _EXAMPLE_RE = re.compile(r''' + # Source consists of a PS1 line followed by zero or more PS2 lines. + (?P + (?:^(?P [ ]*) >>> .*) # PS1 line + (?:\n [ ]* \.\.\. .*)*) # PS2 lines + \n? + # Want consists of any non-blank lines that do not start with PS1. + (?P (?:(?![ ]*$) # Not a blank line + (?![ ]*>>>) # Not a line starting with PS1 + # !!!!!!!!!!! HF Specific !!!!!!!!!!! + (?:(?!```).)* # Match any character except '`' until a '```' is found (this is specific to HF because black removes the last line) + # !!!!!!!!!!! HF Specific !!!!!!!!!!! + (?:\n|$) # Match a new line or end of string + )*) + ''', re.MULTILINE | re.VERBOSE + ) + # fmt: on + + # !!!!!!!!!!! HF Specific !!!!!!!!!!! + skip_cuda_tests: bool = bool(os.environ.get("SKIP_CUDA_DOCTEST", False)) + # !!!!!!!!!!! HF Specific !!!!!!!!!!! + + def parse(self, string, name=""): + """ + Overwrites the `parse` method to incorporate a skip for CUDA tests, and remove logs and dataset prints before + calling `super().parse` + """ + string = preprocess_string(string, self.skip_cuda_tests) + return super().parse(string, name) + + +class HfDoctestModule(Module): + """ + Overwrites the `DoctestModule` of the pytest package to make sure the HFDocTestParser is used when discovering + tests. + """ + + def collect(self) -> Iterable["DoctestItem"]: + class MockAwareDocTestFinder(doctest.DocTestFinder): + """A hackish doctest finder that overrides stdlib internals to fix a stdlib bug. + + https://github.com/pytest-dev/pytest/issues/3456 https://bugs.python.org/issue25532 + """ + + def _find_lineno(self, obj, source_lines): + """Doctest code does not take into account `@property`, this + is a hackish way to fix it. https://bugs.python.org/issue17446 + + Wrapped Doctests will need to be unwrapped so the correct line number is returned. This will be + reported upstream. #8796 + """ + if isinstance(obj, property): + obj = getattr(obj, "fget", obj) + + if hasattr(obj, "__wrapped__"): + # Get the main obj in case of it being wrapped + obj = inspect.unwrap(obj) + + # Type ignored because this is a private function. + return super()._find_lineno( # type:ignore[misc] + obj, + source_lines, + ) + + def _find(self, tests, obj, name, module, source_lines, globs, seen) -> None: + if _is_mocked(obj): + return + with _patch_unwrap_mock_aware(): + # Type ignored because this is a private function. + super()._find( # type:ignore[misc] + tests, obj, name, module, source_lines, globs, seen + ) + + if self.path.name == "conftest.py": + module = self.config.pluginmanager._importconftest( + self.path, + self.config.getoption("importmode"), + rootpath=self.config.rootpath, + ) + else: + try: + module = import_path( + self.path, + root=self.config.rootpath, + mode=self.config.getoption("importmode"), + ) + except ImportError: + if self.config.getvalue("doctest_ignore_import_errors"): + skip("unable to import module %r" % self.path) + else: + raise + + # !!!!!!!!!!! HF Specific !!!!!!!!!!! + finder = MockAwareDocTestFinder(parser=HfDocTestParser()) + # !!!!!!!!!!! HF Specific !!!!!!!!!!! + optionflags = get_optionflags(self) + runner = _get_runner( + verbose=False, + optionflags=optionflags, + checker=_get_checker(), + continue_on_failure=_get_continue_on_failure(self.config), + ) + for test in finder.find(module, module.__name__): + if test.examples: # skip empty doctests and cuda + yield DoctestItem.from_parent(self, name=test.name, runner=runner, dtest=test) diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 35aba10d7e58..7bf73e1a3a3c 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -69,6 +69,7 @@ is_note_seq_available, is_onnx_available, is_peft_available, + is_pytest_available, is_scipy_available, is_tensorboard_available, is_torch_available, diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 9c916737d104..e355d473ce5e 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -278,6 +278,13 @@ except importlib_metadata.PackageNotFoundError: _peft_available = False +_pytest_available = importlib.util.find_spec("pytest") is not None +try: + _pytest_version = importlib_metadata.version("pytest") + logger.debug(f"Successfully imported pytest version {_pytest_version}") +except importlib_metadata.PackageNotFoundError: + _pytest_available = False + _torchvision_available = importlib.util.find_spec("torchvision") is not None try: _torchvision_version = importlib_metadata.version("torchvision") @@ -374,6 +381,10 @@ def is_peft_available(): return _peft_available +def is_pytest_available(): + return _pytest_available + + def is_torchvision_available(): return _torchvision_available diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index edbf6f31a833..d19838eb537b 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -779,7 +779,7 @@ class CaptureLogger: >>> logger = logging.get_logger("diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.py") >>> with CaptureLogger(logger) as cl: ... logger.info(msg) - >>> assert cl.out, msg + "\n" + >>> assert cl.out, msg + \n ``` """ diff --git a/utils/check_doctest_list.py b/utils/check_doctest_list.py new file mode 100644 index 000000000000..89eb981211e9 --- /dev/null +++ b/utils/check_doctest_list.py @@ -0,0 +1,85 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script is responsible for cleaning the list of doctests by making sure the entries all exist and are in +alphabetical order. + +Usage (from the root of the repo): + +Check that the doctest list is properly sorted and all files exist (used in `make repo-consistency`): + +```bash +python utils/check_doctest_list.py +``` + +Auto-sort the doctest list if it is not properly sorted (used in `make fix-copies`): + +```bash +python utils/check_doctest_list.py --fix_and_overwrite +``` +""" +import argparse +import os + + +# All paths are set with the intent you should run this script from the root of the repo with the command +# python utils/check_doctest_list.py +REPO_PATH = "." +DOCTEST_FILE_PATHS = ["not_doctested.txt"] + + +def clean_doctest_list(doctest_file: str, overwrite: bool = False): + """ + Cleans the doctest in a given file. + + Args: + doctest_file (`str`): + The path to the doctest file to check or clean. + overwrite (`bool`, *optional*, defaults to `False`): + Whether or not to fix problems. If `False`, will error when the file is not clean. + """ + non_existent_paths = [] + all_paths = [] + with open(doctest_file, "r", encoding="utf-8") as f: + for line in f: + line = line.strip().split(" ")[0] + path = os.path.join(REPO_PATH, line) + if not (os.path.isfile(path) or os.path.isdir(path)): + non_existent_paths.append(line) + all_paths.append(line) + + if len(non_existent_paths) > 0: + non_existent_paths = "\n".join([f"- {f}" for f in non_existent_paths]) + raise ValueError(f"`{doctest_file}` contains non-existent paths:\n{non_existent_paths}") + + sorted_paths = sorted(all_paths) + if all_paths != sorted_paths: + if not overwrite: + raise ValueError( + f"Files in `{doctest_file}` are not in alphabetical order, run `make fix-copies` to fix " + "this automatically." + ) + with open(doctest_file, "w", encoding="utf-8") as f: + f.write("\n".join(sorted_paths) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.") + args = parser.parse_args() + + for doctest_file in DOCTEST_FILE_PATHS: + doctest_file = os.path.join(REPO_PATH, "utils", doctest_file) + clean_doctest_list(doctest_file, args.fix_and_overwrite) diff --git a/utils/not_doctested.txt b/utils/not_doctested.txt new file mode 100644 index 000000000000..b6998d0eb0e9 --- /dev/null +++ b/utils/not_doctested.txt @@ -0,0 +1,100 @@ +docs/source/en/training/create_dataset.md +docs/source/en/training/wuerstchen.md +docs/source/en/training/adapt_a_model.md +docs/source/en/training/text2image.md +docs/source/en/training/custom_diffusion.md +docs/source/en/training/sdxl.md +docs/source/en/training/unconditional_training.md +docs/source/en/training/overview.md +docs/source/en/training/t2i_adapters.md +docs/source/en/training/lcm_distill.md +docs/source/en/training/instructpix2pix.md +docs/source/en/training/kandinsky.md +docs/source/en/training/lora.md +docs/source/en/training/controlnet.md +docs/source/en/training/dreambooth.md +docs/source/en/training/ddpo.md +docs/source/en/training/text_inversion.md +docs/source/en/training/distributed_inference.md +docs/source/en/optimization/torch2.0.md +docs/source/en/optimization/coreml.md +docs/source/en/optimization/tome.md +docs/source/en/optimization/xformers.md +docs/source/en/optimization/deepcache.md +docs/source/en/optimization/fp16.md +docs/source/en/optimization/memory.md +docs/source/en/optimization/habana.md +docs/source/en/optimization/open_vino.md +docs/source/en/optimization/mps.md +docs/source/en/optimization/opt_overview.md +docs/source/en/optimization/onnx.md +docs/source/en/tutorials/basic_training.md +docs/source/ko/index.md +docs/source/ko/quicktour.md +docs/source/ko/in_translation.md +docs/source/ko/installation.md +docs/source/ko/stable_diffusion.md +docs/source/ko/training/create_dataset.md +docs/source/ko/training/wuerstchen.md +docs/source/ko/training/adapt_a_model.md +docs/source/ko/training/text2image.md +docs/source/ko/training/custom_diffusion.md +docs/source/ko/training/sdxl.md +docs/source/ko/training/unconditional_training.md +docs/source/ko/training/overview.md +docs/source/ko/training/t2i_adapters.md +docs/source/ko/training/lcm_distill.md +docs/source/ko/training/instructpix2pix.md +docs/source/ko/training/kandinsky.md +docs/source/ko/training/lora.md +docs/source/ko/training/controlnet.md +docs/source/ko/training/dreambooth.md +docs/source/ko/training/ddpo.md +docs/source/ko/training/text_inversion.md +docs/source/ko/training/distributed_inference.md +docs/source/ko/optimization/torch2.0.md +docs/source/ko/optimization/coreml.md +docs/source/ko/optimization/tome.md +docs/source/ko/optimization/xformers.md +docs/source/ko/optimization/deepcache.md +docs/source/ko/optimization/fp16.md +docs/source/ko/optimization/memory.md +docs/source/ko/optimization/habana.md +docs/source/ko/optimization/open_vino.md +docs/source/ko/optimization/mps.md +docs/source/ko/optimization/opt_overview.md +docs/source/ko/optimization/onnx.md +docs/source/ko/api/pipelines/stable_diffusion/stable_diffusion_xl.md +docs/source/ko/tutorials/basic_training.md +docs/source/ko/using-diffusers/loading.md +docs/source/ko/using-diffusers/unconditional_image_generation.md +docs/source/ko/using-diffusers/depth2img.md +docs/source/ko/using-diffusers/control_brightness.md +docs/source/ko/using-diffusers/contribute_pipeline.md +docs/source/ko/using-diffusers/img2img.md +docs/source/ko/using-diffusers/weighted_prompts.md +docs/source/ko/using-diffusers/schedulers.md +docs/source/ko/using-diffusers/custom_pipeline_examples.md +docs/source/ko/using-diffusers/using_safetensors.md +docs/source/ko/using-diffusers/reproducibility.md +docs/source/ko/using-diffusers/inpaint.md +docs/source/ko/using-diffusers/conditional_image_generation.md +docs/source/ko/using-diffusers/controlling_generation.md +docs/source/ko/using-diffusers/reusing_seeds.md +docs/source/ko/using-diffusers/textual_inversion_inference.md +docs/source/ko/using-diffusers/loading_overview.md +docs/source/ko/using-diffusers/custom_pipeline_overview.md +docs/source/ko/using-diffusers/other-formats.md +docs/source/ko/using-diffusers/stable_diffusion_jax_how_to.md +docs/source/ko/using-diffusers/pipeline_overview.md +docs/source/ko/using-diffusers/write_own_pipeline.md +docs/source/pt/index.md +docs/source/pt/quicktour.md +docs/source/pt/in_translation.md +docs/source/pt/installation.md +docs/source/pt/stable_diffusion.md +docs/source/ja/index.md +docs/source/ja/quicktour.md +docs/source/ja/in_translation.md +docs/source/ja/installation.md +docs/source/ja/stable_diffusion.md \ No newline at end of file diff --git a/utils/notification_service_doc_tests.py b/utils/notification_service_doc_tests.py new file mode 100644 index 000000000000..e45291484320 --- /dev/null +++ b/utils/notification_service_doc_tests.py @@ -0,0 +1,401 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import json +import math +import os +import re +import time +from fnmatch import fnmatch +from typing import Dict, List + +import requests +from slack_sdk import WebClient + + +client = WebClient(token=os.environ["CI_SLACK_BOT_TOKEN"]) + + +def handle_test_results(test_results): + expressions = test_results.split(" ") + + failed = 0 + success = 0 + + # When the output is short enough, the output is surrounded by = signs: "== OUTPUT ==" + # When it is too long, those signs are not present. + time_spent = expressions[-2] if "=" in expressions[-1] else expressions[-1] + + for i, expression in enumerate(expressions): + if "failed" in expression: + failed += int(expressions[i - 1]) + if "passed" in expression: + success += int(expressions[i - 1]) + + return failed, success, time_spent + + +def extract_first_line_failure(failures_short_lines): + failures = {} + file = None + in_error = False + for line in failures_short_lines.split("\n"): + if re.search(r"_ \[doctest\]", line): + in_error = True + file = line.split(" ")[2] + elif in_error and not line.split(" ")[0].isdigit(): + failures[file] = line + in_error = False + + return failures + + +class Message: + def __init__(self, title: str, doc_test_results: Dict): + self.title = title + + self._time_spent = doc_test_results["time_spent"].split(",")[0] + self.n_success = doc_test_results["success"] + self.n_failures = doc_test_results["failures"] + self.n_tests = self.n_success + self.n_failures + + # Failures and success of the modeling tests + self.doc_test_results = doc_test_results + + @property + def time(self) -> str: + time_spent = [self._time_spent] + total_secs = 0 + + for time in time_spent: + time_parts = time.split(":") + + # Time can be formatted as xx:xx:xx, as .xx, or as x.xx if the time spent was less than a minute. + if len(time_parts) == 1: + time_parts = [0, 0, time_parts[0]] + + hours, minutes, seconds = int(time_parts[0]), int(time_parts[1]), float(time_parts[2]) + total_secs += hours * 3600 + minutes * 60 + seconds + + hours, minutes, seconds = total_secs // 3600, (total_secs % 3600) // 60, total_secs % 60 + return f"{int(hours)}h{int(minutes)}m{int(seconds)}s" + + @property + def header(self) -> Dict: + return {"type": "header", "text": {"type": "plain_text", "text": self.title}} + + @property + def no_failures(self) -> Dict: + return { + "type": "section", + "text": { + "type": "plain_text", + "text": f"🌞 There were no failures: all {self.n_tests} tests passed. The suite ran in {self.time}.", + "emoji": True, + }, + "accessory": { + "type": "button", + "text": {"type": "plain_text", "text": "Check Action results", "emoji": True}, + "url": f"https://github.com/huggingface/transformers/actions/runs/{os.environ['GITHUB_RUN_ID']}", + }, + } + + @property + def failures(self) -> Dict: + return { + "type": "section", + "text": { + "type": "plain_text", + "text": ( + f"There were {self.n_failures} failures, out of {self.n_tests} tests.\nThe suite ran in" + f" {self.time}." + ), + "emoji": True, + }, + "accessory": { + "type": "button", + "text": {"type": "plain_text", "text": "Check Action results", "emoji": True}, + "url": f"https://github.com/huggingface/transformers/actions/runs/{os.environ['GITHUB_RUN_ID']}", + }, + } + + @property + def category_failures(self) -> List[Dict]: + failure_blocks = [] + + MAX_ERROR_TEXT = 3000 - len("The following examples had failures:\n\n\n\n") - len("[Truncated]\n") + line_length = 40 + category_failures = {k: v["failed"] for k, v in doc_test_results.items() if isinstance(v, dict)} + + def single_category_failures(category, failures): + text = "" + if len(failures) == 0: + return "" + text += f"*{category} failures*:".ljust(line_length // 2).rjust(line_length // 2) + "\n" + + for idx, failure in enumerate(failures): + new_text = text + f"`{failure}`\n" + if len(new_text) > MAX_ERROR_TEXT: + text = text + "[Truncated]\n" + break + text = new_text + + return text + + for category, failures in category_failures.items(): + report = single_category_failures(category, failures) + if len(report) == 0: + continue + block = { + "type": "section", + "text": { + "type": "mrkdwn", + "text": f"The following examples had failures:\n\n\n{report}\n", + }, + } + failure_blocks.append(block) + + return failure_blocks + + @property + def payload(self) -> str: + blocks = [self.header] + + if self.n_failures > 0: + blocks.append(self.failures) + + if self.n_failures > 0: + blocks.extend(self.category_failures) + + if self.n_failures == 0: + blocks.append(self.no_failures) + + return json.dumps(blocks) + + @staticmethod + def error_out(): + payload = [ + { + "type": "section", + "text": { + "type": "plain_text", + "text": "There was an issue running the tests.", + }, + "accessory": { + "type": "button", + "text": {"type": "plain_text", "text": "Check Action results", "emoji": True}, + "url": f"https://github.com/huggingface/transformers/actions/runs/{os.environ['GITHUB_RUN_ID']}", + }, + } + ] + + print("Sending the following payload") + print(json.dumps({"blocks": json.loads(payload)})) + + client.chat_postMessage( + channel=os.environ["CI_SLACK_CHANNEL_ID_DAILY"], + text="There was an issue running the tests.", + blocks=payload, + ) + + def post(self): + print("Sending the following payload") + print(json.dumps({"blocks": json.loads(self.payload)})) + + text = f"{self.n_failures} failures out of {self.n_tests} tests," if self.n_failures else "All tests passed." + + self.thread_ts = client.chat_postMessage( + channel=os.environ["CI_SLACK_CHANNEL_ID_DAILY"], + blocks=self.payload, + text=text, + ) + + def get_reply_blocks(self, job_name, job_link, failures, text): + # `text` must be less than 3001 characters in Slack SDK + # keep some room for adding "[Truncated]" when necessary + MAX_ERROR_TEXT = 3000 - len("[Truncated]") + + failure_text = "" + for key, value in failures.items(): + new_text = failure_text + f"*{key}*\n_{value}_\n\n" + if len(new_text) > MAX_ERROR_TEXT: + # `failure_text` here has length <= 3000 + failure_text = failure_text + "[Truncated]" + break + # `failure_text` here has length <= MAX_ERROR_TEXT + failure_text = new_text + + title = job_name + content = {"type": "section", "text": {"type": "mrkdwn", "text": text}} + + if job_link is not None: + content["accessory"] = { + "type": "button", + "text": {"type": "plain_text", "text": "GitHub Action job", "emoji": True}, + "url": job_link, + } + + return [ + {"type": "header", "text": {"type": "plain_text", "text": title.upper(), "emoji": True}}, + content, + {"type": "section", "text": {"type": "mrkdwn", "text": failure_text}}, + ] + + def post_reply(self): + if self.thread_ts is None: + raise ValueError("Can only post reply if a post has been made.") + + job_link = self.doc_test_results.pop("job_link") + self.doc_test_results.pop("failures") + self.doc_test_results.pop("success") + self.doc_test_results.pop("time_spent") + + sorted_dict = sorted(self.doc_test_results.items(), key=lambda t: t[0]) + for job, job_result in sorted_dict: + if len(job_result["failures"]): + text = f"*Num failures* :{len(job_result['failed'])} \n" + failures = job_result["failures"] + blocks = self.get_reply_blocks(job, job_link, failures, text=text) + + print("Sending the following reply") + print(json.dumps({"blocks": blocks})) + + client.chat_postMessage( + channel=os.environ["CI_SLACK_CHANNEL_ID_DAILY"], + text=f"Results for {job}", + blocks=blocks, + thread_ts=self.thread_ts["ts"], + ) + + time.sleep(1) + + +def get_job_links(): + run_id = os.environ["GITHUB_RUN_ID"] + url = f"https://api.github.com/repos/huggingface/transformers/actions/runs/{run_id}/jobs?per_page=100" + result = requests.get(url).json() + jobs = {} + + try: + jobs.update({job["name"]: job["html_url"] for job in result["jobs"]}) + pages_to_iterate_over = math.ceil((result["total_count"] - 100) / 100) + + for i in range(pages_to_iterate_over): + result = requests.get(url + f"&page={i + 2}").json() + jobs.update({job["name"]: job["html_url"] for job in result["jobs"]}) + + return jobs + except Exception as e: + print("Unknown error, could not fetch links.", e) + + return {} + + +def retrieve_artifact(name: str): + _artifact = {} + + if os.path.exists(name): + files = os.listdir(name) + for file in files: + try: + with open(os.path.join(name, file), encoding="utf-8") as f: + _artifact[file.split(".")[0]] = f.read() + except UnicodeDecodeError as e: + raise ValueError(f"Could not open {os.path.join(name, file)}.") from e + + return _artifact + + +def retrieve_available_artifacts(): + class Artifact: + def __init__(self, name: str): + self.name = name + self.paths = [] + + def __str__(self): + return self.name + + def add_path(self, path: str): + self.paths.append({"name": self.name, "path": path}) + + _available_artifacts: Dict[str, Artifact] = {} + + directories = filter(os.path.isdir, os.listdir()) + for directory in directories: + artifact_name = directory + if artifact_name not in _available_artifacts: + _available_artifacts[artifact_name] = Artifact(artifact_name) + + _available_artifacts[artifact_name].add_path(directory) + + return _available_artifacts + + +if __name__ == "__main__": + github_actions_job_links = get_job_links() + available_artifacts = retrieve_available_artifacts() + + docs = collections.OrderedDict( + [ + ("*.py", "API Examples"), + ("*.md", "MD Examples"), + ] + ) + + # This dict will contain all the information relative to each doc test category: + # - failed: list of failed tests + # - failures: dict in the format 'test': 'error_message' + doc_test_results = { + v: { + "failed": [], + "failures": {}, + } + for v in docs.values() + } + + # Link to the GitHub Action job + doc_test_results["job_link"] = github_actions_job_links.get("run_doctests") + + artifact_path = available_artifacts["doc_tests_gpu_test_reports"].paths[0] + artifact = retrieve_artifact(artifact_path["name"]) + if "stats" in artifact: + failed, success, time_spent = handle_test_results(artifact["stats"]) + doc_test_results["failures"] = failed + doc_test_results["success"] = success + doc_test_results["time_spent"] = time_spent[1:-1] + ", " + + all_failures = extract_first_line_failure(artifact["failures_short"]) + for line in artifact["summary_short"].split("\n"): + if re.search("FAILED", line): + line = line.replace("FAILED ", "") + line = line.split()[0].replace("\n", "") + + if "::" in line: + file_path, test = line.split("::") + else: + file_path, test = line, line + + for file_regex in docs.keys(): + if fnmatch(file_path, file_regex): + category = docs[file_regex] + doc_test_results[category]["failed"].append(test) + + failure = all_failures[test] if test in all_failures else "N/A" + doc_test_results[category]["failures"][test] = failure + break + + message = Message("🤗 Results of the doc tests.", doc_test_results) + message.post() + message.post_reply()