Skip to content

Commit

Permalink
generator: Ignore unneeded functions early
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 621527466
Change-Id: Ifc6e43d96a5a4a5981eded7b08aba54d8c1ee63a
  • Loading branch information
happyCoder92 authored and copybara-github committed Apr 3, 2024
1 parent cb276d4 commit 6159168
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 82 deletions.
100 changes: 57 additions & 43 deletions sandboxed_api/tools/generator2/code.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,8 +586,17 @@ def __eq__(self, other):
class _TranslationUnit(object):
"""Class wrapping clang's _TranslationUnit. Provides extra utilities."""

def __init__(self, path, tu, limit_scan_depth=False):
# type: (Text, cindex.TranslationUnit, bool) -> None
def __init__(self, path, tu, limit_scan_depth=False, func_names=None):
# type: (Text, cindex.TranslationUnit, bool, Optional[List[Text]]) -> None
"""Initializes the translation unit.
Args:
path: path to source of the tranlation unit
tu: cindex tranlation unit
limit_scan_depth: whether scan should be limited to single file
func_names: list of function names to take into consideration, empty means
all functions.
"""
self.path = path
self.limit_scan_depth = limit_scan_depth
self._tu = tu
Expand All @@ -598,6 +607,7 @@ def __init__(self, path, tu, limit_scan_depth=False):
self.defines = {}
self.required_defines = set()
self.types_to_skip = set()
self.func_names = func_names or []

def _process(self):
# type: () -> None
Expand Down Expand Up @@ -630,6 +640,9 @@ def _process(self):
self.forward_decls[Type(self, cursor.type)] = cursor
if (cursor.kind == cindex.CursorKind.FUNCTION_DECL and
cursor.linkage != cindex.LinkageKind.INTERNAL):
# Skip non-interesting functions
if self.func_names and cursor.spelling not in self.func_names:
continue
if self.limit_scan_depth:
if (cursor.location and cursor.location.file.name == self.path):
self.functions.add(Function(self, cursor))
Expand All @@ -638,8 +651,7 @@ def _process(self):

def get_functions(self):
# type: () -> Set[Function]
if not self._processed:
self._process()
self._process()
return self.functions

def _walk_preorder(self):
Expand All @@ -663,27 +675,37 @@ def search_for_macro_name(self, cursor):
class Analyzer(object):
"""Class responsible for analysis."""

# pylint: disable=line-too-long
@staticmethod
def process_files(input_paths, compile_flags, limit_scan_depth=False):
# type: (Text, List[Text], bool) -> List[_TranslationUnit]
def process_files(
input_paths, compile_flags, limit_scan_depth=False, func_names=None
):
# type: (Text, List[Text], bool, Optional[List[Text]]) -> List[_TranslationUnit]
"""Processes files with libclang and returns TranslationUnit objects."""
_init_libclang()

tus = []
for path in input_paths:
tu = Analyzer._analyze_file_for_tu(
path, compile_flags=compile_flags, limit_scan_depth=limit_scan_depth)
path,
compile_flags=compile_flags,
limit_scan_depth=limit_scan_depth,
func_names=func_names,
)
tus.append(tu)
return tus

# pylint: disable=line-too-long
@staticmethod
def _analyze_file_for_tu(path,
compile_flags=None,
test_file_existence=True,
unsaved_files=None,
limit_scan_depth=False):
# type: (Text, Optional[List[Text]], bool, Optional[Tuple[Text, Union[Text, IO[Text]]]], bool) -> _TranslationUnit
def _analyze_file_for_tu(
path,
compile_flags=None,
test_file_existence=True,
unsaved_files=None,
limit_scan_depth=False,
func_names=None,
):
# type: (Text, Optional[List[Text]], bool, Optional[Tuple[Text, Union[Text, IO[Text]]]], bool, Optional[List[Text]]) -> _TranslationUnit
"""Returns Analysis object for given path."""
compile_flags = compile_flags or []
if test_file_existence and not os.path.isfile(path):
Expand All @@ -701,11 +723,11 @@ def _analyze_file_for_tu(path,
return _TranslationUnit(
path,
index.parse(
path,
args=args,
unsaved_files=unsaved_files,
options=_PARSE_OPTIONS),
limit_scan_depth=limit_scan_depth)
path, args=args, unsaved_files=unsaved_files, options=_PARSE_OPTIONS
),
limit_scan_depth=limit_scan_depth,
func_names=func_names,
)


class Generator(object):
Expand Down Expand Up @@ -735,20 +757,20 @@ def __init__(self, translation_units):
self.functions = None
_init_libclang()

def generate(self,
name,
function_names,
namespace=None,
output_file=None,
embed_dir=None,
embed_name=None):
def generate(
self,
name,
namespace=None,
output_file=None,
embed_dir=None,
embed_name=None,
):
# pylint: disable=line-too-long
# type: (Text, List[Text], Optional[Text], Optional[Text], Optional[Text], Optional[Text]) -> Text
# type: (Text, Optional[Text], Optional[Text], Optional[Text], Optional[Text]) -> Text
"""Generates structures, functions and typedefs.
Args:
name: name of the class that will contain generated interface
function_names: list of function names to export to the interface
namespace: namespace of the interface
output_file: path to the output file, used to generate header guards;
defaults to None that does not generate the guard #include directives;
Expand All @@ -759,9 +781,9 @@ def generate(self,
Returns:
generated interface as a string
"""
related_types = self._get_related_types(function_names)
related_types = self._get_related_types()
forward_decls = self._get_forward_decls(related_types)
functions = self._get_functions(function_names)
functions = self._get_functions()
related_types = [(t.stringify() + ';') for t in related_types]
defines = self._get_defines()

Expand All @@ -776,18 +798,15 @@ def generate(self,
}
return self.format_template(**api)

def _get_functions(self, func_names=None):
# type: (Optional[List[Text]]) -> List[Function]
def _get_functions(self):
# type: () -> List[Function]
"""Gets Function objects that will be used to generate interface."""
if self.functions is not None:
return self.functions
self.functions = []
# TODO(szwl): for d in translation_unit.diagnostics:, handle that
for translation_unit in self.translation_units:
self.functions += [
f for f in translation_unit.get_functions()
if not func_names or f.name in func_names
]
self.functions += translation_unit.get_functions()
# allow only nonmangled functions - C++ overloads are not handled in
# code generation
self.functions = [f for f in self.functions if not f.is_mangled()]
Expand All @@ -797,28 +816,23 @@ def _get_functions(self, func_names=None):
self.functions.sort(key=lambda x: x.name)
return self.functions

def _get_related_types(self, func_names=None):
# type: (Optional[List[Text]]) -> List[Type]
def _get_related_types(self):
# type: () -> List[Type]
"""Gets type definitions related to chosen functions.
Types related to one function will land in the same translation unit,
we gather the types, sort it and put as a sublist in types list.
This is necessary as we can't compare types from two different translation
units.
Args:
func_names: list of function names to take into consideration, empty means
all functions.
Returns:
list of types in correct (ready to render) order
"""
processed = set()
fn_related_types = set()
types = []
types_to_skip = set()

for f in self._get_functions(func_names):
for f in self._get_functions():
fn_related_types = f.get_related_types()
types += sorted(r for r in fn_related_types if r not in processed)
processed.update(fn_related_types)
Expand Down
Loading

0 comments on commit 6159168

Please sign in to comment.