From 8feefbbc9da9c30edbfcfff38561a2eabb34d8a3 Mon Sep 17 00:00:00 2001 From: MariusWirtz Date: Wed, 27 Oct 2021 20:22:25 +0200 Subject: [PATCH] Minor changes to tm1-hook --- airflow_tm1/hooks/tm1.py | 27 +++++++++++---------------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/airflow_tm1/hooks/tm1.py b/airflow_tm1/hooks/tm1.py index b8bc1cb..4b46668 100644 --- a/airflow_tm1/hooks/tm1.py +++ b/airflow_tm1/hooks/tm1.py @@ -1,7 +1,5 @@ -from typing import Optional - -from airflow.hooks.base_hook import BaseHook from TM1py.Services import TM1Service +from airflow.hooks.base import BaseHook class TM1Hook(BaseHook): @@ -9,20 +7,18 @@ class TM1Hook(BaseHook): Interact with IBM Cognos TM1, using the TM1py library. """ - def __init__(self, tm1_conn_id: str = "tm1_default") -> None: + def __init__(self, tm1_conn_id: str, **kwargs) -> None: """ A hook that uses TM1py to connect to a TM1 database. :param tm1_conn_id: The name of the TM1 connection to use. :type tm1_conn_id: str """ + super().__init__(**kwargs) self.tm1_conn_id = tm1_conn_id - self.tm1: Optional[TM1Service] = None + self.tm1: TM1Service = None self.address = None self.port = None - self.user = None - self.password = None - self.db = None - self.server_version = None + self.instance_name = None def get_conn(self) -> TM1Service: """ @@ -32,8 +28,6 @@ def get_conn(self) -> TM1Service: conn = self.get_connection(self.tm1_conn_id) self.address = conn.host self.port = conn.port - self.user = conn.login - self.password = conn.password # check for relevant additional parameters in conn.extra # except session_id as not sure if this makes sense in an Airflow context @@ -43,7 +37,6 @@ def get_conn(self) -> TM1Service: "namespace", "ssl", "session_context", - "logging", "timeout", "connection_pool_size", ] @@ -60,12 +53,14 @@ def get_conn(self) -> TM1Service: self.tm1 = TM1Service( address=self.address, port=self.port, - user=self.user, - password=self.password, + user=conn.login, + password=conn.password, **extra_args ) - self.db = self.tm1.server.get_server_name() - self.server_version = self.tm1.server.get_product_version() + self.instance_name = self.tm1.server.get_server_name() return self.tm1 + + def logout(self): + self.tm1.logout()