Skip to content

Commit

Permalink
Merge pull request #470 from DagsHub/enhancement/user-api
Browse files Browse the repository at this point in the history
Add a UserAPI class
  • Loading branch information
kbolashev authored Apr 25, 2024
2 parents 0143c86 + 9ec07bf commit c4c82bc
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 23 deletions.
18 changes: 0 additions & 18 deletions dagshub/auth/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,21 +407,3 @@ def add_oauth_token(host: Optional[str] = None, referrer: Optional[str] = None,
host = host or config.host
token = oauth.oauth_flow(host, referrer=referrer)
_get_token_storage(**kwargs).add_token(token, host, skip_validation=True)


def get_user_of_token(token: Union[str, DagshubTokenABC], host: Optional[str] = None) -> str:
"""
Returns the username of the user with the token
"""
host = host or config.host
check_url = multi_urljoin(host, "api/v1/user")
if type(token) is str:
auth = HTTPBearerAuth(token)
else:
auth = token
resp = http_request("GET", check_url, auth=auth)

if resp.status_code == 200:
return resp.json()["login"]
else:
raise RuntimeError(f"Got HTTP status {resp.status_code} while trying to get user: {resp.content}")
5 changes: 2 additions & 3 deletions dagshub/colab/login.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from dagshub.auth import add_oauth_token, get_token
from dagshub.auth.tokens import get_user_of_token
from dagshub.common.api import RepoAPI
from dagshub.common.api import RepoAPI, UserAPI
from dagshub.common.api.repo import RepoNotFoundError
from dagshub.upload import create_repo

Expand All @@ -20,7 +19,7 @@ def login() -> str:
add_oauth_token(referrer="colab")
token = get_token()

username = get_user_of_token(token)
username = UserAPI.get_user_from_token(token).username

colab_repo = RepoAPI(f"{username}/{COLAB_REPO_NAME}")
try:
Expand Down
5 changes: 3 additions & 2 deletions dagshub/common/api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from dagshub.common.api.repo import RepoAPI
from .repo import RepoAPI
from .user import UserAPI

__all__ = [RepoAPI.__name__]
__all__ = [RepoAPI.__name__, UserAPI.__name__]
75 changes: 75 additions & 0 deletions dagshub/common/api/user.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from functools import cached_property
from typing import Optional, Any, Union

import dacite

from dagshub.auth import get_authenticator, get_token
from dagshub.auth.token_auth import HTTPBearerAuth
from dagshub.common import config
from dagshub.common.api.responses import UserAPIResponse
from dagshub.common.helpers import http_request
from dagshub.common.util import multi_urljoin


class UserNotFoundError(Exception):
pass


class UserAPI:
def __init__(self, user: Union[str, UserAPIResponse], host: Optional[str] = None, auth: Optional[Any] = None):
self._user_info: Optional[UserAPIResponse] = None
if isinstance(user, UserAPIResponse):
self._user_info = user
self._username = user.username
else:
self._username = user
self.host = host if host is not None else config.host

if auth is None:
self.auth = get_authenticator(host=host)
else:
self.auth = auth

@staticmethod
def get_user_from_token(token_or_authenticator: Union[str, Any], host: Optional[str] = None) -> "UserAPI":
if host is None:
host = config.host
user_url = multi_urljoin(host, "api/v1/user")
if isinstance(token_or_authenticator, str):
auth = HTTPBearerAuth(token_or_authenticator)
else:
auth = token_or_authenticator
resp = http_request("GET", user_url, auth=auth)
if resp.status_code == 404:
raise UserNotFoundError
if resp.status_code != 200:
raise RuntimeError(f"Got HTTP status {resp.status_code} while trying to get user: {resp.content}")
user_info = dacite.from_dict(UserAPIResponse, resp.json())
return UserAPI(user=user_info, host=host, auth=auth)

@staticmethod
def get_current_user(host: Optional[str] = None) -> "UserAPI":
return UserAPI.get_user_from_token(get_token(host=host), host=host)

@property
def username(self) -> str:
return self.user_info.username

@property
def user_id(self) -> int:
return self.user_info.id

@cached_property
def user_info(self) -> UserAPIResponse:
if self._user_info is not None:
return self._user_info
user_url = multi_urljoin(self.host, "api/v1/users", self._username)
resp = http_request("GET", user_url, auth=self.auth)
if resp.status_code == 404:
raise UserNotFoundError
if resp.status_code != 200:
raise RuntimeError(
f"Got HTTP status {resp.status_code} while trying to get user {self._username}: {resp.content}"
)
user_info = dacite.from_dict(UserAPIResponse, resp.json())
return user_info

0 comments on commit c4c82bc

Please sign in to comment.