From c1b79ad231827e3af4a51ca84ec3a40de2f6bfdc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Thu, 28 Nov 2024 15:52:26 +0800 Subject: [PATCH 1/6] add retry and cache for credential of OdpsDataset --- requirements/runtime.txt | 1 + tzrec/datasets/odps_dataset.py | 36 ++++++++++++++++++++++++++++++---- 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 2fe3893..417db42 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -9,6 +9,7 @@ grpcio-tools<1.63.0 pandas pyfg @ https://tzrec.oss-cn-beijing.aliyuncs.com/third_party/pyfg-0.3.7-cp311-cp311-linux_x86_64.whl ; python_version=="3.11" pyfg @ https://tzrec.oss-cn-beijing.aliyuncs.com/third_party/pyfg-0.3.7-cp310-cp310-linux_x86_64.whl ; python_version=="3.10" +pyodps>0.11.6 scikit-learn tensorboard torch==2.5.0 diff --git a/tzrec/datasets/odps_dataset.py b/tzrec/datasets/odps_dataset.py index a988067..20f3fbd 100644 --- a/tzrec/datasets/odps_dataset.py +++ b/tzrec/datasets/odps_dataset.py @@ -20,7 +20,7 @@ import urllib3 from alibabacloud_credentials import providers from odps import ODPS -from odps.accounts import AliyunAccount, BaseAccount, CredentialProviderAccount +from odps.accounts import AliyunAccount, BaseAccount, StsAccount from odps.apis.storage_api import ( ArrowReader, ReadRowsRequest, @@ -34,6 +34,7 @@ WriteRowsRequest, ) from odps.errors import ODPSError +from Tea.exceptions import RetryError from torch import distributed as dist from tzrec.constant import Mode @@ -77,6 +78,35 @@ TYPE_PA_TO_TABLE = {v: k for k, v in TYPE_TABLE_TO_PA.items()} +class _CredentialProviderAccount(StsAccount): + def __init__(self, credential_provider): + self.provider = credential_provider + try: + self.credential = self.provider.get_credential() + except Exception: + self.credential = self.provider.get_credentials() + super(_CredentialProviderAccount, self).__init__(None, None, None) + + def sign_request(self, req, endpoint, region_name=None): + max_retry_count = 3 + retry_cnt = 0 + while True: + try: + self.access_id = self.credential.get_access_key_id() + self.secret_access_key = self.credential.get_access_key_secret() + self.sts_token = self.credential.get_security_token() + break + except RetryError as e: + if retry_cnt >= max_retry_count: + raise e + retry_cnt += 1 + time.sleep(random.choice([5, 9, 12])) + continue + return super(_CredentialProviderAccount, self).sign_request( + req, endpoint, region_name=region_name + ) + + def _parse_odps_config_file(odps_config_path: str) -> Tuple[str, str, str]: """Parse odps config file.""" if os.path.exists(odps_config_path): @@ -110,9 +140,7 @@ def _create_odps_account() -> Tuple[BaseAccount, str]: account = AliyunAccount(account_id, account_key) elif "ALIBABA_CLOUD_CREDENTIALS_URI" in os.environ: p = providers.DefaultCredentialsProvider() - # prevent too much request to credential server after forked - p.get_credentials().get_credential() - account = CredentialProviderAccount(p) + account = _CredentialProviderAccount(p) try: odps_endpoint = os.environ["ODPS_ENDPOINT"] except KeyError as err: From b11e55e9395a19969c0f921490184a2696582017 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Thu, 28 Nov 2024 16:05:34 +0800 Subject: [PATCH 2/6] add retry and cache for credential of OdpsDataset --- tzrec/datasets/odps_dataset.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tzrec/datasets/odps_dataset.py b/tzrec/datasets/odps_dataset.py index 20f3fbd..990e6f4 100644 --- a/tzrec/datasets/odps_dataset.py +++ b/tzrec/datasets/odps_dataset.py @@ -79,6 +79,7 @@ class _CredentialProviderAccount(StsAccount): + # pyre-ignore [2,3] def __init__(self, credential_provider): self.provider = credential_provider try: @@ -87,6 +88,7 @@ def __init__(self, credential_provider): self.credential = self.provider.get_credentials() super(_CredentialProviderAccount, self).__init__(None, None, None) + # pyre-ignore [2,3] def sign_request(self, req, endpoint, region_name=None): max_retry_count = 3 retry_cnt = 0 From 6c27d3ab14ac5f5212c18b1e820b73e71c6a8b3c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Thu, 28 Nov 2024 16:07:05 +0800 Subject: [PATCH 3/6] bump up version --- tzrec/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tzrec/version.py b/tzrec/version.py index bb8363e..c705736 100644 --- a/tzrec/version.py +++ b/tzrec/version.py @@ -9,4 +9,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.6.4" +__version__ = "0.6.5" From d385ee238bf957163019b1c0986d158ccd7166fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Thu, 28 Nov 2024 17:22:08 +0800 Subject: [PATCH 4/6] bump up version --- tzrec/datasets/odps_dataset.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tzrec/datasets/odps_dataset.py b/tzrec/datasets/odps_dataset.py index 990e6f4..3b068fb 100644 --- a/tzrec/datasets/odps_dataset.py +++ b/tzrec/datasets/odps_dataset.py @@ -94,9 +94,10 @@ def sign_request(self, req, endpoint, region_name=None): retry_cnt = 0 while True: try: - self.access_id = self.credential.get_access_key_id() - self.secret_access_key = self.credential.get_access_key_secret() - self.sts_token = self.credential.get_security_token() + credential = self.credential.get_credential() + self.access_id = credential.get_access_key_id() + self.secret_access_key = credential.get_access_key_secret() + self.sts_token = credential.get_security_token() break except RetryError as e: if retry_cnt >= max_retry_count: From 7590f38bd3a80798db096e023613f4fd761f99e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Thu, 28 Nov 2024 20:50:12 +0800 Subject: [PATCH 5/6] use CredClient --- requirements/runtime.txt | 2 +- tzrec/datasets/odps_dataset.py | 43 +++++----------------------------- 2 files changed, 7 insertions(+), 38 deletions(-) diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 417db42..3bca7a9 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -9,7 +9,7 @@ grpcio-tools<1.63.0 pandas pyfg @ https://tzrec.oss-cn-beijing.aliyuncs.com/third_party/pyfg-0.3.7-cp311-cp311-linux_x86_64.whl ; python_version=="3.11" pyfg @ https://tzrec.oss-cn-beijing.aliyuncs.com/third_party/pyfg-0.3.7-cp310-cp310-linux_x86_64.whl ; python_version=="3.10" -pyodps>0.11.6 +pyodps>0.12.0 scikit-learn tensorboard torch==2.5.0 diff --git a/tzrec/datasets/odps_dataset.py b/tzrec/datasets/odps_dataset.py index 3b068fb..e0d4b6e 100644 --- a/tzrec/datasets/odps_dataset.py +++ b/tzrec/datasets/odps_dataset.py @@ -18,9 +18,9 @@ import pyarrow as pa import urllib3 -from alibabacloud_credentials import providers +from alibabacloud_credentials.client import Client as CredClient from odps import ODPS -from odps.accounts import AliyunAccount, BaseAccount, StsAccount +from odps.accounts import AliyunAccount, BaseAccount, CredentialProviderAccount from odps.apis.storage_api import ( ArrowReader, ReadRowsRequest, @@ -34,7 +34,6 @@ WriteRowsRequest, ) from odps.errors import ODPSError -from Tea.exceptions import RetryError from torch import distributed as dist from tzrec.constant import Mode @@ -78,38 +77,6 @@ TYPE_PA_TO_TABLE = {v: k for k, v in TYPE_TABLE_TO_PA.items()} -class _CredentialProviderAccount(StsAccount): - # pyre-ignore [2,3] - def __init__(self, credential_provider): - self.provider = credential_provider - try: - self.credential = self.provider.get_credential() - except Exception: - self.credential = self.provider.get_credentials() - super(_CredentialProviderAccount, self).__init__(None, None, None) - - # pyre-ignore [2,3] - def sign_request(self, req, endpoint, region_name=None): - max_retry_count = 3 - retry_cnt = 0 - while True: - try: - credential = self.credential.get_credential() - self.access_id = credential.get_access_key_id() - self.secret_access_key = credential.get_access_key_secret() - self.sts_token = credential.get_security_token() - break - except RetryError as e: - if retry_cnt >= max_retry_count: - raise e - retry_cnt += 1 - time.sleep(random.choice([5, 9, 12])) - continue - return super(_CredentialProviderAccount, self).sign_request( - req, endpoint, region_name=region_name - ) - - def _parse_odps_config_file(odps_config_path: str) -> Tuple[str, str, str]: """Parse odps config file.""" if os.path.exists(odps_config_path): @@ -142,8 +109,10 @@ def _create_odps_account() -> Tuple[BaseAccount, str]: ) account = AliyunAccount(account_id, account_key) elif "ALIBABA_CLOUD_CREDENTIALS_URI" in os.environ: - p = providers.DefaultCredentialsProvider() - account = _CredentialProviderAccount(p) + credentials_client = CredClient() + # prevent too much request to credential server after forked + credentials_client.get_credential() + account = CredentialProviderAccount(credentials_client) try: odps_endpoint = os.environ["ODPS_ENDPOINT"] except KeyError as err: From 47c14229487894525b76577bf1e20d0865c07a9f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Thu, 28 Nov 2024 20:52:17 +0800 Subject: [PATCH 6/6] use CredClient --- requirements/runtime.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 3bca7a9..7d08923 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -9,7 +9,7 @@ grpcio-tools<1.63.0 pandas pyfg @ https://tzrec.oss-cn-beijing.aliyuncs.com/third_party/pyfg-0.3.7-cp311-cp311-linux_x86_64.whl ; python_version=="3.11" pyfg @ https://tzrec.oss-cn-beijing.aliyuncs.com/third_party/pyfg-0.3.7-cp310-cp310-linux_x86_64.whl ; python_version=="3.10" -pyodps>0.12.0 +pyodps>=0.12.0 scikit-learn tensorboard torch==2.5.0