From 4ea96ce85fdceeab9e48091784cd111b7102c310 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Mon, 23 Dec 2024 17:41:03 +0300 Subject: [PATCH] Fix auth credentials --- tests/aio/test_credentials.py | 3 ++- ydb/_topic_reader/topic_reader_asyncio.py | 5 ++++- ydb/_topic_writer/topic_writer_asyncio.py | 5 ++++- ydb/aio/credentials.py | 14 +++++++++++--- 4 files changed, 21 insertions(+), 6 deletions(-) diff --git a/tests/aio/test_credentials.py b/tests/aio/test_credentials.py index a6f1d170..5000541c 100644 --- a/tests/aio/test_credentials.py +++ b/tests/aio/test_credentials.py @@ -36,9 +36,10 @@ async def test_yandex_service_account_credentials(): tests.auth.test_credentials.PRIVATE_KEY, server.get_endpoint(), ) - t = (await credentials.auth_metadata())[0][1] + t = await credentials.get_auth_token() assert t == "test_token" assert credentials.get_expire_time() <= 42 + server.stop() diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index 6833492d..351efb9a 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -516,7 +516,10 @@ async def _read_messages_loop(self): async def _update_token_loop(self): while True: await asyncio.sleep(self._update_token_interval) - await self._update_token(token=self._get_token_function()) + token = self._get_token_function() + if asyncio.iscoroutine(token): + token = await token + await self._update_token(token=token) async def _update_token(self, token: str): await self._update_token_event.wait() diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index d759072c..869808f7 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -686,7 +686,10 @@ def write(self, messages: List[InternalMessage]): async def _update_token_loop(self): while True: await asyncio.sleep(self._update_token_interval) - await self._update_token(token=self._get_token_function()) + token = self._get_token_function() + if asyncio.iscoroutine(token): + token = await token + await self._update_token(token=token) async def _update_token(self, token: str): await self._update_token_event.wait() diff --git a/ydb/aio/credentials.py b/ydb/aio/credentials.py index 08db1fd0..03c96a37 100644 --- a/ydb/aio/credentials.py +++ b/ydb/aio/credentials.py @@ -1,11 +1,13 @@ -import time - import abc import asyncio import logging -from ydb import issues, credentials +import time + +from ydb import credentials +from ydb import issues logger = logging.getLogger(__name__) +YDB_AUTH_TICKET_HEADER = "x-ydb-auth-ticket" class _OneToManyValue(object): @@ -64,6 +66,12 @@ def __init__(self): async def _make_token_request(self): pass + async def get_auth_token(self) -> str: + for header, token in await self.auth_metadata(): + if header == YDB_AUTH_TICKET_HEADER: + return token + return "" + async def _refresh(self): current_time = time.time() self._log_refresh_start(current_time)