Skip to content

Commit

Permalink
python: add 'make generate' OS agnostic (#2288)
Browse files Browse the repository at this point in the history
  • Loading branch information
apwojcik authored Oct 11, 2023
1 parent df5f8c9 commit 76bb3b2
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 57 deletions.
14 changes: 13 additions & 1 deletion tools/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,16 @@
# THE SOFTWARE.
#####################################################################################

add_custom_target(generate bash ${CMAKE_CURRENT_SOURCE_DIR}/generate.sh)
find_package(Python 3 COMPONENTS Interpreter)
if(NOT Python_EXECUTABLE)
message(WARNING "Python 3 interpreter not found - skipping 'generate' target!")
return()
endif()

find_program(CLANG_FORMAT clang-format PATHS /opt/rocm/llvm $ENV{HIP_PATH})
if(NOT CLANG_FORMAT)
message(WARNING "clang-format not found - skipping 'generate' target!")
return()
endif()

add_custom_target(generate ${Python_EXECUTABLE} generate.py -f ${CLANG_FORMAT} WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
20 changes: 10 additions & 10 deletions tools/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import runpy
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from pathlib import Path

type_map: Dict[str, Callable[['Parameter'], None]] = {}
cpp_type_map: Dict[str, str] = {}
Expand Down Expand Up @@ -1281,18 +1282,17 @@ def template_eval(template, **kwargs):
return template


def run(args: List[str]) -> None:
runpy.run_path(args[0])
if len(args) > 1:
f = open(args[1]).read()
r = template_eval(f)
def run(path: Union[Path, str]) -> str:
return template_eval(open(path).read())


if __name__ == "__main__":
sys.modules['api'] = sys.modules['__main__']
runpy.run_path(sys.argv[1])
if len(sys.argv) > 2:
r = run(sys.argv[2])
sys.stdout.write(r)
else:
sys.stdout.write(generate_c_header())
sys.stdout.write(generate_c_api_body())
# sys.stdout.write(generate_cpp_header())


if __name__ == "__main__":
sys.modules['api'] = sys.modules['__main__']
run(sys.argv[1:])
88 changes: 88 additions & 0 deletions tools/generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
import api, argparse, os, runpy, subprocess, sys, te
from pathlib import Path

clang_format_path = Path('clang-format.exe' if os.name ==
'nt' else '/opt/rocm/llvm/bin/clang-format')
work_dir = Path().cwd()
src_dir = (work_dir / '../src').absolute()
migraphx_py_path = src_dir / 'api/migraphx.py'


def clang_format(buffer, **kwargs):
return subprocess.run(f'{clang_format_path} -style=file',
capture_output=True,
shell=True,
check=True,
input=buffer.encode('utf-8'),
cwd=work_dir,
**kwargs).stdout.decode('utf-8')


def api_generate(input_path: Path, output_path: Path):
with open(output_path, 'w') as f:
f.write(clang_format(api.run(input_path)))


def te_generate(input_path: Path, output_path: Path):
with open(output_path, 'w') as f:
f.write(clang_format(te.run(input_path)))


def main():
parser = argparse.ArgumentParser()
parser.add_argument('-f', '--clang-format', type=Path)
args = parser.parse_args()

global clang_format_path
if args.clang_format:
clang_format_path = args.clang_format

if not clang_format_path.is_file():
print(f"{clang_format_path}: invalid path or not installed",
file=sys.stderr)
return

try:
files = Path('include').absolute().iterdir()
for f in [f for f in files if f.is_file()]:
te_generate(f, src_dir / f'include/migraphx/{f.name}')
runpy.run_path(str(migraphx_py_path))
api_generate(work_dir / 'api/migraphx.h',
src_dir / 'api/include/migraphx/migraphx.h')
print('Finished generating header migraphx.h')
api_generate(work_dir / 'api/api.cpp', src_dir / 'api/api.cpp')
print('Finished generating source api.cpp')
except subprocess.CalledProcessError as ex:
if ex.stdout:
print(ex.stdout.decode('utf-8'))
if ex.stderr:
print(ex.stdout.decode('utf-8'))
print(f"Command '{ex.cmd}' returned {ex.returncode}")
raise


if __name__ == "__main__":
main()
43 changes: 0 additions & 43 deletions tools/generate.sh

This file was deleted.

9 changes: 6 additions & 3 deletions tools/te.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,9 @@ def template_eval(template, **kwargs):
return template


f = open(sys.argv[1]).read()
r = template_eval(f)
sys.stdout.write(r)
def run(p):
return template_eval(open(p).read())


if __name__ == '__main__':
sys.stdout.write(run(sys.argv[1]))

0 comments on commit 76bb3b2

Please sign in to comment.