From b5dc03983b239beb2cbaccdac527007cbe355c10 Mon Sep 17 00:00:00 2001 From: Matthew Wallace Date: Mon, 4 Dec 2023 16:42:57 -0700 Subject: [PATCH] Allow Unix socket connection rather than just TCP --- dbt/adapters/mariadb/connections.py | 10 ++++++++-- dbt/adapters/mysql/connections.py | 10 ++++++++-- dbt/adapters/mysql5/connections.py | 10 ++++++++-- 3 files changed, 24 insertions(+), 6 deletions(-) diff --git a/dbt/adapters/mariadb/connections.py b/dbt/adapters/mariadb/connections.py index 9583a7d..6b30f9a 100644 --- a/dbt/adapters/mariadb/connections.py +++ b/dbt/adapters/mariadb/connections.py @@ -17,7 +17,8 @@ @dataclass(init=False) class MariaDBCredentials(Credentials): - server: str + server: Optional[str] = None + unix_socket: Optional[str] = None port: Optional[int] = None database: Optional[str] = None schema: str @@ -62,6 +63,7 @@ def _connection_keys(self): """ return ( "server", + "unix_socket", "port", "database", "schema", @@ -81,7 +83,6 @@ def open(cls, connection): credentials = cls.get_credentials(connection.credentials) kwargs = {} - kwargs["host"] = credentials.server kwargs["user"] = credentials.username kwargs["passwd"] = credentials.password kwargs["buffered"] = True @@ -89,6 +90,11 @@ def open(cls, connection): if credentials.ssl_disabled: kwargs["ssl_disabled"] = credentials.ssl_disabled + if credentials.server: + kwargs["host"] = credentials.server + elif credentials.unix_socket: + kwargs["unix_socket"] = credentials.unix_socket + if credentials.port: kwargs["port"] = credentials.port diff --git a/dbt/adapters/mysql/connections.py b/dbt/adapters/mysql/connections.py index 459782c..d8932dd 100644 --- a/dbt/adapters/mysql/connections.py +++ b/dbt/adapters/mysql/connections.py @@ -17,7 +17,8 @@ @dataclass(init=False) class MySQLCredentials(Credentials): - server: str + server: Optional[str] = None + unix_socket: Optional[str] = None port: Optional[int] = None database: Optional[str] = None schema: str @@ -61,6 +62,7 @@ def _connection_keys(self): """ return ( "server", + "unix_socket", "port", "database", "schema", @@ -80,11 +82,15 @@ def open(cls, connection): credentials = cls.get_credentials(connection.credentials) kwargs = {} - kwargs["host"] = credentials.server kwargs["user"] = credentials.username kwargs["passwd"] = credentials.password kwargs["buffered"] = True + if credentials.server: + kwargs["host"] = credentials.server + elif credentials.unix_socket: + kwargs["unix_socket"] = credentials.unix_socket + if credentials.port: kwargs["port"] = credentials.port diff --git a/dbt/adapters/mysql5/connections.py b/dbt/adapters/mysql5/connections.py index 6c9df5c..f1481a2 100644 --- a/dbt/adapters/mysql5/connections.py +++ b/dbt/adapters/mysql5/connections.py @@ -17,7 +17,8 @@ @dataclass(init=False) class MySQLCredentials(Credentials): - server: str + server: Optional[str] = None + unix_socket: Optional[str] = None port: Optional[int] = None database: Optional[str] = None schema: str @@ -62,6 +63,7 @@ def _connection_keys(self): """ return ( "server", + "unix_socket", "port", "database", "schema", @@ -81,7 +83,6 @@ def open(cls, connection): credentials = cls.get_credentials(connection.credentials) kwargs = {} - kwargs["host"] = credentials.server kwargs["user"] = credentials.username kwargs["passwd"] = credentials.password kwargs["buffered"] = True @@ -89,6 +90,11 @@ def open(cls, connection): if credentials.ssl_disabled: kwargs["ssl_disabled"] = credentials.ssl_disabled + if credentials.server: + kwargs["host"] = credentials.server + elif credentials.unix_socket: + kwargs["unix_socket"] = credentials.unix_socket + if credentials.port: kwargs["port"] = credentials.port