Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add doctests #7145

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions .github/workflows/check_links.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
name: Check for Broken Links

on:
repository_dispatch:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it a repository dispatch workflow?

workflow_dispatch:
schedule:
- cron: "0 0 * * *"

jobs:
check_for_broken_links:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3

- name: Link Checker
id: lychee
uses: lycheeverse/lychee-action@v1
with:
args: './**/*.md'
fail: true

- name: Create Issue From File
if: env.lychee_exit_code != 0
uses: diffusers/create-issue-from-file@v4
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would this do?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And where does this workflow live within diffusers?

with:
title: Link Checker Report
content-filepath: ./lychee/out.md
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it automatically create the lychee directory?

labels: report, automated issue
80 changes: 80 additions & 0 deletions .github/workflows/doctests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
name: Doctests

on:
push:
branches:
- doctest*
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
repository_dispatch:
schedule:
- cron: "0 0 * * *"

env:
HF_HOME: /mnt/cache
RUN_SLOW: yes
OMP_NUM_THREADS: 16
MKL_NUM_THREADS: 16

jobs:
run_doctests:
runs-on: [single-gpu, nvidia-gpu, a10, ci]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I suspect some snippets might OOM out in case they don't have the memory optimization bits enabled? I know we cannot catch it early. Just making a note here.

container:
image: huggingface/diffusers-all-latest-gpu
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
options: --gpus 0 --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/

steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2

- name: NVIDIA-SMI
uses: actions/checkout@v3
run: |
nvidia-smi

- name: Install dependencies
run: python3 -m pip install -e .[quality,test,training]

- name: Environment
run: |
python3 utils/print_env.py

- name: Get doctest files
run: |
$(python3 -c 'from utils.tests_fetcher import get_all_doctest_files; to_test = get_all_doctest_files(); to_test = " ".join(to_test); fp = open("doc_tests.txt", "w"); fp.write(to_test); fp.close()')
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved

- name: Run doctests
env:
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
run: |
python3 -m pytest -v --make-reports doc_tests_gpu --doctest-modules $(cat doc_tests.txt) -sv --doctest-continue-on-failure --doctest-glob="*.md"

- name: Failure short reports
if: ${{ failure() }}
continue-on-error: true
run: cat reports/doc_tests_gpu/failures_short.txt

- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v3
with:
name: doc_tests_gpu_test_reports
path: reports/doc_tests_gpu

send_results:
name: Send results to webhook
runs-on: ubuntu-22.04
if: always()
needs: [run_doctests]
steps:
- uses: actions/checkout@v3
- uses: actions/download-artifact@v3
- name: Send message to Slack
env:
CI_SLACK_BOT_TOKEN: ${{ secrets.CI_SLACK_BOT_TOKEN }}
CI_SLACK_CHANNEL_ID: ${{ secrets.CI_SLACK_CHANNEL_ID_DAILY_DOCS }}
CI_SLACK_CHANNEL_ID_DAILY: ${{ secrets.CI_SLACK_CHANNEL_ID_DAILY_DOCS }}
Comment on lines +75 to +76
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just ensuring if we have configured the channels already?

CI_SLACK_CHANNEL_DUMMY_TESTS: ${{ secrets.CI_SLACK_CHANNEL_DUMMY_TESTS }}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to be informed about the dummy tests?

run: |
pip install slack_sdk
python utils/notification_service_doc_tests.py
DN6 marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ repo-consistency:
python utils/check_dummies.py
python utils/check_repo.py
python utils/check_inits.py
python utils/check_doctest_list.py

# this target runs checks on all files

Expand Down Expand Up @@ -67,6 +68,7 @@ fixup: modified_only_fixup extra_style_checks autogenerate_code repo-consistency
fix-copies:
python utils/check_copies.py --fix_and_overwrite
python utils/check_dummies.py --fix_and_overwrite
python utils/check_doctest_list.py --fix_and_overwrite

# Run tests for the library

Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,7 @@ skip-magic-trailing-comma = false

# Like Black, automatically detect the appropriate line ending.
line-ending = "auto"

[tool.pytest.ini_options]
doctest_optionflags="NUMBER NORMALIZE_WHITESPACE ELLIPSIS"
doctest_glob="**/*.md"
185 changes: 185 additions & 0 deletions src/diffusers/doctest_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
import doctest
import inspect
import os
import re
from typing import Iterable

from .utils import is_pytest_available


if is_pytest_available():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be under tests instead?

from _pytest.doctest import (
Module,
_get_checker,
_get_continue_on_failure,
_get_runner,
_is_mocked,
_patch_unwrap_mock_aware,
get_optionflags,
import_path,
)
from _pytest.outcomes import skip
from pytest import DoctestItem
else:
Module = object
DoctestItem = object

"""
The following contains utils to run the documentation tests without having to overwrite any files.

The `preprocess_string` function adds `# doctest: +IGNORE_RESULT` markers on the fly anywhere a `load_dataset` call is
made as a print would otherwise fail the corresonding line.

To skip cuda tests, make sure to call `SKIP_CUDA_DOCTEST=1 pytest --doctest-modules <path_to_files_to_test>
"""


def preprocess_string(string, skip_cuda_tests):
"""Prepare a docstring or a `.md` file to be run by doctest.

The argument `string` would be the whole file content if it is a `.md` file. For a python file, it would be one of
its docstring. In each case, it may contain multiple python code examples. If `skip_cuda_tests` is `True` and a
cuda stuff is detective (with a heuristic), this method will return an empty string so no doctest will be run for
`string`.
"""
codeblock_pattern = r"(```(?:python|py)\s*\n\s*>>> )((?:(?!```)[^])*?```)"
codeblocks = re.split(re.compile(codeblock_pattern, flags=re.MULTILINE | re.DOTALL), string)
is_cuda_found = False
for i, codeblock in enumerate(codeblocks):
if "load_dataset(" in codeblock and "# doctest: +IGNORE_RESULT" not in codeblock:
codeblocks[i] = re.sub(r"(>>> .*load_dataset\(.*)", r"\1 # doctest: +IGNORE_RESULT", codeblock)
if (
(">>>" in codeblock or "..." in codeblock)
and re.search(r"cuda|to\(0\)|device=0", codeblock)
and skip_cuda_tests
):
is_cuda_found = True
break

modified_string = ""
if not is_cuda_found:
modified_string = "".join(codeblocks)

return modified_string


class HfDocTestParser(doctest.DocTestParser):
"""
Overwrites the DocTestParser from doctest to properly parse the codeblocks that are formatted with black. This
means that there are no extra lines at the end of our snippets. The `# doctest: +IGNORE_RESULT` marker is also
added anywhere a `load_dataset` call is made as a print would otherwise fail the corresponding line.

Tests involving cuda are skipped base on a naive pattern that should be updated if it is not enough.
"""

# This regular expression is used to find doctest examples in a
# string. It defines three groups: `source` is the source code
# (including leading indentation and prompts); `indent` is the
# indentation of the first (PS1) line of the source code; and
# `want` is the expected output (including leading indentation).
# fmt: off
_EXAMPLE_RE = re.compile(r'''
# Source consists of a PS1 line followed by zero or more PS2 lines.
(?P<source>
(?:^(?P<indent> [ ]*) >>> .*) # PS1 line
(?:\n [ ]* \.\.\. .*)*) # PS2 lines
\n?
# Want consists of any non-blank lines that do not start with PS1.
(?P<want> (?:(?![ ]*$) # Not a blank line
(?![ ]*>>>) # Not a line starting with PS1
# !!!!!!!!!!! HF Specific !!!!!!!!!!!
(?:(?!```).)* # Match any character except '`' until a '```' is found (this is specific to HF because black removes the last line)
# !!!!!!!!!!! HF Specific !!!!!!!!!!!
(?:\n|$) # Match a new line or end of string
)*)
''', re.MULTILINE | re.VERBOSE
)
# fmt: on

# !!!!!!!!!!! HF Specific !!!!!!!!!!!
skip_cuda_tests: bool = bool(os.environ.get("SKIP_CUDA_DOCTEST", False))
# !!!!!!!!!!! HF Specific !!!!!!!!!!!

def parse(self, string, name="<string>"):
"""
Overwrites the `parse` method to incorporate a skip for CUDA tests, and remove logs and dataset prints before
calling `super().parse`
"""
string = preprocess_string(string, self.skip_cuda_tests)
return super().parse(string, name)


class HfDoctestModule(Module):
"""
Overwrites the `DoctestModule` of the pytest package to make sure the HFDocTestParser is used when discovering
tests.
"""

def collect(self) -> Iterable["DoctestItem"]:
class MockAwareDocTestFinder(doctest.DocTestFinder):
"""A hackish doctest finder that overrides stdlib internals to fix a stdlib bug.

https://github.com/pytest-dev/pytest/issues/3456 https://bugs.python.org/issue25532
"""

def _find_lineno(self, obj, source_lines):
"""Doctest code does not take into account `@property`, this
is a hackish way to fix it. https://bugs.python.org/issue17446

Wrapped Doctests will need to be unwrapped so the correct line number is returned. This will be
reported upstream. #8796
"""
if isinstance(obj, property):
obj = getattr(obj, "fget", obj)

if hasattr(obj, "__wrapped__"):
# Get the main obj in case of it being wrapped
obj = inspect.unwrap(obj)

# Type ignored because this is a private function.
return super()._find_lineno( # type:ignore[misc]
obj,
source_lines,
)

def _find(self, tests, obj, name, module, source_lines, globs, seen) -> None:
if _is_mocked(obj):
return
with _patch_unwrap_mock_aware():
# Type ignored because this is a private function.
super()._find( # type:ignore[misc]
tests, obj, name, module, source_lines, globs, seen
)

if self.path.name == "conftest.py":
module = self.config.pluginmanager._importconftest(
self.path,
self.config.getoption("importmode"),
rootpath=self.config.rootpath,
)
else:
try:
module = import_path(
self.path,
root=self.config.rootpath,
mode=self.config.getoption("importmode"),
)
except ImportError:
if self.config.getvalue("doctest_ignore_import_errors"):
skip("unable to import module %r" % self.path)
else:
raise

# !!!!!!!!!!! HF Specific !!!!!!!!!!!
finder = MockAwareDocTestFinder(parser=HfDocTestParser())
# !!!!!!!!!!! HF Specific !!!!!!!!!!!
optionflags = get_optionflags(self)
runner = _get_runner(
verbose=False,
optionflags=optionflags,
checker=_get_checker(),
continue_on_failure=_get_continue_on_failure(self.config),
)
for test in finder.find(module, module.__name__):
if test.examples: # skip empty doctests and cuda
yield DoctestItem.from_parent(self, name=test.name, runner=runner, dtest=test)
1 change: 1 addition & 0 deletions src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
is_note_seq_available,
is_onnx_available,
is_peft_available,
is_pytest_available,
is_scipy_available,
is_tensorboard_available,
is_torch_available,
Expand Down
11 changes: 11 additions & 0 deletions src/diffusers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,13 @@
except importlib_metadata.PackageNotFoundError:
_peft_available = False

_pytest_available = importlib.util.find_spec("pytest") is not None
try:
_pytest_version = importlib_metadata.version("pytest")
logger.debug(f"Successfully imported pytest version {_pytest_version}")
except importlib_metadata.PackageNotFoundError:
_pytest_available = False

_torchvision_available = importlib.util.find_spec("torchvision") is not None
try:
_torchvision_version = importlib_metadata.version("torchvision")
Expand Down Expand Up @@ -374,6 +381,10 @@ def is_peft_available():
return _peft_available


def is_pytest_available():
return _pytest_available


def is_torchvision_available():
return _torchvision_available

Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/utils/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,7 @@ class CaptureLogger:
>>> logger = logging.get_logger("diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.py")
>>> with CaptureLogger(logger) as cl:
... logger.info(msg)
>>> assert cl.out, msg + "\n"
>>> assert cl.out, msg + \n
```
"""

Expand Down
Loading
Loading