Skip to content

Commit

Permalink
Default DB + parsing (#477)
Browse files Browse the repository at this point in the history
  • Loading branch information
AsafMah authored May 9, 2023
1 parent 6b9250e commit 0286495
Show file tree
Hide file tree
Showing 12 changed files with 234 additions and 74 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,14 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## Unreleased
### Added
- Added Initial Catalog (Default Database) parameter to ConnectionStringBuilder
- Added callback parameter to device code
- Added method to manually set the cache for CloudSettings
### Changed
- Urls with one item after the path (i.e https://test.com/abc) will now be treated as cluster and initial catalog (ie. the cluster is "https://test.com" and the initial catalog is "abc").
* This is to align our behaviour with the .NET SDK
### Fixed
- Some edge cases in url parsing

## [4.1.4] - 2023-04-16
### Fixed
Expand Down
13 changes: 13 additions & 0 deletions azure-kusto-data/azure/kusto/data/_cloud_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ class CloudSettings:
@classmethod
@distributed_trace(name_of_span="CloudSettings.get_cloud_info", kind=SpanKind.CLIENT)
def get_cloud_info_for_cluster(cls, kusto_uri: str, proxies: Optional[Dict[str, str]] = None) -> CloudInfo:
kusto_uri = cls._normalize_uri(kusto_uri)

# tracing attributes for cloud info
KustoTracingAttributes.set_cloud_info_attributes(kusto_uri)

Expand Down Expand Up @@ -115,3 +117,14 @@ def get_cloud_info_for_cluster(cls, kusto_uri: str, proxies: Optional[Dict[str,
else:
raise KustoServiceError("Kusto returned an invalid cloud metadata response", result)
return cls._cloud_cache[kusto_uri]

@classmethod
def add_to_cache(cls, url: str, cloud_info: CloudInfo):
with cls._cloud_cache_lock:
cls._cloud_cache[cls._normalize_uri(url)] = cloud_info

@classmethod
def _normalize_uri(cls, kusto_uri):
if not kusto_uri.endswith("/"):
kusto_uri += "/"
return kusto_uri
26 changes: 19 additions & 7 deletions azure-kusto-data/azure/kusto/data/aio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,22 +43,24 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()

@aio_documented_by(KustoClientSync.execute)
async def execute(self, database: str, query: str, properties: ClientRequestProperties = None) -> KustoResponseDataSet:
async def execute(self, database: Optional[str], query: str, properties: ClientRequestProperties = None) -> KustoResponseDataSet:
query = query.strip()
if query.startswith("."):
return await self.execute_mgmt(database, query, properties)
return await self.execute_query(database, query, properties)

@distributed_trace_async(name_of_span="KustoClient.query_cmd", kind=SpanKind.CLIENT)
@aio_documented_by(KustoClientSync.execute_query)
async def execute_query(self, database: str, query: str, properties: ClientRequestProperties = None) -> KustoResponseDataSet:
async def execute_query(self, database: Optional[str], query: str, properties: ClientRequestProperties = None) -> KustoResponseDataSet:
database = self._get_database_or_default(database)
KustoTracingAttributes.set_query_attributes(self._kusto_cluster, database, properties)

return await self._execute(self._query_endpoint, database, query, None, KustoClient._query_default_timeout, properties)

@distributed_trace_async(name_of_span="KustoClient.control_cmd", kind=SpanKind.CLIENT)
@aio_documented_by(KustoClientSync.execute_mgmt)
async def execute_mgmt(self, database: str, query: str, properties: ClientRequestProperties = None) -> KustoResponseDataSet:
async def execute_mgmt(self, database: Optional[str], query: str, properties: ClientRequestProperties = None) -> KustoResponseDataSet:
database = self._get_database_or_default(database)
KustoTracingAttributes.set_query_attributes(self._kusto_cluster, database, properties)

return await self._execute(self._mgmt_endpoint, database, query, None, KustoClient._mgmt_default_timeout, properties)
Expand All @@ -67,13 +69,14 @@ async def execute_mgmt(self, database: str, query: str, properties: ClientReques
@aio_documented_by(KustoClientSync.execute_streaming_ingest)
async def execute_streaming_ingest(
self,
database: str,
database: Optional[str],
table: str,
stream: io.IOBase,
stream_format: Union[DataFormat, str],
properties: ClientRequestProperties = None,
mapping_name: str = None,
):
database = self._get_database_or_default(database)
KustoTracingAttributes.set_streaming_ingest_attributes(self._kusto_cluster, database, table, properties)

stream_format = stream_format.kusto_value if isinstance(stream_format, DataFormat) else DataFormat[stream_format.upper()].kusto_value
Expand All @@ -85,16 +88,25 @@ async def execute_streaming_ingest(

@aio_documented_by(KustoClientSync._execute_streaming_query_parsed)
async def _execute_streaming_query_parsed(
self, database: str, query: str, timeout: timedelta = _KustoClientBase._query_default_timeout, properties: Optional[ClientRequestProperties] = None
self,
database: Optional[str],
query: str,
timeout: timedelta = _KustoClientBase._query_default_timeout,
properties: Optional[ClientRequestProperties] = None,
) -> StreamingDataSetEnumerator:
response = await self._execute(self._query_endpoint, database, query, None, timeout, properties, stream_response=True)
return StreamingDataSetEnumerator(JsonTokenReader(response.content))

@distributed_trace_async(name_of_span="KustoClient.streaming_query", kind=SpanKind.CLIENT)
@aio_documented_by(KustoClientSync.execute_streaming_query)
async def execute_streaming_query(
self, database: str, query: str, timeout: timedelta = _KustoClientBase._query_default_timeout, properties: Optional[ClientRequestProperties] = None
self,
database: Optional[str],
query: str,
timeout: timedelta = _KustoClientBase._query_default_timeout,
properties: Optional[ClientRequestProperties] = None,
) -> KustoStreamingResponseDataSet:
database = self._get_database_or_default(database)
KustoTracingAttributes.set_query_attributes(self._kusto_cluster, database, properties)

response = await self._execute_streaming_query_parsed(database, query, timeout, properties)
Expand All @@ -104,7 +116,7 @@ async def execute_streaming_query(
async def _execute(
self,
endpoint: str,
database: str,
database: Optional[str],
query: Optional[str],
payload: Optional[io.IOBase],
timeout: timedelta,
Expand Down
35 changes: 23 additions & 12 deletions azure-kusto-data/azure/kusto/data/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,10 @@ def compose_socket_options() -> List[Tuple[int, int, int]]:
else:
return []

def execute(self, database: str, query: str, properties: Optional[ClientRequestProperties] = None) -> KustoResponseDataSet:
def execute(self, database: Optional[str], query: str, properties: Optional[ClientRequestProperties] = None) -> KustoResponseDataSet:
"""
Executes a query or management command.
:param str database: Database against query will be executed.
:param Optional[str] database: Database against query will be executed. If not provided, will default to the "Initial Catalog" value in the connection string
:param str query: Query to be executed.
:param azure.kusto.data.ClientRequestProperties properties: Optional additional properties.
:return: Kusto response data set.
Expand All @@ -160,39 +160,41 @@ def execute(self, database: str, query: str, properties: Optional[ClientRequestP
return self.execute_query(database, query, properties)

@distributed_trace(name_of_span="KustoClient.query_cmd", kind=SpanKind.CLIENT)
def execute_query(self, database: str, query: str, properties: Optional[ClientRequestProperties] = None) -> KustoResponseDataSet:
def execute_query(self, database: Optional[str], query: str, properties: Optional[ClientRequestProperties] = None) -> KustoResponseDataSet:
"""
Execute a KQL query.
To learn more about KQL go to https://docs.microsoft.com/en-us/azure/kusto/query/
:param str database: Database against query will be executed.
:param Optional[str] database: Database against query will be executed. If not provided, will default to the "Initial Catalog" value in the connection string
:param str query: Query to be executed.
:param azure.kusto.data.ClientRequestProperties properties: Optional additional properties.
:return: Kusto response data set.
:rtype: azure.kusto.data.response.KustoResponseDataSet
"""
database = self._get_database_or_default(database)
KustoTracingAttributes.set_query_attributes(self._kusto_cluster, database, properties)

return self._execute(self._query_endpoint, database, query, None, self._query_default_timeout, properties)

@distributed_trace(name_of_span="KustoClient.control_cmd", kind=SpanKind.CLIENT)
def execute_mgmt(self, database: str, query: str, properties: Optional[ClientRequestProperties] = None) -> KustoResponseDataSet:
def execute_mgmt(self, database: Optional[str], query: str, properties: Optional[ClientRequestProperties] = None) -> KustoResponseDataSet:
"""
Execute a KQL control command.
To learn more about KQL control commands go to https://docs.microsoft.com/en-us/azure/kusto/management/
:param str database: Database against query will be executed.
:param Optional[str] database: Database against query will be executed. If not provided, will default to the "Initial Catalog" value in the connection string
:param str query: Query to be executed.
:param azure.kusto.data.ClientRequestProperties properties: Optional additional properties.
:return: Kusto response data set.
:rtype: azure.kusto.data.response.KustoResponseDataSet
"""
database = self._get_database_or_default(database)
KustoTracingAttributes.set_query_attributes(self._kusto_cluster, database, properties)

return self._execute(self._mgmt_endpoint, database, query, None, self._mgmt_default_timeout, properties)

@distributed_trace(name_of_span="KustoClient.streaming_ingest", kind=SpanKind.CLIENT)
def execute_streaming_ingest(
self,
database: str,
database: Optional[str],
table: str,
stream: IO[AnyStr],
stream_format: Union[DataFormat, str],
Expand All @@ -204,13 +206,14 @@ def execute_streaming_ingest(
If the Kusto service is not configured to allow streaming ingestion, this may raise an error
To learn more about streaming ingestion go to:
https://docs.microsoft.com/en-us/azure/data-explorer/ingest-data-streaming
:param str database: Target database.
:param Optional[str] database: Target database. If not provided, will default to the "Initial Catalog" value in the connection string
:param str table: Target table.
:param io.BaseIO stream: stream object which contains the data to ingest.
:param DataFormat stream_format: Format of the data in the stream.
:param ClientRequestProperties properties: additional request properties.
:param str mapping_name: Pre-defined mapping of the table. Required when stream_format is json/avro.
"""
database = self._get_database_or_default(database)
KustoTracingAttributes.set_streaming_ingest_attributes(self._kusto_cluster, database, table, properties)

stream_format = stream_format.kusto_value if isinstance(stream_format, DataFormat) else DataFormat[stream_format.upper()].kusto_value
Expand All @@ -221,21 +224,29 @@ def execute_streaming_ingest(
self._execute(endpoint, database, None, stream, self._streaming_ingest_default_timeout, properties)

def _execute_streaming_query_parsed(
self, database: str, query: str, timeout: timedelta = _KustoClientBase._query_default_timeout, properties: Optional[ClientRequestProperties] = None
self,
database: Optional[str],
query: str,
timeout: timedelta = _KustoClientBase._query_default_timeout,
properties: Optional[ClientRequestProperties] = None,
) -> StreamingDataSetEnumerator:
response = self._execute(self._query_endpoint, database, query, None, timeout, properties, stream_response=True)
response.raw.decode_content = True
return StreamingDataSetEnumerator(JsonTokenReader(response.raw))

@distributed_trace(name_of_span="KustoClient.streaming_query", kind=SpanKind.CLIENT)
def execute_streaming_query(
self, database: str, query: str, timeout: timedelta = _KustoClientBase._query_default_timeout, properties: Optional[ClientRequestProperties] = None
self,
database: Optional[str],
query: str,
timeout: timedelta = _KustoClientBase._query_default_timeout,
properties: Optional[ClientRequestProperties] = None,
) -> KustoStreamingResponseDataSet:
"""
Execute a KQL query without reading it all to memory.
The resulting KustoStreamingResponseDataSet will stream one table at a time, and the rows can be retrieved sequentially.
:param str database: Database against query will be executed.
:param Optional[str] database: Database against query will be executed. If not provided, will default to the "Initial Catalog" value in the connection string
:param str query: Query to be executed.
:param timedelta timeout: timeout for the query to be executed
:param azure.kusto.data.ClientRequestProperties properties: Optional additional properties.
Expand All @@ -248,7 +259,7 @@ def execute_streaming_query(
def _execute(
self,
endpoint: str,
database: str,
database: Optional[str],
query: Optional[str],
payload: Optional[IO[AnyStr]],
timeout: timedelta,
Expand Down
8 changes: 8 additions & 0 deletions azure-kusto-data/azure/kusto/data/client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ def __init__(self, kcsb: Union[KustoConnectionStringBuilder, str], is_async):
# notice that in this context, federated actually just stands for aad auth, not aad federated auth (legacy code)
self._aad_helper = _AadHelper(self._kcsb, is_async) if self._kcsb.aad_federated_security else None

if not self._kusto_cluster.endswith("/"):
self._kusto_cluster += "/"

# Create a session object for connection pooling
self._mgmt_endpoint = urljoin(self._kusto_cluster, "v1/rest/mgmt")
self._query_endpoint = urljoin(self._kusto_cluster, "v2/rest/query")
Expand All @@ -58,6 +61,11 @@ def __init__(self, kcsb: Union[KustoConnectionStringBuilder, str], is_async):
self.client_details = self._kcsb.client_details
self._is_closed: bool = False

self.default_database = self._kcsb.initial_catalog

def _get_database_or_default(self, database_name: Optional[str]) -> str:
return database_name or self.default_database

def close(self):
if not self._is_closed:
if self._aad_helper is not None:
Expand Down
34 changes: 32 additions & 2 deletions azure-kusto-data/azure/kusto/data/kcsb.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from enum import unique, Enum
from typing import Union, Callable, Coroutine, Optional, Tuple, List, Any
from urllib.parse import urlparse

from ._token_providers import DeviceCallbackType
from ._string_utils import assert_string_is_not_empty
from ._token_providers import DeviceCallbackType
from .client_details import ClientDetails


Expand All @@ -13,8 +14,8 @@ class KustoConnectionStringBuilder:
https://github.com/Azure/azure-kusto-python/blob/master/azure-kusto-data/tests/sample.py
"""

DEFAULT_DATABASE_NAME = "NetDefaultDB"
device_callback: DeviceCallbackType = None

kcsb_invalid_item_error = "%s is not supported as an item in KustoConnectionStringBuilder"

@unique
Expand All @@ -25,6 +26,7 @@ class ValidKeywords(Enum):
"""

data_source = "Data Source"
initial_catalog = "Initial Catalog"
aad_federated_security = "AAD Federated Security"
aad_user_id = "AAD User ID"
password = "Password"
Expand All @@ -49,6 +51,8 @@ def parse(cls, key: str) -> "KustoConnectionStringBuilder.ValidKeywords":
key = key.lower().strip()
if key in ["data source", "addr", "address", "network address", "server"]:
return cls.data_source
if key in ["initial catalog"]:
return cls.initial_catalog
if key in ["aad user id"]:
return cls.aad_user_id
if key in ["password", "pwd"]:
Expand Down Expand Up @@ -105,6 +109,7 @@ def is_str_type(self) -> bool:
self.user_token,
self.login_hint,
self.domain_hint,
self.initial_catalog,
]

def is_dict_type(self) -> bool:
Expand Down Expand Up @@ -146,6 +151,8 @@ def __init__(self, connection_string: str):
value_stripped = value.strip()
if keyword.is_str_type():
self[keyword] = value_stripped.rstrip("/")
if keyword == self.ValidKeywords.data_source:
self._parse_data_source(self.data_source)
elif keyword.is_bool_type():
if value_stripped in ["True", "true"]:
self[keyword] = True
Expand All @@ -154,6 +161,9 @@ def __init__(self, connection_string: str):
else:
raise KeyError("Expected aad federated security to be bool. Recieved %s" % value)

if self.initial_catalog is None:
self.initial_catalog = self.DEFAULT_DATABASE_NAME

def __setitem__(self, key: "Union[KustoConnectionStringBuilder.ValidKeywords, str]", value: Union[str, bool, dict]):
try:
keyword = key if isinstance(key, self.ValidKeywords) else self.ValidKeywords.parse(key)
Expand Down Expand Up @@ -479,6 +489,17 @@ def data_source(self) -> Optional[str]:
"""
return self._internal_dict.get(self.ValidKeywords.data_source)

@property
def initial_catalog(self) -> Optional[str]:
"""The default database to be used for requests.
By default, it is set to 'NetDefaultDB'.
"""
return self._internal_dict.get(self.ValidKeywords.initial_catalog)

@initial_catalog.setter
def initial_catalog(self, value: str) -> None:
self._internal_dict[self.ValidKeywords.initial_catalog] = value

@property
def aad_user_id(self) -> Optional[str]:
"""The username to use for AAD Federated AuthN."""
Expand Down Expand Up @@ -647,3 +668,12 @@ def __repr__(self) -> str:

def _build_connection_string(self, kcsb_as_dict: dict) -> str:
return ";".join(["{0}={1}".format(word.value, kcsb_as_dict[word]) for word in self.ValidKeywords if word in kcsb_as_dict])

def _parse_data_source(self, url: str):
url = urlparse(url)
if not url.netloc:
return
segments = url.path.lstrip("/").split("/")
if len(segments) == 1 and segments[0] and not self.initial_catalog:
self.initial_catalog = segments[0]
self._internal_dict[self.ValidKeywords.data_source] = url._replace(path="").geturl()
4 changes: 2 additions & 2 deletions azure-kusto-data/tests/aio/test_async_token_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ async def test_cloud_mfa_off(self):
kusto_service_resource_id="https://fakeurl.kusto.windows.net",
first_party_authority_url="",
)
CloudSettings._cloud_cache[FAKE_URI] = cloud
CloudSettings.add_to_cache(FAKE_URI, cloud)
authority = "auth_test"

with UserPassTokenProvider(FAKE_URI, authority, "a", "b", is_async=True) as provider:
Expand All @@ -317,7 +317,7 @@ async def test_cloud_mfa_on(self):
kusto_service_resource_id="https://fakeurl.kusto.windows.net",
first_party_authority_url="",
)
CloudSettings._cloud_cache[FAKE_URI] = cloud
CloudSettings.add_to_cache(FAKE_URI, cloud)
authority = "auth_test"

with UserPassTokenProvider(FAKE_URI, authority, "a", "b", is_async=True) as provider:
Expand Down
Loading

0 comments on commit 0286495

Please sign in to comment.