diff --git a/.gitignore b/.gitignore index c862d4a..3d4f354 100644 --- a/.gitignore +++ b/.gitignore @@ -145,8 +145,4 @@ dmypy.json cython_debug/ # PyCharm -# JetBrains specific template is maintained in a separate JetBrains.gitignore that can -# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore -# and can be added to the global gitignore or merged into this file. For a more nuclear -# option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ +.idea/ diff --git a/src/pytsql/grammar/tsqlParser.py b/src/pytsql/grammar/tsqlParser.py index fbf7127..c5474ac 100644 --- a/src/pytsql/grammar/tsqlParser.py +++ b/src/pytsql/grammar/tsqlParser.py @@ -24890,6 +24890,13 @@ def getRuleIndex(self): return tsqlParser.RULE_data_type + @staticmethod + def is_top_level_statement(node: ParserRuleContext): + """Check wether node is a top level SQL statement.""" + cur = node.parentCtx + while isinstance(cur, tsqlParser.Sql_clauseContext) or isinstance(cur, tsqlParser.Sql_clausesContext): + cur = cur.parentCtx + return isinstance(cur, tsqlParser.BatchContext) def data_type(self): diff --git a/src/pytsql/tsql.py b/src/pytsql/tsql.py index 910f4dc..7fdaaf2 100644 --- a/src/pytsql/tsql.py +++ b/src/pytsql/tsql.py @@ -112,7 +112,9 @@ def visitChildren(self, node: antlr4.ParserRuleContext) -> List[str]: else: result = super().visitChildren(node) - if isinstance(node, tsqlParser.Declare_statementContext): + if isinstance( + node, tsqlParser.Declare_statementContext + ) and tsqlParser.is_top_level_statement(node): self.dynamics.extend(result) return result diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 696b28f..074b121 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -38,6 +38,12 @@ def pytest_addoption(parser): def pytest_generate_tests(metafunc): if "backend" in metafunc.fixturenames: - metafunc.parametrize( - "backend", [metafunc.config.getoption("backend")], scope="module" - ) + try: + metafunc.parametrize( + "backend", [metafunc.config.getoption("backend")], scope="module" + ) + except ValueError: + # some metafunc.config objects don't have an option "backend" + metafunc.parametrize( + "backend", ["default_backend"], scope="module" + ) diff --git a/tests/integration/test_multiple_statements.py b/tests/integration/test_multiple_statements.py index 7a07b8a..26fd6f6 100644 --- a/tests/integration/test_multiple_statements.py +++ b/tests/integration/test_multiple_statements.py @@ -156,6 +156,49 @@ def test_multiple_uses(engine): assert "table4" in test_database_names +def test_stored_procedure_declaration(engine): + statement = """ + DROP DATABASE IF EXISTS stored_procedure_declaration + CREATE DATABASE stored_procedure_declaration + USE stored_procedure_declaration +GO + +/****** Object: Table [dbo].[table1] Script Date: 2/23/2021 2:48:02 PM ******/ +CREATE PROCEDURE CREATEALLDATES + ( + @StartDate AS DATE, @EndDate AS DATE + ) AS + DECLARE @Current AS DATE = DATEADD(DD, 0, @StartDate); DROP TABLE IF EXISTS ##alldates CREATE TABLE ##alldates ( + dt DATE PRIMARY KEY + ) WHILE @Current <= @EndDate BEGIN + INSERT INTO ##alldates + VALUES (@Current); + SET @Current = DATEADD(DD, 1, @Current) -- add 1 to current day +END +GO +IF OBJECT_ID ( N'dbo.get_db_sampling_factor' , N'FN' ) IS NOT NULL DROP FUNCTION get_db_sampling_factor ; +GO +""" + executes(statement, engine, None) + + +def test_top_level_declaration(engine): + statement = """ + DROP DATABASE IF EXISTS top_level_declaration + CREATE DATABASE top_level_declaration + USE top_level_declaration +GO + +DECLARE @Current AS DATE = '2022-01-01' +GO +SELECT @Current as a INTO dummy01 +GO +SELECT @Current as b INTO dummy02 +GO +""" + executes(statement, engine, None) + + def get_table( engine: Engine, table_name: str, schema: Optional[str] = None ) -> sa.Table: diff --git a/tests/unit/test_dynamics.py b/tests/unit/test_dynamics.py index e351c1a..6327e1c 100644 --- a/tests/unit/test_dynamics.py +++ b/tests/unit/test_dynamics.py @@ -7,7 +7,7 @@ def test_declaration_in_control_flow(): seed = """ IF 1 = 1 DECLARE @A INT = 5 - SELECT * FROM x + SELECT @A """ splits = _split(seed) assert len(splits) == 2 @@ -15,12 +15,28 @@ def test_declaration_in_control_flow(): assert_strings_equal_disregarding_whitespace( splits[0], "IF 1 = 1 DECLARE @A INT = 5" ) + # unfortunately we can't be right here because otherwise we would need to get + # the output of the declaration assert_strings_equal_disregarding_whitespace( splits[1], - """ + """SELECT @A""", + ) + + +def test_select_in_control_flow(): + seed = """ + IF 1 = 0 + BEGIN DECLARE @A INT = 5 SELECT * FROM x - """, + END + """ + splits = _split(seed) + assert len(splits) == 1 + + # this is beyond the complexity we want to manage with isolate_top_level_statements=True + assert_strings_equal_disregarding_whitespace( + splits[0], "IF 1 = 0 BEGIN DECLARE @A INT = 5 SELECT * FROM x END" )