diff --git a/sandboxed_api/tools/generator2/code.py b/sandboxed_api/tools/generator2/code.py index f0f27d5c..6727b13e 100644 --- a/sandboxed_api/tools/generator2/code.py +++ b/sandboxed_api/tools/generator2/code.py @@ -586,8 +586,17 @@ def __eq__(self, other): class _TranslationUnit(object): """Class wrapping clang's _TranslationUnit. Provides extra utilities.""" - def __init__(self, path, tu, limit_scan_depth=False): - # type: (Text, cindex.TranslationUnit, bool) -> None + def __init__(self, path, tu, limit_scan_depth=False, func_names=None): + # type: (Text, cindex.TranslationUnit, bool, Optional[List[Text]]) -> None + """Initializes the translation unit. + + Args: + path: path to source of the tranlation unit + tu: cindex tranlation unit + limit_scan_depth: whether scan should be limited to single file + func_names: list of function names to take into consideration, empty means + all functions. + """ self.path = path self.limit_scan_depth = limit_scan_depth self._tu = tu @@ -598,6 +607,7 @@ def __init__(self, path, tu, limit_scan_depth=False): self.defines = {} self.required_defines = set() self.types_to_skip = set() + self.func_names = func_names or [] def _process(self): # type: () -> None @@ -630,6 +640,9 @@ def _process(self): self.forward_decls[Type(self, cursor.type)] = cursor if (cursor.kind == cindex.CursorKind.FUNCTION_DECL and cursor.linkage != cindex.LinkageKind.INTERNAL): + # Skip non-interesting functions + if self.func_names and cursor.spelling not in self.func_names: + continue if self.limit_scan_depth: if (cursor.location and cursor.location.file.name == self.path): self.functions.add(Function(self, cursor)) @@ -638,8 +651,7 @@ def _process(self): def get_functions(self): # type: () -> Set[Function] - if not self._processed: - self._process() + self._process() return self.functions def _walk_preorder(self): @@ -663,27 +675,37 @@ def search_for_macro_name(self, cursor): class Analyzer(object): """Class responsible for analysis.""" + # pylint: disable=line-too-long @staticmethod - def process_files(input_paths, compile_flags, limit_scan_depth=False): - # type: (Text, List[Text], bool) -> List[_TranslationUnit] + def process_files( + input_paths, compile_flags, limit_scan_depth=False, func_names=None + ): + # type: (Text, List[Text], bool, Optional[List[Text]]) -> List[_TranslationUnit] """Processes files with libclang and returns TranslationUnit objects.""" _init_libclang() tus = [] for path in input_paths: tu = Analyzer._analyze_file_for_tu( - path, compile_flags=compile_flags, limit_scan_depth=limit_scan_depth) + path, + compile_flags=compile_flags, + limit_scan_depth=limit_scan_depth, + func_names=func_names, + ) tus.append(tu) return tus # pylint: disable=line-too-long @staticmethod - def _analyze_file_for_tu(path, - compile_flags=None, - test_file_existence=True, - unsaved_files=None, - limit_scan_depth=False): - # type: (Text, Optional[List[Text]], bool, Optional[Tuple[Text, Union[Text, IO[Text]]]], bool) -> _TranslationUnit + def _analyze_file_for_tu( + path, + compile_flags=None, + test_file_existence=True, + unsaved_files=None, + limit_scan_depth=False, + func_names=None, + ): + # type: (Text, Optional[List[Text]], bool, Optional[Tuple[Text, Union[Text, IO[Text]]]], bool, Optional[List[Text]]) -> _TranslationUnit """Returns Analysis object for given path.""" compile_flags = compile_flags or [] if test_file_existence and not os.path.isfile(path): @@ -701,11 +723,11 @@ def _analyze_file_for_tu(path, return _TranslationUnit( path, index.parse( - path, - args=args, - unsaved_files=unsaved_files, - options=_PARSE_OPTIONS), - limit_scan_depth=limit_scan_depth) + path, args=args, unsaved_files=unsaved_files, options=_PARSE_OPTIONS + ), + limit_scan_depth=limit_scan_depth, + func_names=func_names, + ) class Generator(object): @@ -735,20 +757,20 @@ def __init__(self, translation_units): self.functions = None _init_libclang() - def generate(self, - name, - function_names, - namespace=None, - output_file=None, - embed_dir=None, - embed_name=None): + def generate( + self, + name, + namespace=None, + output_file=None, + embed_dir=None, + embed_name=None, + ): # pylint: disable=line-too-long - # type: (Text, List[Text], Optional[Text], Optional[Text], Optional[Text], Optional[Text]) -> Text + # type: (Text, Optional[Text], Optional[Text], Optional[Text], Optional[Text]) -> Text """Generates structures, functions and typedefs. Args: name: name of the class that will contain generated interface - function_names: list of function names to export to the interface namespace: namespace of the interface output_file: path to the output file, used to generate header guards; defaults to None that does not generate the guard #include directives; @@ -759,9 +781,9 @@ def generate(self, Returns: generated interface as a string """ - related_types = self._get_related_types(function_names) + related_types = self._get_related_types() forward_decls = self._get_forward_decls(related_types) - functions = self._get_functions(function_names) + functions = self._get_functions() related_types = [(t.stringify() + ';') for t in related_types] defines = self._get_defines() @@ -776,18 +798,15 @@ def generate(self, } return self.format_template(**api) - def _get_functions(self, func_names=None): - # type: (Optional[List[Text]]) -> List[Function] + def _get_functions(self): + # type: () -> List[Function] """Gets Function objects that will be used to generate interface.""" if self.functions is not None: return self.functions self.functions = [] # TODO(szwl): for d in translation_unit.diagnostics:, handle that for translation_unit in self.translation_units: - self.functions += [ - f for f in translation_unit.get_functions() - if not func_names or f.name in func_names - ] + self.functions += translation_unit.get_functions() # allow only nonmangled functions - C++ overloads are not handled in # code generation self.functions = [f for f in self.functions if not f.is_mangled()] @@ -797,8 +816,8 @@ def _get_functions(self, func_names=None): self.functions.sort(key=lambda x: x.name) return self.functions - def _get_related_types(self, func_names=None): - # type: (Optional[List[Text]]) -> List[Type] + def _get_related_types(self): + # type: () -> List[Type] """Gets type definitions related to chosen functions. Types related to one function will land in the same translation unit, @@ -806,19 +825,14 @@ def _get_related_types(self, func_names=None): This is necessary as we can't compare types from two different translation units. - Args: - func_names: list of function names to take into consideration, empty means - all functions. - Returns: list of types in correct (ready to render) order """ processed = set() - fn_related_types = set() types = [] types_to_skip = set() - for f in self._get_functions(func_names): + for f in self._get_functions(): fn_related_types = f.get_related_types() types += sorted(r for r in fn_related_types if r not in processed) processed.update(fn_related_types) diff --git a/sandboxed_api/tools/generator2/code_test.py b/sandboxed_api/tools/generator2/code_test.py index 8ef60574..f924f1b2 100644 --- a/sandboxed_api/tools/generator2/code_test.py +++ b/sandboxed_api/tools/generator2/code_test.py @@ -33,15 +33,20 @@ """ -def analyze_string(content, path='tmp.cc', limit_scan_depth=False): +def analyze_string( + content, path='tmp.cc', limit_scan_depth=False, func_names=None +): """Returns Analysis object for in memory content.""" - return analyze_strings(path, [(path, content)], limit_scan_depth) + return analyze_strings(path, [(path, content)], limit_scan_depth, func_names) -def analyze_strings(path, unsaved_files, limit_scan_depth=False): +def analyze_strings( + path, unsaved_files, limit_scan_depth=False, func_names=None +): """Returns Analysis object for in memory content.""" - return code.Analyzer._analyze_file_for_tu(path, None, False, unsaved_files, - limit_scan_depth) + return code.Analyzer._analyze_file_for_tu( + path, None, False, unsaved_files, limit_scan_depth, func_names + ) class CodeAnalysisTest(parameterized.TestCase): @@ -126,8 +131,8 @@ def testCodeGeneratorOutput(self): 'function_a', 'types_1', 'types_2', 'types_3', 'types_4', 'types_5', 'types_6' ] - generator = code.Generator([analyze_string(body)]) - result = generator.generate('Test', functions, 'sapi::Tests', None, None) + generator = code.Generator([analyze_string(body, func_names=functions)]) + result = generator.generate('Test', 'sapi::Tests', None, None) self.assertMultiLineEqual(code_test_util.CODE_GOLD, result) def testElaboratedArgument(self): @@ -135,18 +140,18 @@ def testElaboratedArgument(self): struct x { int a; }; extern "C" int function(struct x a) { return a.a; } """ - generator = code.Generator([analyze_string(body)]) + generator = code.Generator([analyze_string(body, func_names=['function'])]) with self.assertRaisesRegex(ValueError, r'Elaborate.*mapped.*'): - generator.generate('Test', ['function'], 'sapi::Tests', None, None) + generator.generate('Test', 'sapi::Tests', None, None) def testElaboratedArgument2(self): body = """ typedef struct { int a; char b; } x; extern "C" int function(x a) { return a.a; } """ - generator = code.Generator([analyze_string(body)]) + generator = code.Generator([analyze_string(body, func_names=['function'])]) with self.assertRaisesRegex(ValueError, r'Elaborate.*mapped.*'): - generator.generate('Test', ['function'], 'sapi::Tests', None, None) + generator.generate('Test', 'sapi::Tests', None, None) def testGetMappedType(self): body = """ @@ -155,7 +160,7 @@ def testGetMappedType(self): extern "C" uint function(uintp a) { return *a; } """ generator = code.Generator([analyze_string(body)]) - result = generator.generate('Test', [], 'sapi::Tests', None, None) + result = generator.generate('Test', 'sapi::Tests', None, None) self.assertMultiLineEqual(code_test_util.CODE_GOLD_MAPPED, result) @parameterized.named_parameters( @@ -182,7 +187,7 @@ def testArgumentNames(self, body, names): self.assertLen(functions, 1) self.assertLen(functions[0].argument_types, len(names)) # Extra check for generation, in case rendering throws error for this test. - generator.generate('Test', [], 'sapi::Tests', None, None) + generator.generate('Test', 'sapi::Tests', None, None) for t in functions[0].argument_types: self.assertIn(t.name, names) @@ -203,7 +208,7 @@ def testEnumGeneration(self): } """ generator = code.Generator([analyze_string(body)]) - result = generator.generate('Test', [], 'sapi::Tests', None, None) + result = generator.generate('Test', 'sapi::Tests', None, None) self.assertMultiLineEqual(code_test_util.CODE_ENUM_GOLD, result) def testTypeEq(self): @@ -223,7 +228,7 @@ def testTypeEq(self): self.assertLen(set(args), 2) # Extra check for generation, in case rendering throws error for this test. - generator.generate('Test', [], 'sapi::Tests', None, None) + generator.generate('Test', 'sapi::Tests', None, None) def testTypedefRelatedTypes(self): body = """ @@ -262,7 +267,7 @@ def testTypedefRelatedTypes(self): self.assertSameElements(names, ['data_s', 'data_p']) # Extra check for generation, in case rendering throws error for this test. - generator.generate('Test', [], 'sapi::Tests', None, None) + generator.generate('Test', 'sapi::Tests', None, None) def testTypedefDuplicateType(self): body = """ @@ -291,7 +296,7 @@ def testTypedefDuplicateType(self): self.assertSameElements(['data_s', 's'], names) # Extra check for generation, in case rendering throws error for this test. - generator.generate('Test', [], 'sapi::Tests', None, None) + generator.generate('Test', 'sapi::Tests', None, None) def testStructureRelatedTypes(self): body = """ @@ -342,7 +347,7 @@ def testStructureRelatedTypes(self): self.assertLen(types, 1) # Extra check for generation, in case rendering throws error for this test. - generator.generate('Test', [], 'sapi::Tests', None, None) + generator.generate('Test', 'sapi::Tests', None, None) def testUnionRelatedTypes(self): body = """ @@ -382,7 +387,7 @@ def testUnionRelatedTypes(self): self.assertSameElements(names, ['union_1', 'uint']) # Extra check for generation, in case rendering throws error for this test. - generator.generate('Test', [], 'sapi::Tests', None, None) + generator.generate('Test', 'sapi::Tests', None, None) def testFunctionPointerRelatedTypes(self): body = """ @@ -415,7 +420,7 @@ def testFunctionPointerRelatedTypes(self): self.assertSameElements(names, ['funcp', 'uint', 'uchar']) # Extra check for generation, in case rendering throws error for this test. - generator.generate('Test', [], 'sapi::Tests', None, None) + generator.generate('Test', 'sapi::Tests', None, None) def testForwardDeclaration(self): body = """ @@ -457,7 +462,7 @@ def testForwardDeclaration(self): forward_decls = generator._get_forward_decls(generator._get_related_types()) self.assertLen(forward_decls, 1) self.assertEqual(forward_decls[0], 'struct struct_6_def;') - generator.generate('Test', [], 'sapi::Tests', None, None) + generator.generate('Test', 'sapi::Tests', None, None) def testEnumRelatedTypes(self): body = """ @@ -486,7 +491,7 @@ def testEnumRelatedTypes(self): self.assertLen(args[5].get_related_types(), 1) # Extra check for generation, in case rendering throws error for this test. - generator.generate('Test', [], 'sapi::Tests', None, None) + generator.generate('Test', 'sapi::Tests', None, None) def testArrayAsParam(self): body = """ @@ -577,7 +582,7 @@ def testTypeOrder(self, func, a1, a2): args = functions[0].arguments() getattr(self, func)(args[a1], args[a2]) # Extra check for generation, in case rendering throws error for this test. - generator.generate('Test', [], 'sapi::Tests', None, None) + generator.generate('Test', 'sapi::Tests', None, None) def testFilterFunctionsFromInputFilesOnly(self): file1_code = """ @@ -643,7 +648,7 @@ def testTypeToString(self): self.assertMultiLineEqual(expected, types[1].stringify()) # Extra check for generation, in case rendering throws error for this test. - generator.generate('Test', [], 'sapi::Tests', None, None) + generator.generate('Test', 'sapi::Tests', None, None) def testCollectDefines(self): body = """ @@ -665,7 +670,7 @@ def testCollectDefines(self): generator._get_related_types() tu = generator.translation_units[0] - tu._process() + tu.get_functions() self.assertLen(tu.required_defines, 4) defines = generator._get_defines() @@ -676,7 +681,7 @@ def testCollectDefines(self): self.assertIn('#define SIZE4 10', defines) # Extra check for generation, in case rendering throws error for this test. - generator.generate('Test', [], 'sapi::Tests', None, None) + generator.generate('Test', 'sapi::Tests', None, None) def testYaraCase(self): body = """ @@ -697,7 +702,7 @@ def testYaraCase(self): generator._get_related_types() tu = generator.translation_units[0] - tu._process() + tu.get_functions() self.assertLen(tu.required_defines, 2) defines = generator._get_defines() @@ -708,7 +713,7 @@ def testYaraCase(self): self.assertTrue(defines[1].startswith(gold)) # Extra check for generation, in case rendering throws error for this test. - generator.generate('Test', [], 'sapi::Tests', None, None) + generator.generate('Test', 'sapi::Tests', None, None) def testDoubleFunction(self): body = """ @@ -726,7 +731,7 @@ def testDoubleFunction(self): self.assertLen(tu.functions, 1) # Extra check for generation, in case rendering throws error for this test. - generator.generate('Test', [], 'sapi::Tests', None, None) + generator.generate('Test', 'sapi::Tests', None, None) def testDefineStructBody(self): body = """ @@ -744,7 +749,7 @@ def testDefineStructBody(self): self.assertLen(generator.translation_units, 1) # initialize all internal data - generator.generate('Test', [], 'sapi::Tests', None, None) + generator.generate('Test', 'sapi::Tests', None, None) tu = generator.translation_units[0] self.assertLen(tu.functions, 1) @@ -762,7 +767,7 @@ def testJpegTurboCase(self): self.assertLen(generator.translation_units, 1) # initialize all internal data - generator.generate('Test', [], 'sapi::Tests', None, None) + generator.generate('Test', 'sapi::Tests', None, None) tu = generator.translation_units[0] self.assertLen(tu.functions, 1) @@ -784,7 +789,7 @@ def testMultipleTypesWhenConst(self): self.assertLen(generator.translation_units, 1) # Initialize all internal data - generator.generate('Test', [], 'sapi::Tests', None, None) + generator.generate('Test', 'sapi::Tests', None, None) tu = generator.translation_units[0] self.assertLen(tu.functions, 2) @@ -802,7 +807,7 @@ def testReference(self): self.assertLen(generator.translation_units, 1) # Initialize all internal data - generator.generate('Test', [], 'sapi::Tests', None, None) + generator.generate('Test', 'sapi::Tests', None, None) tu = generator.translation_units[0] self.assertLen(tu.functions, 1) @@ -822,7 +827,7 @@ def testCppHeader(self): unsaved_files = [(path, content)] generator = code.Generator([analyze_strings(path, unsaved_files)]) # Initialize all internal data - generator.generate('Test', [], 'sapi::Tests', None, None) + generator.generate('Test', 'sapi::Tests', None, None) # generator should filter out mangled function functions = generator._get_functions() diff --git a/sandboxed_api/tools/generator2/sapi_generator.py b/sandboxed_api/tools/generator2/sapi_generator.py index 56a6dab9..c34fcfaf 100644 --- a/sandboxed_api/tools/generator2/sapi_generator.py +++ b/sandboxed_api/tools/generator2/sapi_generator.py @@ -57,12 +57,17 @@ def main(c_flags): c_flags.pop(0) logging.debug(FLAGS.sapi_functions) extract_includes(FLAGS.sapi_isystem, c_flags) - tus = code.Analyzer.process_files(FLAGS.sapi_in, c_flags, - FLAGS.sapi_limit_scan_depth) + tus = code.Analyzer.process_files( + FLAGS.sapi_in, c_flags, FLAGS.sapi_limit_scan_depth, FLAGS.sapi_functions + ) generator = code.Generator(tus) - result = generator.generate(FLAGS.sapi_name, FLAGS.sapi_functions, - FLAGS.sapi_ns, FLAGS.sapi_out, - FLAGS.sapi_embed_dir, FLAGS.sapi_embed_name) + result = generator.generate( + FLAGS.sapi_name, + FLAGS.sapi_ns, + FLAGS.sapi_out, + FLAGS.sapi_embed_dir, + FLAGS.sapi_embed_name, + ) if FLAGS.sapi_out: with open(FLAGS.sapi_out, 'w') as out_file: