From 32feab41006897de182bfa684813be230027aca1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miroslav=20=C5=A0ediv=C3=BD?= <6774676+eumiro@users.noreply.github.com> Date: Tue, 22 Aug 2023 06:44:29 +0000 Subject: [PATCH] Refactor: Simplify code in Apache/Alibaba providers (#33227) --- .../alibaba/cloud/hooks/analyticdb_spark.py | 2 +- airflow/providers/apache/beam/hooks/beam.py | 6 +-- .../providers/apache/beam/triggers/beam.py | 47 +++++++++---------- airflow/providers/apache/hive/hooks/hive.py | 33 +++++-------- airflow/providers/apache/livy/hooks/livy.py | 12 ++--- .../providers/apache/spark/hooks/spark_sql.py | 2 +- 6 files changed, 44 insertions(+), 58 deletions(-) diff --git a/airflow/providers/alibaba/cloud/hooks/analyticdb_spark.py b/airflow/providers/alibaba/cloud/hooks/analyticdb_spark.py index 9881ca38ae4c..e06ee912281b 100644 --- a/airflow/providers/alibaba/cloud/hooks/analyticdb_spark.py +++ b/airflow/providers/alibaba/cloud/hooks/analyticdb_spark.py @@ -321,7 +321,7 @@ def _validate_extra_conf(conf: dict[Any, Any]) -> bool: if conf: if not isinstance(conf, dict): raise ValueError("'conf' argument must be a dict") - if not all((v and isinstance(v, str)) or isinstance(v, int) for v in conf.values()): + if not all(isinstance(v, (str, int)) and v != "" for v in conf.values()): raise ValueError("'conf' values must be either strings or ints") return True diff --git a/airflow/providers/apache/beam/hooks/beam.py b/airflow/providers/apache/beam/hooks/beam.py index 762dd2f07b67..72dc224626bb 100644 --- a/airflow/providers/apache/beam/hooks/beam.py +++ b/airflow/providers/apache/beam/hooks/beam.py @@ -104,10 +104,8 @@ def process_fd( fd_to_log = {proc.stderr: log.warning, proc.stdout: log.info} func_log = fd_to_log[fd] - while True: - line = fd.readline().decode() - if not line: - return + for line in iter(fd.readline, b""): + line = line.decode() if process_line_callback: process_line_callback(line) func_log(line.rstrip("\n")) diff --git a/airflow/providers/apache/beam/triggers/beam.py b/airflow/providers/apache/beam/triggers/beam.py index 9c29a9fbe272..0d201cd8c9a4 100644 --- a/airflow/providers/apache/beam/triggers/beam.py +++ b/airflow/providers/apache/beam/triggers/beam.py @@ -85,32 +85,29 @@ def serialize(self) -> tuple[str, dict[str, Any]]: async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] """Get current pipeline status and yields a TriggerEvent.""" hook = self._get_async_hook() - while True: - try: - return_code = await hook.start_python_pipeline_async( - variables=self.variables, - py_file=self.py_file, - py_options=self.py_options, - py_interpreter=self.py_interpreter, - py_requirements=self.py_requirements, - py_system_site_packages=self.py_system_site_packages, + try: + return_code = await hook.start_python_pipeline_async( + variables=self.variables, + py_file=self.py_file, + py_options=self.py_options, + py_interpreter=self.py_interpreter, + py_requirements=self.py_requirements, + py_system_site_packages=self.py_system_site_packages, + ) + except Exception as e: + self.log.exception("Exception occurred while checking for pipeline state") + yield TriggerEvent({"status": "error", "message": str(e)}) + else: + if return_code == 0: + yield TriggerEvent( + { + "status": "success", + "message": "Pipeline has finished SUCCESSFULLY", + } ) - if return_code == 0: - yield TriggerEvent( - { - "status": "success", - "message": "Pipeline has finished SUCCESSFULLY", - } - ) - return - else: - yield TriggerEvent({"status": "error", "message": "Operation failed"}) - return - - except Exception as e: - self.log.exception("Exception occurred while checking for pipeline state") - yield TriggerEvent({"status": "error", "message": str(e)}) - return + else: + yield TriggerEvent({"status": "error", "message": "Operation failed"}) + return def _get_async_hook(self) -> BeamAsyncHook: return BeamAsyncHook(runner=self.runner) diff --git a/airflow/providers/apache/hive/hooks/hive.py b/airflow/providers/apache/hive/hooks/hive.py index ea004860b494..7f02619024fd 100644 --- a/airflow/providers/apache/hive/hooks/hive.py +++ b/airflow/providers/apache/hive/hooks/hive.py @@ -277,13 +277,11 @@ def run_cli( ) self.sub_process = sub_process stdout = "" - while True: - line = sub_process.stdout.readline() - if not line: - break - stdout += line.decode("UTF-8") + for line in iter(sub_process.stdout.readline, b""): + line = line.decode() + stdout += line if verbose: - self.log.info(line.decode("UTF-8").strip()) + self.log.info(line.strip()) sub_process.wait() if sub_process.returncode: @@ -704,25 +702,20 @@ def _get_max_partition_from_part_specs( # Assuming all specs have the same keys. if partition_key not in part_specs[0].keys(): raise AirflowException(f"Provided partition_key {partition_key} is not in part_specs.") - is_subset = None - if filter_map: - is_subset = set(filter_map.keys()).issubset(set(part_specs[0].keys())) - if filter_map and not is_subset: + if filter_map and not set(filter_map).issubset(part_specs[0]): raise AirflowException( f"Keys in provided filter_map {', '.join(filter_map.keys())} " f"are not subset of part_spec keys: {', '.join(part_specs[0].keys())}" ) - candidates = [ - p_dict[partition_key] - for p_dict in part_specs - if filter_map is None or all(item in p_dict.items() for item in filter_map.items()) - ] - - if not candidates: - return None - else: - return max(candidates) + return max( + ( + p_dict[partition_key] + for p_dict in part_specs + if filter_map is None or all(item in p_dict.items() for item in filter_map.items()) + ), + default=None, + ) def max_partition( self, diff --git a/airflow/providers/apache/livy/hooks/livy.py b/airflow/providers/apache/livy/hooks/livy.py index ede3d2eb985e..ba2ff1bb1318 100644 --- a/airflow/providers/apache/livy/hooks/livy.py +++ b/airflow/providers/apache/livy/hooks/livy.py @@ -432,7 +432,7 @@ def _validate_list_of_stringables(vals: Sequence[str | int | float]) -> bool: if ( vals is None or not isinstance(vals, (tuple, list)) - or any(1 for val in vals if not isinstance(val, (str, int, float))) + or not all(isinstance(val, (str, int, float)) for val in vals) ): raise ValueError("List of strings expected") return True @@ -448,7 +448,7 @@ def _validate_extra_conf(conf: dict[Any, Any]) -> bool: if conf: if not isinstance(conf, dict): raise ValueError("'conf' argument must be a dict") - if not all((v and isinstance(v, str)) or isinstance(v, int) for v in conf.values()): + if not all(isinstance(v, (str, int)) and v != "" for v in conf.values()): raise ValueError("'conf' values must be either strings or ints") return True @@ -542,8 +542,7 @@ async def _do_api_call_async( else: return {"Response": f"Unexpected HTTP Method: {self.method}", "status": "error"} - attempt_num = 1 - while True: + for attempt_num in range(1, 1 + self.retry_limit): response = await request_func( url, json=data if self.method in ("POST", "PATCH") else None, @@ -568,7 +567,6 @@ async def _do_api_call_async( # Don't retry. return {"Response": {e.message}, "Status Code": {e.status}, "status": "error"} - attempt_num += 1 await asyncio.sleep(self.retry_delay) def _generate_base_url(self, conn: Connection) -> str: @@ -815,7 +813,7 @@ def _validate_list_of_stringables(vals: Sequence[str | int | float]) -> bool: if ( vals is None or not isinstance(vals, (tuple, list)) - or any(1 for val in vals if not isinstance(val, (str, int, float))) + or not all(isinstance(val, (str, int, float)) for val in vals) ): raise ValueError("List of strings expected") return True @@ -831,6 +829,6 @@ def _validate_extra_conf(conf: dict[Any, Any]) -> bool: if conf: if not isinstance(conf, dict): raise ValueError("'conf' argument must be a dict") - if not all((v and isinstance(v, str)) or isinstance(v, int) for v in conf.values()): + if not all(isinstance(v, (str, int)) and v != "" for v in conf.values()): raise ValueError("'conf' values must be either strings or ints") return True diff --git a/airflow/providers/apache/spark/hooks/spark_sql.py b/airflow/providers/apache/spark/hooks/spark_sql.py index 6864aa52fe0c..41dc741ccdd3 100644 --- a/airflow/providers/apache/spark/hooks/spark_sql.py +++ b/airflow/providers/apache/spark/hooks/spark_sql.py @@ -134,7 +134,7 @@ def _prepare_command(self, cmd: str | list[str]) -> list[str]: connection_cmd += ["--num-executors", str(self._num_executors)] if self._sql: sql = self._sql.strip() - if sql.endswith(".sql") or sql.endswith(".hql"): + if sql.endswith((".sql", ".hql")): connection_cmd += ["-f", sql] else: connection_cmd += ["-e", sql]