Skip to content

Commit

Permalink
Added device code callback param (#475)
Browse files Browse the repository at this point in the history
  • Loading branch information
AsafMah authored May 2, 2023
1 parent cce9862 commit 6b9250e
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 60 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## Unreleased
### Added
- Added callback parameter to device code

## [4.1.4] - 2023-04-16
### Fixed
- Unicode headers are now espaced using '?', to align with the service
Expand Down
100 changes: 47 additions & 53 deletions azure-kusto-data/azure/kusto/data/_token_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,32 @@
# Licensed under the MIT License
import abc
import asyncio
import inspect
import time
import webbrowser
from datetime import datetime
from threading import Lock
from typing import Callable, Coroutine, List, Optional, Any

from azure.core.exceptions import ClientAuthenticationError
from azure.core.tracing import SpanKind
from azure.core.tracing.decorator import distributed_trace
from azure.core.tracing.decorator_async import distributed_trace_async
from azure.core.tracing import SpanKind
from azure.identity import AzureCliCredential, ManagedIdentityCredential
from azure.identity import AzureCliCredential, ManagedIdentityCredential, DeviceCodeCredential
from msal import ConfidentialClientApplication, PublicClientApplication

from ._cloud_settings import CloudInfo, CloudSettings
from ._telemetry import KustoTracing
from .exceptions import KustoAioSyntaxError, KustoAsyncUsageError, KustoClientError

DeviceCallbackType = Callable[[str, str, datetime], None]
"""A callback enabling control of how authentication
instructions are presented. Must accept arguments (``verification_uri``, ``user_code``, ``expires_on``):
- ``verification_uri`` (str) the URL the user must visit
- ``user_code`` (str) the code the user must enter there
- ``expires_on`` (datetime.datetime) the UTC time at which the code will expire
If this argument isn't provided, the credential will print instructions to stdout."""

try:
from asgiref.sync import sync_to_async
except ImportError:
Expand Down Expand Up @@ -485,55 +495,6 @@ def _get_token_from_cache_impl(self) -> dict:
return self._valid_token_or_none(token)


class DeviceLoginTokenProvider(CloudInfoTokenProvider):
"""Acquire a token from MSAL with Device Login flow"""

def __init__(self, kusto_uri: str, authority_id: str, device_code_callback=None, is_async: bool = False):
super().__init__(kusto_uri, is_async)
self._msal_client = None
self._auth = authority_id
self._account = None
self._device_code_callback = device_code_callback

@staticmethod
def name() -> str:
return "DeviceLoginTokenProvider"

def _context_impl(self) -> dict:
return {"authority": self._cloud_info.authority_uri(self._auth), "client_id": self._cloud_info.kusto_client_app_id}

def _init_impl(self):
self._msal_client = PublicClientApplication(
client_id=self._cloud_info.kusto_client_app_id, authority=self._cloud_info.authority_uri(self._auth), proxies=self._proxy_dict
)

def _get_token_impl(self) -> Optional[dict]:
flow = self._msal_client.initiate_device_flow(scopes=self._scopes)
try:
if self._device_code_callback:
self._device_code_callback(flow[TokenConstants.MSAL_DEVICE_MSG])
else:
print(flow[TokenConstants.MSAL_DEVICE_MSG])

webbrowser.open(flow[TokenConstants.MSAL_DEVICE_URI])
except KeyError:
raise KustoClientError("Failed to initiate device code flow")

token = self._msal_client.acquire_token_by_device_flow(flow)

# Keep the account for silent login
if self._valid_token_or_none(token) is not None:
accounts = self._msal_client.get_accounts()
if len(accounts) == 1:
self._account = accounts[0]

return self._valid_token_or_throw(token)

def _get_token_from_cache_impl(self) -> dict:
token = self._msal_client.acquire_token_silent(scopes=self._scopes, account=self._account)
return self._valid_token_or_none(token)


class InteractiveLoginTokenProvider(CloudInfoTokenProvider):
"""Acquire a token from MSAL with Device Login flow"""

Expand Down Expand Up @@ -694,7 +655,11 @@ def _get_token_impl(self) -> Optional[dict]:
return {TokenConstants.MSAL_TOKEN_TYPE: TokenConstants.BEARER_TYPE, TokenConstants.MSAL_ACCESS_TOKEN: t.token}

async def _get_token_impl_async(self) -> Optional[dict]:
t = await self.credential.get_token(self._scopes[0])
# check if get_token is async
if inspect.iscoroutinefunction(self.credential.get_token):
t = await self.credential.get_token(self._scopes[0])
else:
t = await sync_to_async(self.credential.get_token)(self._scopes[0])
return {TokenConstants.MSAL_TOKEN_TYPE: TokenConstants.BEARER_TYPE, TokenConstants.MSAL_ACCESS_TOKEN: t.token}

def _get_token_from_cache_impl(self) -> Optional[dict]:
Expand All @@ -708,3 +673,32 @@ def close(self):
self.credential.close()
self.credential = None
self.credential_from_login_endpoint = None


class DeviceLoginTokenProvider(AzureIdentityTokenCredentialProvider):
"""Acquire a token from MSAL with Device Login flow"""

def __init__(self, kusto_uri: str, authority_id: str, device_code_callback: DeviceCallbackType = None, is_async: bool = False):
self._msal_client = None
self._auth = authority_id
self._account = None
self._device_code_callback = device_code_callback

def credential_from_login_endpoint(endpoint: str):
cred = DeviceCodeCredential(
authority=endpoint,
tenant_id=self._auth,
client_id=self._cloud_info.kusto_client_app_id,
prompt_callback=self._device_code_callback,
)

return cred

super().__init__(kusto_uri, is_async, credential_from_login_endpoint=credential_from_login_endpoint)

@staticmethod
def name() -> str:
return "DeviceLoginTokenProvider"

def _context_impl(self) -> dict:
return {"authority": self._cloud_info.authority_uri(self._auth), "client_id": self._cloud_info.kusto_client_app_id}
12 changes: 11 additions & 1 deletion azure-kusto-data/azure/kusto/data/kcsb.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from enum import unique, Enum
from typing import Union, Callable, Coroutine, Optional, Tuple, List, Any

from ._token_providers import DeviceCallbackType
from ._string_utils import assert_string_is_not_empty
from .client_details import ClientDetails

Expand All @@ -12,6 +13,8 @@ class KustoConnectionStringBuilder:
https://github.com/Azure/azure-kusto-python/blob/master/azure-kusto-data/tests/sample.py
"""

device_callback: DeviceCallbackType = None

kcsb_invalid_item_error = "%s is not supported as an item in KustoConnectionStringBuilder"

@unique
Expand Down Expand Up @@ -312,16 +315,23 @@ def with_aad_application_token_authentication(cls, connection_string: str, appli
return kcsb

@classmethod
def with_aad_device_authentication(cls, connection_string: str, authority_id: str = "organizations") -> "KustoConnectionStringBuilder":
def with_aad_device_authentication(
cls, connection_string: str, authority_id: str = "organizations", callback: DeviceCallbackType = None
) -> "KustoConnectionStringBuilder":
"""
Creates a KustoConnection string builder that will authenticate with AAD application and
password.
:param str connection_string: Kusto connection string should be of the format: https://<clusterName>.kusto.windows.net
:param str authority_id: optional param. defaults to "organizations"
:param DeviceCallbackType callback: options callback function to be called when authentication is required, accepts three parameters:
- ``verification_uri`` (str) the URL the user must visit
- ``user_code`` (str) the code the user must enter there
- ``expires_on`` (datetime.datetime) the UTC time at which the code will expire
"""
kcsb = cls(connection_string)
kcsb[kcsb.ValidKeywords.aad_federated_security] = True
kcsb[kcsb.ValidKeywords.authority_id] = authority_id
kcsb.device_callback = callback

return kcsb

Expand Down
2 changes: 1 addition & 1 deletion azure-kusto-data/azure/kusto/data/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(self, kcsb: "KustoConnectionStringBuilder", is_async: bool):
credential_from_login_endpoint=kcsb.credential_from_login_endpoint,
)
else: # TODO - next breaking change - remove this as default, make no auth the default
self.token_provider = DeviceLoginTokenProvider(self.kusto_uri, kcsb.authority_id, is_async=is_async)
self.token_provider = DeviceLoginTokenProvider(self.kusto_uri, kcsb.authority_id, kcsb.device_callback, is_async=is_async)

def acquire_authorization_header(self):
try:
Expand Down
4 changes: 2 additions & 2 deletions azure-kusto-data/tests/aio/test_async_token_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,9 @@ async def test_device_auth_provider(self):
print(" *** Skipped User Device Flow Test ***")
return

def callback(x):
def callback(x, x2, x3):
# break here if you debug this test, and get the code from 'x'
print(x)
print(f"Please go to {x} and enter code {x2} to authenticate, expires in {x3}")

with DeviceLoginTokenProvider(KUSTO_URI, "organizations", callback, is_async=True) as provider:
token = await provider.get_token_async()
Expand Down
6 changes: 3 additions & 3 deletions azure-kusto-data/tests/test_token_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from threading import Thread

from asgiref.sync import async_to_sync

from azure.identity import ClientSecretCredential, DefaultAzureCredential

from azure.kusto.data._token_providers import *

KUSTO_URI = "https://sdkse2etest.eastus.kusto.windows.net"
Expand Down Expand Up @@ -232,9 +232,9 @@ def test_device_auth_provider():
print(" *** Skipped User Device Flow Test ***")
return

def callback(x):
def callback(x, x2, x3):
# break here if you debug this test, and get the code from 'x'
print(x)
print(f"Please go to {x} and enter code {x2} to authenticate, expires in {x3}")

with DeviceLoginTokenProvider(KUSTO_URI, "organizations", callback) as provider:
token = provider.get_token()
Expand Down

0 comments on commit 6b9250e

Please sign in to comment.