Skip to content

Commit

Permalink
Use Pandoc for converting Github Markdown into docstring RST
Browse files Browse the repository at this point in the history
* Proof of concept usage of pandoc

* Regenerate opset

* Use gfm instead of markdown format

Seems like it mostly affects usage of unicode chars (like for ellipsis, quotes) - it leaves them as ASCII instead.

* Use simple replace to preserve Tensor<T> notation

* Point relative hyperlinks to current onnx/docs

* Fix indents when docstring isn't generated

* Add comment to pandoc req

* Batch pandoc calls for ~9x faster runtime

* Move pandoc import to call site

* Move pandoc import further down

* Don't parametrise on separator

Co-authored-by: Jakub Bachurski <[email protected]>
  • Loading branch information
JakubBachurskiQC and jbachurski authored Jan 19, 2023
1 parent 27707ce commit 4942117
Show file tree
Hide file tree
Showing 5 changed files with 4,115 additions and 2,418 deletions.
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ dependencies:
- numpydoc
- onnx>=1.13.0
- onnxruntime>=1.13.1
- pandoc # ONNX docstrings in opset generation
- pip
- pre-commit
- pytest>=6
Expand Down
103 changes: 88 additions & 15 deletions src/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@
("If", "else_branch", ("Callable[[], Iterable[Var]]", "AttrGraph")),
]

with importlib.resources.path("spox", ".") as path:
_TEMPLATE_DIR = path.parent / "templates"
with importlib.resources.path("spox", ".") as _resource_path:
_TEMPLATE_DIR = _resource_path.parent / "templates"


@dataclass
Expand Down Expand Up @@ -235,20 +235,74 @@ def get_constructor_return(schema: onnx.defs.OpSchema) -> str:
return "Var"


def format_github_markdown(doc: str) -> str:
_PANDOC_SEP = "\U0001f6a7" # U+1F6A7 CONSTRUCTION SIGN
_PANDOC_GFM_TO_RST_CACHE: Dict[str, str] = {}


def _pandoc_gfm_to_rst_run(*args: str) -> Tuple[str, ...]:
if not args:
return ()

import pandoc

sep = f"\n\n{_PANDOC_SEP}{_PANDOC_SEP}\n\n"
acc = sep.join([_PANDOC_SEP] + list(args) + [_PANDOC_SEP])
acc_results = pandoc.write(pandoc.read(acc, format="gfm"), format="rst")
_, *results, _ = acc_results.split(sep)
for arg, result in zip(args, results):
if _PANDOC_SEP in result:
raise ValueError(
f"Pandoc separator character '{_PANDOC_SEP}' found in a result (bad convert)."
)
_PANDOC_GFM_TO_RST_CACHE[arg] = result + "\n"
return results


def _pandoc_gfm_to_rst(*args: str) -> Tuple[str, ...]:
args = tuple(arg.strip() for arg in args)
if any(_PANDOC_SEP in arg for arg in args):
raise ValueError(
f"Pandoc separator character '{_PANDOC_SEP}' cannot appear in any of the arguments."
)
valid = [
i
for i, arg in enumerate(args)
if not (arg in _PANDOC_GFM_TO_RST_CACHE or not arg)
]
results = _pandoc_gfm_to_rst_run(*[args[i] for i in valid])
sub: List[Optional[str]] = [None] * len(args)
for i, result in zip(valid, results):
sub[i] = result
for i, arg in enumerate(args):
if not arg:
sub[i] = ""
elif arg in _PANDOC_GFM_TO_RST_CACHE:
sub[i] = _PANDOC_GFM_TO_RST_CACHE[arg]
if any(r is None for r in sub):
raise ValueError("Missing processed pandoc result.")
return tuple(sub) # type: ignore


def pandoc_gfm_to_rst(doc: str) -> str:
(result,) = _pandoc_gfm_to_rst(doc)
return result


def format_github_markdown(doc: str, *, to_batch: Optional[List[str]] = None) -> str:
"""Jinja filter. Makes some attempt at fixing "Markdown" into RST."""
lines = [line.replace("\t", " " * 4).rstrip() for line in doc.splitlines()]
lines = [line for line in lines if line.rstrip()]
space_lcm = 0
while lines and all(line[: space_lcm + 1].isspace() for line in lines):
space_lcm += 1
lines = [line[space_lcm:] for line in lines]
doc = "\n".join(lines).strip()
doc = doc.replace("<br>", "\n\n")
doc = re.sub(r"<i>(.*)</i>", r"`\1`", doc)
doc = re.sub(r"<b>(.*)</b>", r"**\1**", doc)
doc = re.sub(r"\[(.+)\]\((.+)\)", r"\1 (\2)", doc)
return doc
# Sometimes Tensor<T> is used in the docs (~17 instances at 1.13)
# and is treated as invalid HTML tags by pandoc.
doc = doc.replace("<T>", "&lt;T&gt;")
# Point hyperlinks to onnx/docs
rel = "https://github.com/onnx/onnx/blob/main/docs"
doc = re.sub(
r"\[(.*)]\((\w+.md)\)", lambda match: f"[{match[1]}]({rel}/{match[2]})", doc
)
if to_batch is not None:
to_batch.append(doc)
return doc
else:
return pandoc_gfm_to_rst(doc).rstrip()


def is_variadic(param: onnx.defs.OpSchema.FormalParameter) -> bool:
Expand Down Expand Up @@ -317,6 +371,22 @@ def write_schemas_code(

built_schemas: Set[onnx.defs.OpSchema] = set()

pandoc_batch: List[str] = []
for schema in schemas:
if schema in inherited_schemas:
continue
todo = [schema.doc] + [
p.description
for p in (
list(schema.inputs)
+ list(schema.outputs)
+ list(schema.attributes.values())
)
]
for doc in todo:
format_github_markdown(doc, to_batch=pandoc_batch)
_pandoc_gfm_to_rst(*pandoc_batch)

# Operator classes
for schema in sorted(schemas, key=lambda s: s.name):
if schema in inherited_schemas:
Expand Down Expand Up @@ -548,6 +618,7 @@ def main(


if __name__ == "__main__":
gen_all_docstrings = True
ai_onnx_v17_schemas, ai_onnx_v17_module = main(
"ai.onnx",
17,
Expand All @@ -558,6 +629,7 @@ def main(
subgraphs_solutions=V16_SUBGRAPH_SOLUTIONS,
attr_type_overrides=DEFAULT_ATTR_TYPE_OVERRIDES,
allow_extra_constructor_arguments=["Split"],
gen_docstrings=gen_all_docstrings,
)
ai_onnx_ml_v3_schemas, ai_onnx_ml_v3_module = main(
"ai.onnx.ml",
Expand All @@ -575,4 +647,5 @@ def main(
"TreeEnsembleClassifier": "treeensembleclassifier3",
"TreeEnsembleRegressor": "treeensembleregressor3",
},
gen_docstrings=gen_all_docstrings,
)
Loading

0 comments on commit 4942117

Please sign in to comment.