diff --git a/addon_imps/citations/__init__.py b/addon_imps/citations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/addon_imps/citations/zotero_org.py b/addon_imps/citations/zotero_org.py new file mode 100644 index 00000000..f8926b63 --- /dev/null +++ b/addon_imps/citations/zotero_org.py @@ -0,0 +1,6 @@ +from addon_toolkit.interfaces.storage import StorageAddonImp + + +class ZoteroOrgCitationImp(StorageAddonImp): + async def get_external_account_id(self, auth_result_extras: dict[str, str]) -> str: + return auth_result_extras["userID"] diff --git a/addon_imps/storage/box_dot_com.py b/addon_imps/storage/box_dot_com.py index 57ca8b87..73cef9d2 100644 --- a/addon_imps/storage/box_dot_com.py +++ b/addon_imps/storage/box_dot_com.py @@ -12,6 +12,11 @@ class BoxDotComStorageImp(storage.StorageAddonImp): see https://developer.box.com/reference/ """ + async def get_external_account_id(self, auth_result_extras: dict[str, str]) -> str: + async with self.network.GET("/users/me") as _response: + _json = await _response.json_content() + return str(_json["id"]) + async def list_root_items(self, page_cursor: str = "") -> storage.ItemSampleResult: return storage.ItemSampleResult( items=[await self.get_item_info(_box_root_id())], diff --git a/addon_service/addon_imp/instantiation.py b/addon_service/addon_imp/instantiation.py index 7c97c74d..a06d9106 100644 --- a/addon_service/addon_imp/instantiation.py +++ b/addon_service/addon_imp/instantiation.py @@ -1,4 +1,6 @@ -from addon_service.common.aiohttp_session import get_singleton_client_session__blocking +from asgiref.sync import async_to_sync + +from addon_service.common.aiohttp_session import get_singleton_client_session from addon_service.common.network import GravyvaletHttpRequestor from addon_service.models import AuthorizedStorageAccount from addon_toolkit.interfaces.storage import ( @@ -7,7 +9,7 @@ ) -def get_storage_addon_instance( +async def get_storage_addon_instance( imp_cls: type[StorageAddonImp], account: AuthorizedStorageAccount, config: StorageConfig, @@ -16,8 +18,11 @@ def get_storage_addon_instance( return imp_cls( config=config, network=GravyvaletHttpRequestor( - client_session=get_singleton_client_session__blocking(), + client_session=await get_singleton_client_session(), prefix_url=config.external_api_url, account=account, ), ) + + +get_storage_addon_instance__blocking = async_to_sync(get_storage_addon_instance) diff --git a/addon_service/admin/__init__.py b/addon_service/admin/__init__.py index f2170d74..bc43d59e 100644 --- a/addon_service/admin/__init__.py +++ b/addon_service/admin/__init__.py @@ -33,3 +33,13 @@ class OAuth2ClientConfigAdmin(GravyvaletModelAdmin): "created", "modified", ) + + +@admin.register(models.OAuth1ClientConfig) +@linked_many_field("external_storage_services") +class OAuth1ClientConfigAdmin(GravyvaletModelAdmin): + readonly_fields = ( + "id", + "created", + "modified", + ) diff --git a/addon_service/authorized_storage_account/callbacks.py b/addon_service/authorized_storage_account/callbacks.py new file mode 100644 index 00000000..eccf244c --- /dev/null +++ b/addon_service/authorized_storage_account/callbacks.py @@ -0,0 +1,17 @@ +from addon_service.addon_imp.instantiation import get_storage_addon_instance +from addon_service.authorized_storage_account.models import AuthorizedStorageAccount + + +async def after_successful_auth( + account: AuthorizedStorageAccount, + auth_result_extras: dict[str, str] | None = None, +): + _imp = await get_storage_addon_instance( + account.imp_cls, # type: ignore[arg-type] + account, + account.storage_imp_config(), + ) + account.external_account_id = await _imp.get_external_account_id( + auth_result_extras or {} + ) + await account.asave() diff --git a/addon_service/authorized_storage_account/models.py b/addon_service/authorized_storage_account/models.py index 8db07988..bbe56cf1 100644 --- a/addon_service/authorized_storage_account/models.py +++ b/addon_service/authorized_storage_account/models.py @@ -14,8 +14,9 @@ from addon_service.common.service_types import ServiceTypes from addon_service.common.validators import validate_addon_capability from addon_service.credentials.models import ExternalCredentials -from addon_service.oauth import utils as oauth_utils -from addon_service.oauth.models import ( +from addon_service.oauth1 import utils as oauth1_utils +from addon_service.oauth2 import utils as oauth2_utils +from addon_service.oauth2.models import ( OAuth2ClientConfig, OAuth2TokenMetadata, ) @@ -23,11 +24,14 @@ AddonCapabilities, AddonImp, ) +from addon_toolkit.credentials import ( + Credentials, + OAuth1Credentials, +) from addon_toolkit.interfaces.storage import StorageConfig class AuthorizedStorageAccountManager(models.Manager): - def active(self): """filter to accounts owned by non-deactivated users""" return self.get_queryset().filter(account_owner__deactivated__isnull=True) @@ -68,12 +72,21 @@ class AuthorizedStorageAccount(AddonsServiceBaseModel): blank=True, related_name="authorized_storage_account", ) + _temporary_oauth1_credentials = models.OneToOneField( + "addon_service.ExternalCredentials", + on_delete=models.CASCADE, + primary_key=False, + null=True, + blank=True, + related_name="temporary_authorized_storage_account", + ) oauth2_token_metadata = models.ForeignKey( "addon_service.OAuth2TokenMetadata", on_delete=models.CASCADE, # probs not null=True, blank=True, related_name="authorized_storage_accounts", + related_query_name="%(class)s_authorized_storage_account", ) class Meta: @@ -108,17 +121,40 @@ def credentials(self): @credentials.setter def credentials(self, credentials_data): + if self.temporary_oauth1_credentials: + self._temporary_oauth1_credentials.delete() + self._temporary_oauth1_credentials = None + self._set_credentials("_credentials", credentials_data) + + @property + def temporary_oauth1_credentials(self) -> OAuth1Credentials | None: + if self._temporary_oauth1_credentials: + return self._temporary_oauth1_credentials.decrypted_credentials + return None + + @temporary_oauth1_credentials.setter + def temporary_oauth1_credentials(self, credentials_data: OAuth1Credentials): + if self.credentials_format is not CredentialsFormats.OAUTH1A: + raise ValidationError( + "Trying to set temporary credentials for non OAuth1A account" + ) + self._set_credentials("_temporary_oauth1_credentials", credentials_data) + + def _set_credentials(self, credentials_field: str, credentials_data: Credentials): creds_type = type(credentials_data) + if not hasattr(self, credentials_field): + raise ValidationError("Trying to set credentials to non-existing field") if creds_type is not self.credentials_format.dataclass: raise ValidationError( - f"Expectd credentials of type type {self.credentials_format.dataclass}." + f"Expected credentials of type type {self.credentials_format.dataclass}." f"Got credentials of type {creds_type}." ) - if not self._credentials: - self._credentials = ExternalCredentials.new() + if not getattr(self, credentials_field, None): + setattr(self, credentials_field, ExternalCredentials.new()) try: - self._credentials.decrypted_credentials = credentials_data - self._credentials.save() + creds = getattr(self, credentials_field) + creds.decrypted_credentials = credentials_data + creds.save() except TypeError as e: raise ValidationError(e) @@ -129,7 +165,7 @@ def authorized_capabilities(self) -> AddonCapabilities: @authorized_capabilities.setter def authorized_capabilities(self, new_capabilities: AddonCapabilities): - """set int_authorized_capabilities without caring it's int""" + """set int_authorized_capabilities without caring its int""" self.int_authorized_capabilities = new_capabilities.value @property @@ -158,18 +194,32 @@ def authorized_operation_names(self) -> list[str]: @property def auth_url(self) -> str | None: - """Generates the url required to initiate OAuth2 credentials exchange. + """Generates the url required to initiate OAuth credentials exchange. - Returns None if the ExternalStorageService does not support OAuth2 - or if the initial credentials exchange has already ocurred. + Returns None if the ExternalStorageService does not support OAuth + or if the initial credentials exchange has already occurred. """ - if self.credentials_format is not CredentialsFormats.OAUTH2: - return None + match self.credentials_format: + case CredentialsFormats.OAUTH2: + return self.oauth2_auth_url + case CredentialsFormats.OAUTH1A: + return self.oauth1_auth_url + @property + def oauth1_auth_url(self) -> str: + client_config = self.external_service.oauth1_client_config + if self._temporary_oauth1_credentials: + return oauth1_utils.build_auth_url( + auth_uri=client_config.auth_url, + temporary_oauth_token=self.temporary_oauth1_credentials.oauth_token, + ) + + @property + def oauth2_auth_url(self) -> str | None: state_token = self.oauth2_token_metadata.state_token if not state_token: return None - return oauth_utils.build_auth_url( + return oauth2_utils.build_auth_url( auth_uri=self.external_service.oauth2_client_config.auth_uri, client_id=self.external_service.oauth2_client_config.client_id, state_token=state_token, @@ -178,26 +228,39 @@ def auth_url(self) -> str | None: ) @property - def api_base_url(self): + def api_base_url(self) -> str: return self._api_base_url or self.external_service.api_base_url @api_base_url.setter - def api_base_url(self, value): + def api_base_url(self, value: str): self._api_base_url = value @property def imp_cls(self) -> type[AddonImp]: return self.external_service.addon_imp.imp_cls + @transaction.atomic + def initiate_oauth1_flow(self): + if self.credentials_format is not CredentialsFormats.OAUTH1A: + raise ValueError("Cannot initiate OAuth1 flow for non-OAuth1 credentials") + client_config = self.external_service.oauth1_client_config + request_token_result, _ = async_to_sync(oauth1_utils.get_temporary_token)( + client_config.request_token_url, + client_config.client_key, + client_config.client_secret, + ) + self.temporary_oauth1_credentials = request_token_result + self.save() + @transaction.atomic def initiate_oauth2_flow(self, authorized_scopes=None): if self.credentials_format is not CredentialsFormats.OAUTH2: - raise ValueError("Cannot initaite OAuth flow for non-OAuth credentials") + raise ValueError("Cannot initiate OAuth2 flow for non-OAuth2 credentials") self.oauth2_token_metadata = OAuth2TokenMetadata.objects.create( authorized_scopes=( authorized_scopes or self.external_service.supported_scopes ), - state_nonce=oauth_utils.generate_state_nonce(), + state_nonce=oauth2_utils.generate_state_nonce(), ) self.save() @@ -209,8 +272,8 @@ def storage_imp_config(self) -> StorageConfig: external_account_id=self.external_account_id, ) - def clean(self, *args, **kwargs): - super().clean(*args, **kwargs) + def clean(self): + super().clean() self.validate_api_base_url() self.validate_oauth_state() @@ -249,13 +312,14 @@ def validate_oauth_state(self): ) ### - # async functions for use in oauth callback flows - - async def refresh_oauth_access_token(self) -> None: - _oauth_client_config, _oauth_token_metadata = ( - await self._load_client_config_and_token_metadata() - ) - _fresh_token_result = await oauth_utils.get_refreshed_access_token( + # async functions for use in oauth2 callback flows + + async def refresh_oauth2_access_token(self) -> None: + ( + _oauth_client_config, + _oauth_token_metadata, + ) = await self._load_oauth2_client_config_and_token_metadata() + _fresh_token_result = await oauth2_utils.get_refreshed_access_token( token_endpoint_url=_oauth_client_config.token_endpoint_url, refresh_token=_oauth_token_metadata.refresh_token, auth_callback_url=_oauth_client_config.auth_callback_url, @@ -263,12 +327,12 @@ async def refresh_oauth_access_token(self) -> None: client_secret=_oauth_client_config.client_secret, ) await _oauth_token_metadata.update_with_fresh_token(_fresh_token_result) - await sync_to_async(self.refresh_from_db)() + await self.arefresh_from_db() - refresh_oauth_access_token__blocking = async_to_sync(refresh_oauth_access_token) + refresh_oauth_access_token__blocking = async_to_sync(refresh_oauth2_access_token) @sync_to_async - def _load_client_config_and_token_metadata( + def _load_oauth2_client_config_and_token_metadata( self, ) -> tuple[OAuth2ClientConfig, OAuth2TokenMetadata]: # wrap db access in `sync_to_async` diff --git a/addon_service/authorized_storage_account/serializers.py b/addon_service/authorized_storage_account/serializers.py index 7598df7c..3c3ecd38 100644 --- a/addon_service/authorized_storage_account/serializers.py +++ b/addon_service/authorized_storage_account/serializers.py @@ -1,3 +1,4 @@ +from asgiref.sync import async_to_sync from django.core.exceptions import ValidationError as ModelValidationError from rest_framework_json_api import serializers from rest_framework_json_api.relations import ( @@ -7,6 +8,7 @@ from rest_framework_json_api.utils import get_resource_type_from_model from addon_service.addon_operation.models import AddonOperationModel +from addon_service.authorized_storage_account.callbacks import after_successful_auth from addon_service.common import view_names from addon_service.common.credentials_formats import CredentialsFormats from addon_service.models import ( @@ -15,6 +17,7 @@ ExternalStorageService, UserReference, ) +from addon_service.osf_models.fields import encrypt_string from addon_service.serializer_fields import ( CredentialsField, DataclassRelatedLinkField, @@ -95,13 +98,22 @@ def create(self, validated_data): authorized_account.initiate_oauth2_flow( validated_data.get("authorized_scopes") ) + elif external_service.credentials_format is CredentialsFormats.OAUTH1A: + authorized_account.initiate_oauth1_flow() + self.context["request"].session["oauth1a_account_id"] = encrypt_string( + authorized_account.pk + ) else: authorized_account.credentials = validated_data["credentials"] + try: authorized_account.save() except ModelValidationError as e: raise serializers.ValidationError(e) + if external_service.credentials_format.is_direct_from_user: + async_to_sync(after_successful_auth)(authorized_account) + return authorized_account class Meta: diff --git a/addon_service/common/credentials_formats.py b/addon_service/common/credentials_formats.py index c01260a4..faaafb0c 100644 --- a/addon_service/common/credentials_formats.py +++ b/addon_service/common/credentials_formats.py @@ -1,20 +1,27 @@ -from enum import Enum +from enum import ( + Enum, + unique, +) from addon_toolkit import credentials +@unique class CredentialsFormats(Enum): UNSPECIFIED = 0 OAUTH2 = 1 ACCESS_KEY_SECRET_KEY = 2 USERNAME_PASSWORD = 3 PERSONAL_ACCESS_TOKEN = 4 + OAUTH1A = 5 @property def dataclass(self): match self: case CredentialsFormats.OAUTH2: return credentials.AccessTokenCredentials + case CredentialsFormats.OAUTH1A: + return credentials.OAuth1Credentials case CredentialsFormats.ACCESS_KEY_SECRET_KEY: return credentials.AccessKeySecretKeyCredentials case CredentialsFormats.PERSONAL_ACCESS_TOKEN: @@ -22,3 +29,11 @@ def dataclass(self): case CredentialsFormats.USERNAME_PASSWORD: return credentials.UsernamePasswordCredentials raise ValueError(f"No dataclass support for credentials type {self.name}") + + @property + def is_direct_from_user(self): + return self in { + CredentialsFormats.ACCESS_KEY_SECRET_KEY, + CredentialsFormats.USERNAME_PASSWORD, + CredentialsFormats.PERSONAL_ACCESS_TOKEN, + } diff --git a/addon_service/common/known_imps.py b/addon_service/common/known_imps.py index 25c75866..3b6c2ea9 100644 --- a/addon_service/common/known_imps.py +++ b/addon_service/common/known_imps.py @@ -5,6 +5,7 @@ import enum +from addon_imps.citations import zotero_org from addon_imps.storage import box_dot_com from addon_service.common.enum_decorators import enum_names_same_as from addon_toolkit import AddonImp @@ -54,6 +55,7 @@ class KnownAddonImps(enum.Enum): """Static mapping from API-facing name for an AddonImp to the Imp itself""" BOX_DOT_COM = box_dot_com.BoxDotComStorageImp + ZOTERO_ORG = zotero_org.ZoteroOrgCitationImp if __debug__: BLARG = my_blarg.MyBlargStorage @@ -65,6 +67,7 @@ class AddonImpNumbers(enum.Enum): """Static mapping from each AddonImp name to a unique integer (for database use)""" BOX_DOT_COM = 1001 + ZOTERO_ORG = 1002 if __debug__: BLARG = -7 diff --git a/addon_service/common/network.py b/addon_service/common/network.py index 14c3b6c5..2a1fc417 100644 --- a/addon_service/common/network.py +++ b/addon_service/common/network.py @@ -14,6 +14,7 @@ from addon_service import models as db from addon_service.common import exceptions +from addon_service.common.credentials_formats import CredentialsFormats from addon_toolkit.constrained_network import ( HttpRequestInfo, HttpRequestor, @@ -65,12 +66,12 @@ def __init__( # abstract method from HttpRequestor: @contextlib.asynccontextmanager - async def do_send(self, request: HttpRequestInfo): + async def _do_send(self, request: HttpRequestInfo): try: async with self._try_send(request) as _response: yield _response except exceptions.ExpiredAccessToken: - await _PrivateNetworkInfo.get(self).account.refresh_oauth_access_token() + await _PrivateNetworkInfo.get(self).account.refresh_oauth2_access_token() # if this one fails, don't try refreshing again async with self._try_send(request) as _response: yield _response @@ -86,7 +87,10 @@ async def _try_send(self, request: HttpRequestInfo): headers=await _private.get_headers(), # TODO: content ) as _response: - if _response.status == HTTPStatus.UNAUTHORIZED: + if ( + _response.status == HTTPStatus.UNAUTHORIZED + and _private.account.credentials_format == CredentialsFormats.OAUTH2 + ): # assume unauthorized because of token expiration. # if not, will fail again after refresh (which is fine) raise exceptions.ExpiredAccessToken diff --git a/addon_service/credentials/models.py b/addon_service/credentials/models.py index f37c8e0b..acbb39fc 100644 --- a/addon_service/credentials/models.py +++ b/addon_service/credentials/models.py @@ -94,7 +94,15 @@ def authorized_accounts(self): other types of accounts for the same user could point to the same set of credentials """ try: - return (self.authorized_storage_account,) + return [ + *filter( + bool, + [ + getattr(self, "authorized_storage_account", None), + getattr(self, "temporary_authorized_storage_account", None), + ], + ) + ] except ExternalCredentials.authorized_storage_account.RelatedObjectDoesNotExist: return None diff --git a/addon_service/credentials/serializers.py b/addon_service/credentials/serializers.py index 72c6edf6..8040c48a 100644 --- a/addon_service/credentials/serializers.py +++ b/addon_service/credentials/serializers.py @@ -9,6 +9,7 @@ SUPPORTED_CREDENTIALS_FORMATS = set(CredentialsFormats) - { CredentialsFormats.UNSPECIFIED, CredentialsFormats.OAUTH2, + CredentialsFormats.OAUTH1A, } @@ -17,7 +18,8 @@ def __init__(self, write_only=True, required=False, *args, **kwargs): super().__init__(write_only=write_only, required=required) def to_internal_value(self, data): - if not data: + # this issue still hasn't been fixed on FE, so keeping this for now + if not data or not any(data.values()): return None # consider empty {} same as omitting the field # No access to the credentials format here, so just try all of them for creds_format in SUPPORTED_CREDENTIALS_FORMATS: diff --git a/addon_service/external_storage_service/models.py b/addon_service/external_storage_service/models.py index d9b832c8..3b308d82 100644 --- a/addon_service/external_storage_service/models.py +++ b/addon_service/external_storage_service/models.py @@ -40,6 +40,14 @@ class ExternalStorageService(AddonsServiceBaseModel): # Distinct from `display_name` to avoid over-coupling wb_key = models.CharField(null=False, blank=True, default="") + oauth1_client_config = models.ForeignKey( + "addon_service.OAuth1ClientConfig", + on_delete=models.SET_NULL, + related_name="external_storage_services", + null=True, + blank=True, + ) + oauth2_client_config = models.ForeignKey( "addon_service.OAuth2ClientConfig", on_delete=models.SET_NULL, diff --git a/addon_service/management/commands/do_box_test.py b/addon_service/management/commands/do_box_test.py index f1180ab8..40dd4f9c 100644 --- a/addon_service/management/commands/do_box_test.py +++ b/addon_service/management/commands/do_box_test.py @@ -96,7 +96,7 @@ def _setup_oauth(self, user_uri: str, client_id, client_secret): ) _account.initiate_oauth2_flow() self.stdout.write( - self.style.SUCCESS("set up for oauth! now do the flow in a browser:") + self.style.SUCCESS("set up for oauth2! now do the flow in a browser:") ) self.stdout.write(_account.auth_url) self.stdout.write( diff --git a/addon_service/migrations/0001_initial.py b/addon_service/migrations/0001_initial.py index 1651c456..68f01b77 100644 --- a/addon_service/migrations/0001_initial.py +++ b/addon_service/migrations/0001_initial.py @@ -1,4 +1,4 @@ -# Generated by Django 4.2.7 on 2024-07-01 13:30 +# Generated by Django 4.2.7 on 2024-07-09 13:26 import django.contrib.postgres.fields import django.db.models.deletion @@ -50,6 +50,32 @@ class Migration(migrations.Migration): "verbose_name_plural": "Authorized Storage Accounts", }, ), + migrations.CreateModel( + name="OAuth1ClientConfig", + fields=[ + ( + "id", + addon_service.common.str_uuid_field.StrUUIDField( + default=addon_service.common.str_uuid_field.str_uuid4, + editable=False, + primary_key=True, + serialize=False, + ), + ), + ("created", models.DateTimeField(editable=False)), + ("modified", models.DateTimeField()), + ("request_token_url", models.URLField()), + ("auth_url", models.URLField()), + ("auth_callback_url", models.URLField()), + ("access_token_url", models.URLField()), + ("client_key", models.CharField(null=True)), + ("client_secret", models.CharField(null=True)), + ], + options={ + "verbose_name": "OAuth1 Client Config", + "verbose_name_plural": "OAuth1 Client Configs", + }, + ), migrations.CreateModel( name="OAuth2ClientConfig", fields=[ @@ -206,6 +232,16 @@ class Migration(migrations.Migration): ), ("api_base_url", models.URLField(blank=True, default="")), ("wb_key", models.CharField(blank=True, default="")), + ( + "oauth1_client_config", + models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="external_storage_services", + to="addon_service.oauth1clientconfig", + ), + ), ( "oauth2_client_config", models.ForeignKey( @@ -309,6 +345,16 @@ class Migration(migrations.Migration): to="addon_service.externalcredentials", ), ), + migrations.AddField( + model_name="authorizedstorageaccount", + name="_temporary_oauth1_credentials", + field=models.OneToOneField( + blank=True, + null=True, + on_delete=django.db.models.deletion.CASCADE, + to="addon_service.externalcredentials", + ), + ), migrations.AddField( model_name="authorizedstorageaccount", name="account_owner", diff --git a/addon_service/models.py b/addon_service/models.py index 8c8db3aa..98b2f0cf 100644 --- a/addon_service/models.py +++ b/addon_service/models.py @@ -1,4 +1,4 @@ -""" Import models here so they auto-detect for makemigrations """ +"""Import models here so they auto-detect for makemigrations""" from addon_service.addon_imp.models import AddonImpModel from addon_service.addon_operation.models import AddonOperationModel @@ -7,7 +7,8 @@ from addon_service.configured_storage_addon.models import ConfiguredStorageAddon from addon_service.credentials.models import ExternalCredentials from addon_service.external_storage_service.models import ExternalStorageService -from addon_service.oauth.models import ( +from addon_service.oauth1.models import OAuth1ClientConfig +from addon_service.oauth2.models import ( OAuth2ClientConfig, OAuth2TokenMetadata, ) @@ -25,6 +26,7 @@ "ExternalStorageService", "OAuth2ClientConfig", "OAuth2TokenMetadata", + "OAuth1ClientConfig", "ResourceReference", "UserReference", ) diff --git a/addon_service/oauth1/__init__.py b/addon_service/oauth1/__init__.py new file mode 100644 index 00000000..0e29ee3a --- /dev/null +++ b/addon_service/oauth1/__init__.py @@ -0,0 +1,5 @@ +from . import utils +from .models import OAuth1ClientConfig + + +__all__ = ("OAuth1ClientConfig", "utils") diff --git a/addon_service/oauth1/models.py b/addon_service/oauth1/models.py new file mode 100644 index 00000000..e11fc0e8 --- /dev/null +++ b/addon_service/oauth1/models.py @@ -0,0 +1,33 @@ +from django.db import models + +from addon_service.common.base_model import AddonsServiceBaseModel + + +class OAuth1ClientConfig(AddonsServiceBaseModel): + """ + Model for storing attributes that are required for managing + OAuth1 credentials exchanges with an ExternalService on behalf + of a registered client (e.g. the OSF) + """ + + # URI that allows to obtain temporary request token to proceed with user auth + request_token_url = models.URLField(null=False) + # URI to which user will be redirected to authenticate + auth_url = models.URLField(null=False) + # URI to which user will be redirected after authentication + auth_callback_url = models.URLField(null=False) + # URI to obtain access token + access_token_url = models.URLField(null=False) + + client_key = models.CharField(null=True) + client_secret = models.CharField(null=True) + + class Meta: + verbose_name = "OAuth1 Client Config" + verbose_name_plural = "OAuth1 Client Configs" + app_label = "addon_service" + + def __repr__(self): + return f'<{self.__class__.__qualname__}(pk="{self.pk}", auth_uri="{self.auth_url}, access_token_url="{self.access_token_url}", request_token_url="{self.request_token_url}", client_key="{self.client_key}")>' + + __str__ = __repr__ diff --git a/addon_service/oauth1/utils.py b/addon_service/oauth1/utils.py new file mode 100644 index 00000000..87d34d8d --- /dev/null +++ b/addon_service/oauth1/utils.py @@ -0,0 +1,148 @@ +from base64 import b64encode +from hashlib import sha1 +from hmac import HMAC +from http import ( + HTTPMethod, + HTTPStatus, +) +from secrets import token_urlsafe +from time import time +from urllib.parse import ( + parse_qs, + quote_plus, +) + +from aiohttp import ContentTypeError + +from addon_service.common.aiohttp_session import get_singleton_client_session +from addon_toolkit.credentials import OAuth1Credentials +from addon_toolkit.iri_utils import iri_with_query + + +async def get_temporary_token( + temporary_token_url: str, + oauth_consumer_key: str, + oauth_consumer_secret: str, +) -> tuple[OAuth1Credentials, dict]: + """ + Obtaining unauthorised request token needed to construct Authorization url + https://oauth.net/core/1.0a/#auth_step1 + """ + signed_headers = construct_signed_headers( + temporary_token_url, oauth_consumer_key, oauth_consumer_secret + ) + return await _get_token(temporary_token_url, signed_headers) + + +async def get_access_token( + access_token_url: str, + oauth_consumer_key: str, + oauth_consumer_secret: str, + oauth_token: str, + oauth_token_secret: str, + oauth_verifier: str, +) -> tuple[OAuth1Credentials, dict]: + """ + Getting final access token needed to access protected resources from Service provider + + """ + signed_headers = construct_signed_headers( + access_token_url, + oauth_consumer_key, + oauth_consumer_secret, + oauth_token=oauth_token, + oauth_verifier=oauth_verifier, + oauth_token_secret=oauth_token_secret, + ) + return await _get_token(access_token_url, signed_headers) + + +async def _get_token( + request_token_url, signed_headers +) -> tuple[OAuth1Credentials, dict]: + _client = await get_singleton_client_session() + async with _client.post( + request_token_url, headers=signed_headers + ) as _token_response: + if not HTTPStatus(_token_response.status).is_success: + raise RuntimeError(await _token_response.text()) + try: + return OAuth1Credentials.from_dict(await _token_response.json()) + except ContentTypeError: + raw_result = parse_qs(await _token_response.text()) + result_dict = {key: value[0] for key, value in raw_result.items() if value} + return OAuth1Credentials.from_dict(result_dict) + + +def _construct_params(params_to_encode: dict) -> str: + return ",".join( + f'{key}="{value}"' for key, value in sorted(params_to_encode.items()) + ) + + +def construct_signed_headers( + url: str, + oauth_consumer_key: str, + oauth_consumer_secret: str, + http_method: HTTPMethod = HTTPMethod.POST, + oauth_token: str | None = None, + oauth_token_secret: str | None = None, + oauth_verifier: str | None = None, +): + oauth_params = construct_headers(oauth_consumer_key, oauth_token, oauth_verifier) + signature = generate_signature( + http_method, url, oauth_params, oauth_consumer_secret, oauth_token_secret + ) + oauth_params |= {"oauth_signature": signature} + return {"Authorization": f"OAuth {_construct_params(oauth_params)}"} + + +def generate_signature( + http_method: HTTPMethod, + url: str, + headers: dict, + oauth_consumer_secret: str, + oauth_token_secret: str | None = None, +) -> str: + params_str = "&".join(f"{key}={value}" for key, value in sorted(headers.items())) + signature_base = f"{http_method}&{quote_plus(url)}&{quote_plus(params_str)}" + key = f"{oauth_consumer_secret}&{oauth_token_secret or ''}" + hmac = HMAC(key.encode(), signature_base.encode(), sha1).digest() + return quote_plus(b64encode(hmac)) + + +def construct_headers( + oauth_consumer_key: str, + oauth_token: str | None = None, + oauth_verifier: str | None = None, +) -> dict[str, str]: + initial_payload = { + "oauth_consumer_key": oauth_consumer_key, + "oauth_signature_method": "HMAC-SHA1", + "oauth_timestamp": f"{int(time())}", + "oauth_nonce": generate_nonce(32), + "oauth_version": "1.0", + "oauth_token": oauth_token, + "oauth_verifier": oauth_verifier, + } + + return {key: value for key, value in initial_payload.items() if value} + + +def build_auth_url( + *, + auth_uri: str, + temporary_oauth_token: str, +) -> str: + """build a URL that will initiate authorization when visited by a user + + see https://www.rfc-editor.org/rfc/rfc6749.html#section-4.1.1 + """ + query_params = { + "oauth_token": temporary_oauth_token, + } + return iri_with_query(auth_uri, query_params) + + +def generate_nonce(nonce_length: int = 16): + return token_urlsafe(nonce_length) diff --git a/addon_service/oauth1/views.py b/addon_service/oauth1/views.py new file mode 100644 index 00000000..6da1aa43 --- /dev/null +++ b/addon_service/oauth1/views.py @@ -0,0 +1,33 @@ +from http import HTTPStatus + +from asgiref.sync import async_to_sync +from django.http import HttpResponse + +from addon_service.authorized_storage_account.callbacks import after_successful_auth +from addon_service.authorized_storage_account.models import AuthorizedStorageAccount +from addon_service.oauth1.utils import get_access_token +from addon_service.osf_models.fields import decrypt_string + + +def oauth1_callback_view(request): + oauth_token = request.GET["oauth_token"] + oauth_verifier = request.GET["oauth_verifier"] + + pk = decrypt_string(request.session.get("oauth1a_account_id")) + del request.session["oauth1a_account_id"] + + account = AuthorizedStorageAccount.objects.get(pk=pk) + + oauth1_client_config = account.external_service.oauth1_client_config + final_credentials, other_info = async_to_sync(get_access_token)( + access_token_url=oauth1_client_config.access_token_url, + oauth_consumer_key=oauth1_client_config.client_key, + oauth_consumer_secret=oauth1_client_config.client_secret, + oauth_token=oauth_token, + oauth_token_secret=account.temporary_oauth1_credentials.oauth_token_secret, + oauth_verifier=oauth_verifier, + ) + account.credentials = final_credentials + account.save() + async_to_sync(after_successful_auth)(account, other_info) + return HttpResponse(status=HTTPStatus.OK) # TODO: redirect diff --git a/addon_service/oauth/__init__.py b/addon_service/oauth2/__init__.py similarity index 100% rename from addon_service/oauth/__init__.py rename to addon_service/oauth2/__init__.py diff --git a/addon_service/oauth/models.py b/addon_service/oauth2/models.py similarity index 100% rename from addon_service/oauth/models.py rename to addon_service/oauth2/models.py diff --git a/addon_service/oauth/utils.py b/addon_service/oauth2/utils.py similarity index 97% rename from addon_service/oauth/utils.py rename to addon_service/oauth2/utils.py index 19bf8ab1..e4eb64e0 100644 --- a/addon_service/oauth/utils.py +++ b/addon_service/oauth2/utils.py @@ -117,7 +117,7 @@ async def _token_request( async with _client.post(token_endpoint_url, data=request_body) as _token_response: if not HTTPStatus(_token_response.status).is_success: raise RuntimeError(await _token_response.json()) - raise RuntimeError # TODO: https://www.rfc-editor.org/rfc/rfc6749.html#section-5.2 + # TODO: https://www.rfc-editor.org/rfc/rfc6749.html#section-5.2 return FreshTokenResult.from_token_response_json(await _token_response.json()) diff --git a/addon_service/oauth/views.py b/addon_service/oauth2/views.py similarity index 68% rename from addon_service/oauth/views.py rename to addon_service/oauth2/views.py index 437bb34d..4a42254f 100644 --- a/addon_service/oauth/views.py +++ b/addon_service/oauth2/views.py @@ -1,16 +1,16 @@ +import asyncio from http import HTTPStatus from asgiref.sync import sync_to_async from django.db import transaction from django.http import HttpResponse -from addon_service.common.aiohttp_session import get_singleton_client_session -from addon_service.common.network import GravyvaletHttpRequestor +from addon_service.authorized_storage_account.callbacks import after_successful_auth from addon_service.models import ( OAuth2ClientConfig, OAuth2TokenMetadata, ) -from addon_service.oauth.utils import get_initial_access_token +from addon_service.oauth2.utils import get_initial_access_token @transaction.non_atomic_requests # async views and ATOMIC_REQUESTS do not mix @@ -33,7 +33,7 @@ async def oauth2_callback_view(request): client_secret=_oauth_client_config.client_secret, ) _accounts = await _token_metadata.update_with_fresh_token(_fresh_token_result) - await _update_external_account_ids(_accounts) + await asyncio.gather(*[after_successful_auth(_account) for _account in _accounts]) return HttpResponse(status=HTTPStatus.OK) # TODO: redirect @@ -47,15 +47,3 @@ def _resolve_state_token( ) -> tuple[OAuth2TokenMetadata, OAuth2ClientConfig]: _token_metadata = OAuth2TokenMetadata.objects.get_by_state_token(state_token) return (_token_metadata, _token_metadata.client_details) - - -async def _update_external_account_ids(accounts): - for _account in accounts: - _account.external_account_id = await _account.imp_cls.get_external_account_id( - network=GravyvaletHttpRequestor( - client_session=await get_singleton_client_session(), - prefix_url=_account.external_service.api_base_url, - account=_account, - ), - ) - await sync_to_async(_account.save)() diff --git a/addon_service/tasks/invocation.py b/addon_service/tasks/invocation.py index 7a39c47e..1363bb67 100644 --- a/addon_service/tasks/invocation.py +++ b/addon_service/tasks/invocation.py @@ -2,7 +2,7 @@ from asgiref.sync import sync_to_async from django.db import transaction -from addon_service.addon_imp.instantiation import get_storage_addon_instance +from addon_service.addon_imp.instantiation import get_storage_addon_instance__blocking from addon_service.common.dibs import dibs from addon_service.common.invocation_status import InvocationStatus from addon_service.models import AddonOperationInvocation @@ -20,7 +20,7 @@ def perform_invocation__blocking(invocation: AddonOperationInvocation) -> None: # implemented as a sync function for django transactions with dibs(invocation): # TODO: handle dibs errors try: - _imp = get_storage_addon_instance( + _imp = get_storage_addon_instance__blocking( invocation.imp_cls, # type: ignore[arg-type] #(TODO: generic impstantiation) invocation.thru_account, invocation.storage_imp_config(), diff --git a/addon_service/tests/_factories.py b/addon_service/tests/_factories.py index a0e7ee6f..73e20784 100644 --- a/addon_service/tests/_factories.py +++ b/addon_service/tests/_factories.py @@ -36,6 +36,18 @@ class Meta: client_secret = factory.Faker("word") +class OAuth1ClientConfigFactory(DjangoModelFactory): + class Meta: + model = db.OAuth1ClientConfig + + auth_url = "https://api.example/auth/" + auth_callback_url = "https://api.example/auth/" + access_token_url = "https://osf.example/oauth/access" + request_token_url = "https://api.example.com/oauth/request" + client_key = factory.Faker("word") + client_secret = factory.Faker("word") + + class AddonOperationInvocationFactory(DjangoModelFactory): class Meta: model = db.AddonOperationInvocation @@ -64,7 +76,6 @@ class Meta: max_concurrent_downloads = factory.Faker("pyint") max_upload_mb = factory.Faker("pyint") int_addon_imp = known_imps.get_imp_number(known_imps.get_imp_by_name("BLARG")) - oauth2_client_config = factory.SubFactory(OAuth2ClientConfigFactory) supported_scopes = ["service.url/grant_all"] @classmethod @@ -89,6 +100,16 @@ def _create( ) +class ExternalStorageOAuth2ServiceFactory(ExternalStorageServiceFactory): + credentials_format = CredentialsFormats.OAUTH2 + oauth2_client_config = factory.SubFactory(OAuth2ClientConfigFactory) + + +class ExternalStorageOAuth1ServiceFactory(ExternalStorageServiceFactory): + credentials_format = CredentialsFormats.OAUTH1A + oauth1_client_config = factory.SubFactory(OAuth1ClientConfigFactory) + + class AuthorizedStorageAccountFactory(DjangoModelFactory): class Meta: model = db.AuthorizedStorageAccount @@ -112,7 +133,9 @@ def _create( account = super()._create( model_class=model_class, external_storage_service=external_storage_service - or ExternalStorageServiceFactory(credentials_format=credentials_format), + or ExternalStorageOAuth2ServiceFactory( + credentials_format=credentials_format + ), account_owner=account_owner or UserReferenceFactory(), *args, **kwargs, diff --git a/addon_service/tests/_helpers.py b/addon_service/tests/_helpers.py index 920f9830..2c8dfd59 100644 --- a/addon_service/tests/_helpers.py +++ b/addon_service/tests/_helpers.py @@ -3,8 +3,14 @@ import secrets from collections import defaultdict from http import HTTPStatus -from typing import Any -from unittest.mock import patch +from typing import ( + TYPE_CHECKING, + Any, +) +from unittest.mock import ( + AsyncMock, + patch, +) from urllib.parse import ( parse_qs, urlparse, @@ -20,6 +26,10 @@ from addon_service.common.aiohttp_session import get_singleton_client_session +if TYPE_CHECKING: + from addon_service.external_storage_service import ExternalStorageService + + class MockOSF: _configured_caller_uri: str | None = None _permissions: dict[str, dict[str, str | bool]] @@ -45,13 +55,17 @@ def __init__(self, permissions=None): @contextlib.contextmanager def mocking(self): - with patch( - "addon_service.authentication.GVCombinedAuthentication.authenticate", - side_effect=self._mock_user_check, - ), patch( - "addon_service.common.osf.has_osf_permission_on_resource", - side_effect=self._mock_resource_check, - ), patch_encryption_key_derivation(): + with ( + patch( + "addon_service.authentication.GVCombinedAuthentication.authenticate", + side_effect=self._mock_user_check, + ), + patch( + "addon_service.common.osf.has_osf_permission_on_resource", + side_effect=self._mock_resource_check, + ), + patch_encryption_key_derivation(), + ): yield self def configure_assumed_caller(self, caller_uri): @@ -100,7 +114,7 @@ def _mock_resource_check(self, request, uri, required_permission, *args, **kwarg return bool(required_permission.lower() in permissions) -class MockExternalService: +class MockOAuth2ExternalService: def __init__(self, external_service): self._static_access_token = None self._static_refresh_token = None @@ -159,6 +173,70 @@ async def _route_post(self, url, *args, **kwargs): raise RuntimeError(f"Received unrecognized endpoint {url}") +@dataclasses.dataclass +class MockOAuth1ServiceProvider: + _external_service: "ExternalStorageService" + _static_request_token: str + _static_request_secret: str + _static_verifier: str + _static_oauth_token: str + _static_oauth_secret: str + + def __post_init__(self): + if self._external_service.oauth1_client_config is not None: + self._access_token_url = ( + self._external_service.oauth1_client_config.access_token_url + ) + self._request_token_url = ( + self._external_service.oauth1_client_config.request_token_url + ) + + @property + def auth_url(self): + return self._external_service.auth_url + + def set_internal_client(self, client): + """Attach a DRF APIClient for making requests internally""" + self._internal_client = client + + @contextlib.contextmanager + def mocking(self): + with patch( + "addon_service.oauth1.utils.get_singleton_client_session", + AsyncMock(return_value=AsyncMock(post=self._route_post)), + ): + yield self + + def initiate_oauth_exchange(self): + self._internal_client.get( + reverse("oauth1-callback"), + {"oauth_token": "oauth_token", "oauth_verifier": "oauth_verifier"}, + ) + return _FakeAiohttpResponse() + + @contextlib.asynccontextmanager + async def _route_post(self, url, *args, **kwargs): + if url.startswith(self._access_token_url): + yield _FakeAiohttpResponse( + status=HTTPStatus.CREATED, + data={ + "oauth_token": self._static_oauth_token, + "oauth_token_secret": self._static_oauth_secret, + }, + ) + elif url.startswith(self._request_token_url): + yield _FakeAiohttpResponse( + status=HTTPStatus.CREATED, + data={ + "oauth_token": self._static_request_token, + "oauth_token_secret": self._static_request_secret, + "oauth_verifier": self._static_verifier, + }, + ) + else: + raise RuntimeError(f"Received unrecognized endpoint {url}") + + @dataclasses.dataclass class _FakeAiohttpResponse: status: HTTPStatus = HTTPStatus.OK diff --git a/addon_service/tests/e2e_tests/test_oauth_flow.py b/addon_service/tests/e2e_tests/test_oauth_flow.py index b2dd8678..c0beda6f 100644 --- a/addon_service/tests/e2e_tests/test_oauth_flow.py +++ b/addon_service/tests/e2e_tests/test_oauth_flow.py @@ -1,3 +1,5 @@ +import secrets + from asgiref.sync import async_to_sync from django.conf import settings from django.urls import reverse @@ -16,10 +18,6 @@ from addon_toolkit import AddonCapabilities -MOCK_ACCESS_TOKEN = "access" -MOCK_REFRESH_TOKEN = "refresh" - - def _make_post_payload(*, external_service, capabilities=None, credentials=None): return { "data": { @@ -40,17 +38,20 @@ def _make_post_payload(*, external_service, capabilities=None, credentials=None) class TestOAuth2Flow(APITestCase): + MOCK_ACCESS_TOKEN = "access" + MOCK_REFRESH_TOKEN = "refresh" + @classmethod def setUpTestData(cls): cls._user = _factories.UserReferenceFactory() - cls._service = _factories.ExternalStorageServiceFactory() + cls._service = _factories.ExternalStorageOAuth2ServiceFactory() def setUp(self): super().setUp() self.addCleanup(close_singleton_client_session__blocking) - self._mock_service = _helpers.MockExternalService(self._service) + self._mock_service = _helpers.MockOAuth2ExternalService(self._service) self._mock_service.configure_static_tokens( - access=MOCK_ACCESS_TOKEN, refresh=MOCK_REFRESH_TOKEN + access=self.MOCK_ACCESS_TOKEN, refresh=self.MOCK_REFRESH_TOKEN ) self.client.cookies[settings.USER_REFERENCE_COOKIE] = self._user.user_uri @@ -82,13 +83,77 @@ def test_oauth_account_setup(self): _account.refresh_from_db() with self.subTest("Credentials set post-exchange"): - self.assertEqual(_account.credentials.access_token, MOCK_ACCESS_TOKEN) + self.assertEqual(_account.credentials.access_token, self.MOCK_ACCESS_TOKEN) with self.subTest("Refresh token set post-exchange"): self.assertEqual( - _account.oauth2_token_metadata.refresh_token, MOCK_REFRESH_TOKEN + _account.oauth2_token_metadata.refresh_token, self.MOCK_REFRESH_TOKEN ) async def _get(self, _account: AuthorizedStorageAccount): aiohttp_client_session = await get_singleton_client_session() async with self._mock_service.mocking(): return await aiohttp_client_session.get(_account.auth_url) + + +class TestOAuth1AFlow(APITestCase): + MOCK_REQUEST_TOKEN = secrets.token_hex(12) + MOCK_REQUEST_TOKEN_SECRET = secrets.token_hex(12) + MOCK_ACCESS_TOKEN = secrets.token_hex(12) + MOCK_ACCESS_TOKEN_SECRET = secrets.token_hex(12) + MOCK_VERIFIER = secrets.token_hex(12) + + @classmethod + def setUpTestData(cls): + cls._user = _factories.UserReferenceFactory() + cls._service = _factories.ExternalStorageOAuth1ServiceFactory() + + def setUp(self): + super().setUp() + self.addCleanup(close_singleton_client_session__blocking) + self._mock_service = _helpers.MockOAuth1ServiceProvider( + _external_service=self._service, + _static_request_token=self.MOCK_REQUEST_TOKEN, + _static_request_secret=self.MOCK_REQUEST_TOKEN_SECRET, + _static_oauth_token=self.MOCK_ACCESS_TOKEN, + _static_oauth_secret=self.MOCK_ACCESS_TOKEN_SECRET, + _static_verifier=self.MOCK_VERIFIER, + ) + + self.client.cookies[settings.USER_REFERENCE_COOKIE] = self._user.user_uri + self._mock_osf = _helpers.MockOSF() + self.enterContext(self._mock_osf.mocking()) + self.enterContext(self._mock_service.mocking()) + + def test_oauth_account_setup(self): + with self.subTest("Preconditions"): + self.assertEqual( + self._service.credentials_format, CredentialsFormats.OAUTH1A + ) + + # Set up Account + _resp = self.client.post( + reverse("authorized-storage-accounts-list"), + _make_post_payload(external_service=self._service), + format="vnd.api+json", + ) + _account = AuthorizedStorageAccount.objects.get(id=_resp.data["id"]) + + with self.subTest("Account Initial Conditions"): + self.assertIsNotNone(_account.temporary_oauth1_credentials) + self.assertIsNone(_account.credentials) + + self._mock_service.set_internal_client(self.client) + self._mock_service.initiate_oauth_exchange() + + _account.refresh_from_db() + with self.subTest("Credentials set post-exchange"): + + self.assertIsNone(_account.temporary_oauth1_credentials) + self.assertEqual(_account.credentials.oauth_token, self.MOCK_ACCESS_TOKEN) + self.assertEqual( + _account.credentials.oauth_token_secret, self.MOCK_ACCESS_TOKEN_SECRET + ) + + async def _get(self, _account: AuthorizedStorageAccount): + aiohttp_client_session = await get_singleton_client_session() + return await aiohttp_client_session.get(_account.auth_url) diff --git a/addon_service/tests/test_by_type/test_authorized_storage_account.py b/addon_service/tests/test_by_type/test_authorized_storage_account.py index 7076e2f6..ef68ec80 100644 --- a/addon_service/tests/test_by_type/test_authorized_storage_account.py +++ b/addon_service/tests/test_by_type/test_authorized_storage_account.py @@ -28,7 +28,10 @@ VALID_CREDENTIALS_FORMATS = set(CredentialsFormats) - {CredentialsFormats.UNSPECIFIED} -NON_OAUTH_FORMATS = VALID_CREDENTIALS_FORMATS - {CredentialsFormats.OAUTH2} +NON_OAUTH_FORMATS = VALID_CREDENTIALS_FORMATS - { + CredentialsFormats.OAUTH2, + CredentialsFormats.OAUTH1A, +} MOCK_CREDENTIALS = { CredentialsFormats.OAUTH2: None, @@ -114,7 +117,7 @@ def test_get_detail(self): ) def test_post(self): - external_service = _factories.ExternalStorageServiceFactory() + external_service = _factories.ExternalStorageOAuth2ServiceFactory() self.assertFalse(external_service.authorized_storage_accounts.exists()) _resp = self.client.post( @@ -132,7 +135,7 @@ def test_post(self): def test_post__sets_credentials(self): for creds_format in NON_OAUTH_FORMATS: - external_service = _factories.ExternalStorageServiceFactory() + external_service = _factories.ExternalStorageOAuth2ServiceFactory() external_service.int_credentials_format = creds_format.value external_service.save() @@ -151,7 +154,7 @@ def test_post__sets_credentials(self): ) def test_post__sets_auth_url(self): - external_service = _factories.ExternalStorageServiceFactory( + external_service = _factories.ExternalStorageOAuth2ServiceFactory( credentials_format=CredentialsFormats.OAUTH2 ) @@ -167,7 +170,7 @@ def test_post__sets_auth_url(self): def tet_post__does_not_set_auth_url(self): for creds_format in NON_OAUTH_FORMATS: with self.subTest(creds_format=creds_format): - external_service = _factories.ExternalStorageServiceFactory( + external_service = _factories.ExternalStorageOAuth2ServiceFactory( credentials_format=creds_format ) @@ -186,7 +189,7 @@ def test_post__api_base_url__success(self): ServiceTypes.PUBLIC | ServiceTypes.HOSTED, ]: with self.subTest(service_type=service_type): - service = _factories.ExternalStorageServiceFactory( + service = _factories.ExternalStorageOAuth2ServiceFactory( service_type=service_type ) _resp = self.client.post( @@ -205,7 +208,7 @@ def test_post__api_base_url__success(self): self.assertTrue(account._api_base_url) def test_post__api_base_url__invalid__required(self): - service = _factories.ExternalStorageServiceFactory( + service = _factories.ExternalStorageOAuth2ServiceFactory( service_type=ServiceTypes.HOSTED ) service.api_base_url = "" @@ -219,7 +222,7 @@ def test_post__api_base_url__invalid__required(self): self.assertEqual(_resp.status_code, 400) def test_post__api_base_url__invalid__unsupported(self): - service = _factories.ExternalStorageServiceFactory( + service = _factories.ExternalStorageOAuth2ServiceFactory( service_type=ServiceTypes.PUBLIC ) _resp = self.client.post( @@ -232,7 +235,7 @@ def test_post__api_base_url__invalid__unsupported(self): self.assertEqual(_resp.status_code, 400) def test_post__api_base_url__invalid__bad_url(self): - service = _factories.ExternalStorageServiceFactory( + service = _factories.ExternalStorageOAuth2ServiceFactory( service_type=ServiceTypes.HOSTED ) _resp = self.client.post( @@ -355,7 +358,7 @@ def test_auth_url__no_active_state_token(self): def test_initiate_oauth2_flow(self): account = db.AuthorizedStorageAccount.objects.create( - external_storage_service=_factories.ExternalStorageServiceFactory( + external_storage_service=_factories.ExternalStorageOAuth2ServiceFactory( credentials_format=CredentialsFormats.OAUTH2 ), account_owner=self._user, @@ -384,7 +387,7 @@ def test_set_credentials__oauth__fails_if_state_token_exists(self): def test_set_credentials__create(self): for creds_format in NON_OAUTH_FORMATS: - external_service = _factories.ExternalStorageServiceFactory( + external_service = _factories.ExternalStorageOAuth2ServiceFactory( credentials_format=creds_format ) account = db.AuthorizedStorageAccount( diff --git a/addon_service/tests/test_by_type/test_external_storage_service.py b/addon_service/tests/test_by_type/test_external_storage_service.py index abaeac96..24d128b6 100644 --- a/addon_service/tests/test_by_type/test_external_storage_service.py +++ b/addon_service/tests/test_by_type/test_external_storage_service.py @@ -15,7 +15,7 @@ class TestExternalStorageServiceAPI(APITestCase): @classmethod def setUpTestData(cls): - cls._ess = _factories.ExternalStorageServiceFactory() + cls._ess = _factories.ExternalStorageOAuth2ServiceFactory() @property def _detail_path(self): @@ -58,7 +58,7 @@ def test_methods_not_allowed(self): class TestExternalStorageServiceModel(TestCase): @classmethod def setUpTestData(cls): - cls._ess = _factories.ExternalStorageServiceFactory() + cls._ess = _factories.ExternalStorageOAuth2ServiceFactory() def test_can_load(self): _resource_from_db = db.ExternalStorageService.objects.get(id=self._ess.id) @@ -83,19 +83,19 @@ def test_authorized_storage_accounts__several(self): ) def test_validation__invalid_format(self): - service = _factories.ExternalStorageServiceFactory() + service = _factories.ExternalStorageOAuth2ServiceFactory() service.int_credentials_format = -1 with self.assertRaises(ValidationError): service.save() def test_validation__unsupported_format(self): - service = _factories.ExternalStorageServiceFactory() + service = _factories.ExternalStorageOAuth2ServiceFactory() service.int_credentials_format = CredentialsFormats.UNSPECIFIED.value with self.assertRaises(ValidationError): service.save() def test_validation__oauth_creds_require_client_config(self): - service = _factories.ExternalStorageServiceFactory( + service = _factories.ExternalStorageOAuth2ServiceFactory( credentials_format=CredentialsFormats.OAUTH2 ) service.oauth2_client_config = None @@ -107,7 +107,7 @@ def test_validation__oauth_creds_require_client_config(self): class TestExternalStorageServiceViewSet(TestCase): @classmethod def setUpTestData(cls): - cls._ess = _factories.ExternalStorageServiceFactory() + cls._ess = _factories.ExternalStorageOAuth2ServiceFactory() cls._view = ExternalStorageServiceViewSet.as_view({"get": "retrieve"}) cls._user = _factories.UserReferenceFactory() @@ -158,7 +158,7 @@ def test_wrong_user(self): class TestExternalStorageServiceRelatedView(TestCase): @classmethod def setUpTestData(cls): - cls._ess = _factories.ExternalStorageServiceFactory() + cls._ess = _factories.ExternalStorageOAuth2ServiceFactory() cls._related_view = ExternalStorageServiceViewSet.as_view( {"get": "retrieve_related"}, ) diff --git a/addon_service/tests/test_hmac_api_auth.py b/addon_service/tests/test_hmac_api_auth.py index 2a1f1b2e..b44ece1d 100644 --- a/addon_service/tests/test_hmac_api_auth.py +++ b/addon_service/tests/test_hmac_api_auth.py @@ -171,7 +171,7 @@ class TestHmacApiAuth(APITestCase): def setUpTestData(cls): cls._user = _factories.UserReferenceFactory() cls._resource = _factories.ResourceReferenceFactory() - cls._service = _factories.ExternalStorageServiceFactory() + cls._service = _factories.ExternalStorageOAuth2ServiceFactory() cls._account = _factories.AuthorizedStorageAccountFactory( account_owner=cls._user, external_storage_service=cls._service, diff --git a/addon_service/urls.py b/addon_service/urls.py index f34ad49f..04484cd4 100644 --- a/addon_service/urls.py +++ b/addon_service/urls.py @@ -64,6 +64,7 @@ def _register_viewset(viewset): urlpatterns = [ *_router.urls, - path(r"oauth/callback/", views.oauth2_callback_view, name="oauth2-callback"), + path(r"oauth2/callback/", views.oauth2_callback_view, name="oauth2-callback"), + path(r"oauth1/callback/", views.oauth1_callback_view, name="oauth1-callback"), path(r"status/", views.status, name="status"), ] diff --git a/addon_service/views.py b/addon_service/views.py index 0a0893fa..dad5ab79 100644 --- a/addon_service/views.py +++ b/addon_service/views.py @@ -15,7 +15,8 @@ ) from addon_service.configured_storage_addon.views import ConfiguredStorageAddonViewSet from addon_service.external_storage_service.views import ExternalStorageServiceViewSet -from addon_service.oauth.views import oauth2_callback_view +from addon_service.oauth1.views import oauth1_callback_view +from addon_service.oauth2.views import oauth2_callback_view from addon_service.resource_reference.views import ResourceReferenceViewSet from addon_service.user_reference.views import UserReferenceViewSet @@ -38,5 +39,6 @@ async def status(request): "ResourceReferenceViewSet", "UserReferenceViewSet", "oauth2_callback_view", + "oauth1_callback_view", "status", ) diff --git a/addon_toolkit/constrained_network/http.py b/addon_toolkit/constrained_network/http.py index 1ad34ffd..7a9b8ca7 100644 --- a/addon_toolkit/constrained_network/http.py +++ b/addon_toolkit/constrained_network/http.py @@ -61,7 +61,7 @@ class HttpRequestor(typing.Protocol): def response_info_cls(self) -> type[HttpResponseInfo]: ... # abstract method for subclasses - def do_send( + def _do_send( self, request: HttpRequestInfo ) -> contextlib.AbstractAsyncContextManager[HttpResponseInfo]: ... @@ -79,7 +79,7 @@ async def request( query=(query if isinstance(query, Multidict) else Multidict(query)), headers=(headers if isinstance(headers, Multidict) else Multidict(headers)), ) - async with self.do_send(_request_info) as _response: + async with self._do_send(_request_info) as _response: yield _response # TODO: streaming send/receive (only if/when needed) diff --git a/addon_toolkit/credentials.py b/addon_toolkit/credentials.py index c1f78bf9..4b89b40c 100644 --- a/addon_toolkit/credentials.py +++ b/addon_toolkit/credentials.py @@ -18,7 +18,7 @@ class AccessTokenCredentials(Credentials): access_token: str def iter_headers(self): - yield ("Authorization", f"Bearer {self.access_token}") + yield "Authorization", f"Bearer {self.access_token}" @dataclasses.dataclass(frozen=True, kw_only=True) @@ -27,6 +27,36 @@ class AccessKeySecretKeyCredentials(Credentials): secret_key: str +@dataclasses.dataclass(frozen=True, slots=True) +class OAuth1Credentials(Credentials): + oauth_token: str + oauth_token_secret: str + + def iter_headers(self) -> typing.Iterator[tuple[str, str]]: + """ + This is Zotero specific as other OAuth1.0a clients require request signing, + as per current architecture, we cannot it here. + """ + + yield "Authorization", f"Bearer {self.oauth_token_secret}" + # TODO: implement request signing for OAuth1.0a services that require it + + @classmethod + def from_dict(cls, payload: dict) -> "tuple[OAuth1Credentials, dict]": + """ + This method returns credentials constructed dict and dict with other attributes, + which may contain provider-specific useful info + """ + + return ( + OAuth1Credentials( + oauth_token=payload.pop("oauth_token"), + oauth_token_secret=payload.pop("oauth_token_secret"), + ), + payload, + ) + + @dataclasses.dataclass(frozen=True, kw_only=True) class UsernamePasswordCredentials(Credentials): username: str diff --git a/addon_toolkit/imp.py b/addon_toolkit/imp.py index 3373a1ea..e4b64ea7 100644 --- a/addon_toolkit/imp.py +++ b/addon_toolkit/imp.py @@ -10,7 +10,6 @@ from . import exceptions from .addon_operation_declaration import AddonOperationDeclaration from .capabilities import AddonCapabilities -from .constrained_network import HttpRequestor from .interfaces import AddonInterface from .json_arguments import kwargs_from_json @@ -79,11 +78,6 @@ def get_operation_declaration( raise exceptions.OperationNotImplemented(cls, _operation) return _operation - @classmethod - async def get_external_account_id(cls, network: HttpRequestor) -> str: - """to be implemented by addons which require an external account id""" - return "" - ### # instance methods @@ -99,3 +93,7 @@ async def invoke_operation( return _result invoke_operation__blocking = async_to_sync(invoke_operation) + + async def get_external_account_id(self, auth_result_extras: dict[str, str]) -> str: + """to be implemented by addons which require an external account id""" + return ""