Skip to content

Commit

Permalink
Disallow untyped def's (#767)
Browse files Browse the repository at this point in the history
* Disallow untyped `def`'s

---------

Co-authored-by: Mike Alfare <[email protected]>
  • Loading branch information
Fokko and mikealfare authored Jun 24, 2023
1 parent a1d161c commit 8ea1597
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 72 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20230510-163110.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: Disallow untyped `def`'s
time: 2023-05-10T16:31:10.593358+02:00
custom:
Author: Fokko
Issue: "760"
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ repos:
alias: flake8-check
stages: [manual]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.1.1
rev: v1.2.0
hooks:
- id: mypy
# N.B.: Mypy is... a bit fragile.
Expand All @@ -52,7 +52,7 @@ repos:
# of our control to the mix. Unfortunately, there's nothing we can
# do about per pre-commit's author.
# See https://github.com/pre-commit/pre-commit/issues/730 for details.
args: [--show-error-codes, --ignore-missing-imports, --explicit-package-bases, --warn-unused-ignores]
args: [--show-error-codes, --ignore-missing-imports, --explicit-package-bases, --warn-unused-ignores, --disallow-untyped-defs]
files: ^dbt/adapters/.*
language: system
- id: mypy
Expand Down
2 changes: 1 addition & 1 deletion dbt/adapters/spark/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def can_expand_to(self: Self, other_column: Self) -> bool: # type: ignore
"""returns True if both columns are strings"""
return self.is_string() and other_column.is_string()

def literal(self, value):
def literal(self, value: Any) -> str:
return "cast({} as {})".format(value, self.dtype)

@property
Expand Down
80 changes: 47 additions & 33 deletions dbt/adapters/spark/connections.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from contextlib import contextmanager
from typing import Tuple

import dbt.exceptions
from dbt.adapters.base import Credentials
Expand All @@ -23,10 +22,10 @@
pyodbc = None
from datetime import datetime
import sqlparams

from dbt.contracts.connection import Connection
from hologram.helpers import StrEnum
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional, Union, Tuple, List, Generator, Iterable

try:
from thrift.transport.TSSLSocket import TSSLSocket
Expand All @@ -45,7 +44,7 @@
NUMBERS = DECIMALS + (int, float)


def _build_odbc_connnection_string(**kwargs) -> str:
def _build_odbc_connnection_string(**kwargs: Any) -> str:
return ";".join([f"{k}={v}" for k, v in kwargs.items()])


Expand Down Expand Up @@ -78,17 +77,17 @@ class SparkCredentials(Credentials):
retry_all: bool = False

@classmethod
def __pre_deserialize__(cls, data):
def __pre_deserialize__(cls, data: Any) -> Any:
data = super().__pre_deserialize__(data)
if "database" not in data:
data["database"] = None
return data

@property
def cluster_id(self):
def cluster_id(self) -> Optional[str]:
return self.cluster

def __post_init__(self):
def __post_init__(self) -> None:
# spark classifies database and schema as the same thing
if self.database is not None and self.database != self.schema:
raise dbt.exceptions.DbtRuntimeError(
Expand Down Expand Up @@ -141,31 +140,34 @@ def __post_init__(self):
) from e

@property
def type(self):
def type(self) -> str:
return "spark"

@property
def unique_field(self):
def unique_field(self) -> str:
return self.host

def _connection_keys(self) -> Tuple[str, ...]:
return ("host", "port", "cluster", "endpoint", "schema", "organization")
return "host", "port", "cluster", "endpoint", "schema", "organization"


class PyhiveConnectionWrapper(object):
"""Wrap a Spark connection in a way that no-ops transactions"""

# https://forums.databricks.com/questions/2157/in-apache-spark-sql-can-we-roll-back-the-transacti.html # noqa

def __init__(self, handle):
handle: "pyodbc.Connection"
_cursor: "Optional[pyodbc.Cursor]"

def __init__(self, handle: "pyodbc.Connection") -> None:
self.handle = handle
self._cursor = None

def cursor(self):
def cursor(self) -> "PyhiveConnectionWrapper":
self._cursor = self.handle.cursor()
return self

def cancel(self):
def cancel(self) -> None:
if self._cursor:
# Handle bad response in the pyhive lib when
# the connection is cancelled
Expand All @@ -174,7 +176,7 @@ def cancel(self):
except EnvironmentError as exc:
logger.debug("Exception while cancelling query: {}".format(exc))

def close(self):
def close(self) -> None:
if self._cursor:
# Handle bad response in the pyhive lib when
# the connection is cancelled
Expand All @@ -184,13 +186,14 @@ def close(self):
logger.debug("Exception while closing cursor: {}".format(exc))
self.handle.close()

def rollback(self, *args, **kwargs):
def rollback(self, *args: Any, **kwargs: Any) -> None:
logger.debug("NotImplemented: rollback")

def fetchall(self):
def fetchall(self) -> List["pyodbc.Row"]:
assert self._cursor, "Cursor not available"
return self._cursor.fetchall()

def execute(self, sql, bindings=None):
def execute(self, sql: str, bindings: Optional[List[Any]] = None) -> None:
if sql.strip().endswith(";"):
sql = sql.strip()[:-1]

Expand All @@ -212,6 +215,8 @@ def execute(self, sql, bindings=None):
if bindings is not None:
bindings = [self._fix_binding(binding) for binding in bindings]

assert self._cursor, "Cursor not available"

self._cursor.execute(sql, bindings, async_=True)
poll_state = self._cursor.poll()
state = poll_state.operationState
Expand Down Expand Up @@ -245,7 +250,7 @@ def execute(self, sql, bindings=None):
logger.debug("Poll status: {}, query complete".format(state))

@classmethod
def _fix_binding(cls, value):
def _fix_binding(cls, value: Any) -> Union[float, str]:
"""Convert complex datatypes to primitives that can be loaded by
the Spark driver"""
if isinstance(value, NUMBERS):
Expand All @@ -256,12 +261,14 @@ def _fix_binding(cls, value):
return value

@property
def description(self):
def description(self) -> Tuple[Tuple[str, Any, int, int, int, int, bool]]:
assert self._cursor, "Cursor not available"
return self._cursor.description


class PyodbcConnectionWrapper(PyhiveConnectionWrapper):
def execute(self, sql, bindings=None):
def execute(self, sql: str, bindings: Optional[List[Any]] = None) -> None:
assert self._cursor, "Cursor not available"
if sql.strip().endswith(";"):
sql = sql.strip()[:-1]
# pyodbc does not handle a None type binding!
Expand All @@ -282,7 +289,7 @@ class SparkConnectionManager(SQLConnectionManager):
SPARK_CONNECTION_URL = "{host}:{port}" + SPARK_CLUSTER_HTTP_PATH

@contextmanager
def exception_handler(self, sql):
def exception_handler(self, sql: str) -> Generator[None, None, None]:
try:
yield

Expand All @@ -299,30 +306,30 @@ def exception_handler(self, sql):
else:
raise dbt.exceptions.DbtRuntimeError(str(exc))

def cancel(self, connection):
def cancel(self, connection: Connection) -> None:
connection.handle.cancel()

@classmethod
def get_response(cls, cursor) -> AdapterResponse:
def get_response(cls, cursor: Any) -> AdapterResponse:
# https://github.com/dbt-labs/dbt-spark/issues/142
message = "OK"
return AdapterResponse(_message=message)

# No transactions on Spark....
def add_begin_query(self, *args, **kwargs):
def add_begin_query(self, *args: Any, **kwargs: Any) -> None:
logger.debug("NotImplemented: add_begin_query")

def add_commit_query(self, *args, **kwargs):
def add_commit_query(self, *args: Any, **kwargs: Any) -> None:
logger.debug("NotImplemented: add_commit_query")

def commit(self, *args, **kwargs):
def commit(self, *args: Any, **kwargs: Any) -> None:
logger.debug("NotImplemented: commit")

def rollback(self, *args, **kwargs):
def rollback(self, *args: Any, **kwargs: Any) -> None:
logger.debug("NotImplemented: rollback")

@classmethod
def validate_creds(cls, creds, required):
def validate_creds(cls, creds: Any, required: Iterable[str]) -> None:
method = creds.method

for key in required:
Expand All @@ -333,7 +340,7 @@ def validate_creds(cls, creds, required):
)

@classmethod
def open(cls, connection):
def open(cls, connection: Connection) -> Connection:
if connection.state == ConnectionState.OPEN:
logger.debug("Connection is already open, skipping open.")
return connection
Expand Down Expand Up @@ -450,7 +457,7 @@ def open(cls, connection):
SessionConnectionWrapper,
)

handle = SessionConnectionWrapper(Connection())
handle = SessionConnectionWrapper(Connection()) # type: ignore
else:
raise dbt.exceptions.DbtProfileError(
f"invalid credential method: {creds.method}"
Expand Down Expand Up @@ -487,7 +494,7 @@ def open(cls, connection):
else:
raise dbt.exceptions.FailedToConnectError("failed to connect") from e
else:
raise exc
raise exc # type: ignore

connection.handle = handle
connection.state = ConnectionState.OPEN
Expand All @@ -507,7 +514,14 @@ def data_type_code_to_name(cls, type_code: Union[type, str]) -> str: # type: ig
return type_code.__name__.upper()


def build_ssl_transport(host, port, username, auth, kerberos_service_name, password=None):
def build_ssl_transport(
host: str,
port: int,
username: str,
auth: str,
kerberos_service_name: str,
password: Optional[str] = None,
) -> "thrift_sasl.TSaslClientTransport":
transport = None
if port is None:
port = 10000
Expand All @@ -531,7 +545,7 @@ def build_ssl_transport(host, port, username, auth, kerberos_service_name, passw
# to be nonempty.
password = "x"

def sasl_factory():
def sasl_factory() -> sasl.Client:
sasl_client = sasl.Client()
sasl_client.setAttr("host", host)
if sasl_auth == "GSSAPI":
Expand Down
31 changes: 17 additions & 14 deletions dbt/adapters/spark/impl.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import re
from concurrent.futures import Future
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Optional, Union, Type, Tuple, Callable
from typing import Any, Dict, Iterable, List, Optional, Union, Type, Tuple, Callable, Set

from dbt.adapters.base.relation import InformationSchema
from dbt.contracts.graph.manifest import Manifest

from typing_extensions import TypeAlias

Expand Down Expand Up @@ -109,27 +112,27 @@ def date_function(cls) -> str:
return "current_timestamp()"

@classmethod
def convert_text_type(cls, agate_table, col_idx):
def convert_text_type(cls, agate_table: agate.Table, col_idx: int) -> str:
return "string"

@classmethod
def convert_number_type(cls, agate_table, col_idx):
def convert_number_type(cls, agate_table: agate.Table, col_idx: int) -> str:
decimals = agate_table.aggregate(agate.MaxPrecision(col_idx))
return "double" if decimals else "bigint"

@classmethod
def convert_date_type(cls, agate_table, col_idx):
def convert_date_type(cls, agate_table: agate.Table, col_idx: int) -> str:
return "date"

@classmethod
def convert_time_type(cls, agate_table, col_idx):
def convert_time_type(cls, agate_table: agate.Table, col_idx: int) -> str:
return "time"

@classmethod
def convert_datetime_type(cls, agate_table, col_idx):
def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str:
return "timestamp"

def quote(self, identifier):
def quote(self, identifier: str) -> str: # type: ignore
return "`{}`".format(identifier)

def _get_relation_information(self, row: agate.Row) -> RelationInfo:
Expand Down Expand Up @@ -344,7 +347,7 @@ def _get_columns_for_catalog(self, relation: BaseRelation) -> Iterable[Dict[str,
as_dict["table_database"] = None
yield as_dict

def get_catalog(self, manifest):
def get_catalog(self, manifest: Manifest) -> Tuple[agate.Table, List[Exception]]:
schema_map = self._get_catalog_schemas(manifest)
if len(schema_map) > 1:
raise dbt.exceptions.CompilationError(
Expand All @@ -370,9 +373,9 @@ def get_catalog(self, manifest):

def _get_one_catalog(
self,
information_schema,
schemas,
manifest,
information_schema: InformationSchema,
schemas: Set[str],
manifest: Manifest,
) -> agate.Table:
if len(schemas) != 1:
raise dbt.exceptions.CompilationError(
Expand All @@ -388,7 +391,7 @@ def _get_one_catalog(
columns.extend(self._get_columns_for_catalog(relation))
return agate.Table.from_object(columns, column_types=DEFAULT_TYPE_TESTER)

def check_schema_exists(self, database, schema):
def check_schema_exists(self, database: str, schema: str) -> bool:
results = self.execute_macro(LIST_SCHEMAS_MACRO_NAME, kwargs={"database": database})

exists = True if schema in [row[0] for row in results] else False
Expand Down Expand Up @@ -425,7 +428,7 @@ def get_rows_different_sql(
# This is for use in the test suite
# Spark doesn't have 'commit' and 'rollback', so this override
# doesn't include those commands.
def run_sql_for_tests(self, sql, fetch, conn):
def run_sql_for_tests(self, sql, fetch, conn): # type: ignore
cursor = conn.handle.cursor()
try:
cursor.execute(sql)
Expand Down Expand Up @@ -477,7 +480,7 @@ def standardize_grants_dict(self, grants_table: agate.Table) -> dict:
grants_dict.update({privilege: [grantee]})
return grants_dict

def debug_query(self):
def debug_query(self) -> None:
"""Override for DebugTask method"""
self.execute("select 1 as id")

Expand Down
Loading

0 comments on commit 8ea1597

Please sign in to comment.