Skip to content

Commit

Permalink
add download_csv function
Browse files Browse the repository at this point in the history
  • Loading branch information
bh2smith committed Sep 6, 2023
1 parent b84dd2c commit f3a3bf3
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 28 deletions.
2 changes: 0 additions & 2 deletions dune_client/api/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,6 @@ def get_execution_results_csv(self, job_id: str) -> ExecutionResultCSV:
if you need metadata information use get_results() or get_status()
"""
route = f"/execution/{job_id}/results/csv"
url = self._route_url(f"/execution/{job_id}/results/csv")
self.logger.debug(f"GET CSV received input url={url}")
response = self._get(route=route, raw=True)
response.raise_for_status()
return ExecutionResultCSV(data=BytesIO(response.content))
Expand Down
31 changes: 18 additions & 13 deletions dune_client/api/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

import time
from io import BytesIO
from typing import Union, Optional, Any

from deprecated import deprecated
Expand All @@ -15,7 +16,7 @@
QueryFailed,
ExecutionResultCSV,
)
from dune_client.query import QueryBase
from dune_client.query import QueryBase, parse_query_object_or_id


class ExtendedAPI(ExecutionAPI):
Expand Down Expand Up @@ -78,7 +79,7 @@ def run_query_csv(
job_id = self._refresh(
query, ping_frequency=ping_frequency, performance=performance
)
return self.get_result_csv(job_id)
return self.get_execution_results_csv(job_id)

def run_query_dataframe(
self, query: QueryBase, performance: Optional[str] = None
Expand All @@ -102,19 +103,11 @@ def get_latest_result(self, query: Union[QueryBase, str, int]) -> ResultsRespons
"""
GET the latest results for a query_id without having to execute the query again.
:param query: :class:`Query` object OR query id as string | int
:param query: :class:`Query` object OR query id as string or int
https://dune.com/docs/api/api-reference/latest_results/
https://dune.com/docs/api/api-reference/get-results/latest-results
"""
if isinstance(query, QueryBase):
params = {
f"params.{p.key}": p.to_dict()["value"] for p in query.parameters()
}
query_id = query.query_id
else:
params = None
query_id = int(query)

params, query_id = parse_query_object_or_id(query)
response_json = self._get(
route=f"/query/{query_id}/results",
params=params,
Expand Down Expand Up @@ -187,3 +180,15 @@ def upload_csv(self, table_name: str, data: str, description: str = "") -> bool:
return bool(response_json["success"])
except KeyError as err:
raise DuneError(response_json, "upload_csv response", err) from err

def download_csv(self, query: QueryBase | str | int) -> ExecutionResultCSV:
"""
Almost like an alias for `get_latest_results` but for the csv endpoint.
https://dune.com/docs/api/api-reference/get-results/latest-results
"""
params, query_id = parse_query_object_or_id(query)
response = self._get(
route=f"/query/{query_id}/results/csv", params=params, raw=True
)
response.raise_for_status()
return ExecutionResultCSV(data=BytesIO(response.content))
12 changes: 2 additions & 10 deletions dune_client/client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
ExecutionState,
)

from dune_client.query import QueryBase
from dune_client.query import QueryBase, parse_query_object_or_id


# pylint: disable=duplicate-code
Expand Down Expand Up @@ -181,15 +181,7 @@ async def get_latest_result(
https://dune.com/docs/api/api-reference/latest_results/
"""
if isinstance(query, QueryBase):
params = {
f"params.{p.key}": p.to_dict()["value"] for p in query.parameters()
}
query_id = query.query_id
else:
params = None
query_id = int(query)

params, query_id = parse_query_object_or_id(query)
response_json = await self._get(
route=f"/query/{query_id}/results",
params=params,
Expand Down
2 changes: 1 addition & 1 deletion dune_client/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ class ExecutionResultCSV:
Representation of a raw `result` in CSV format
this payload can be passed directly to
csv.reader(data) or
pandas.from_csv(data)
pandas.read_csv(data)
"""

data: BytesIO # includes all CSV rows, including the header row.
Expand Down
16 changes: 16 additions & 0 deletions dune_client/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,22 @@
from dune_client.types import QueryParameter


def parse_query_object_or_id(
query: QueryBase | str | int,
) -> tuple[dict[str, str] | None, int]:
"""
Users are allowed to pass QueryBase or ID into some functions.
This method handles both scenarios, returning a pair of the form (params, query_id)
"""
if isinstance(query, QueryBase):
params = {f"params.{p.key}": p.to_dict()["value"] for p in query.parameters()}
query_id = query.query_id
else:
params = None
query_id = int(query)
return params, query_id


@dataclass
class QueryBase:
"""Basic data structure constituting a Dune Analytics Query."""
Expand Down
20 changes: 20 additions & 0 deletions tests/e2e/test_client.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import copy
import csv
import os
import time
import unittest

import dotenv
import pandas

from dune_client.models import (
ExecutionState,
Expand Down Expand Up @@ -186,6 +188,24 @@ def test_upload_csv_success(self):
True,
)

def test_download_csv_success_by_id(self):
client = DuneClient(self.valid_api_key)
result_csv = client.download_csv(self.query.query_id)
self.assertEqual(
pandas.read_csv(result_csv.data).to_csv(index=False),
"text_field,number_field,date_field,list_field\n"
"Word,3.1415926535,2022-05-04 00:00:00.000,Option 1\n",
)

def test_download_csv_success_with_params(self):
client = DuneClient(self.valid_api_key)
result_csv = client.download_csv(self.query)
self.assertEqual(
pandas.read_csv(result_csv.data).to_csv(index=False),
"text_field,number_field,date_field,list_field\n"
"Plain Text,3.1415926535,2022-05-04 00:00:00.000,Option 1\n",
)


@unittest.skip("This is an enterprise only endpoint that can no longer be tested.")
class TestCRUDOps(unittest.TestCase):
Expand Down
28 changes: 26 additions & 2 deletions tests/unit/test_query.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import unittest
from datetime import datetime

from dune_client.query import QueryBase
from dune_client.query import QueryBase, parse_query_object_or_id
from dune_client.types import QueryParameter


class TestQueryMonitor(unittest.TestCase):
class TestQueryBase(unittest.TestCase):
def setUp(self) -> None:
self.date = datetime(year=1985, month=3, day=10)
self.query_params = [
Expand Down Expand Up @@ -60,6 +60,30 @@ def test_hash(self):
query2 = QueryBase(query_id=1, params=[QueryParameter.number_type("num", 1)])
self.assertNotEqual(hash(query1), hash(query2))

def test_parse_object_or_id(self):
expected_params = {
"params.Date": "2021-01-01 12:34:56",
"params.Enum": "option1",
"params.Number": "12",
"params.Text": "plain text",
}
expected_query_id = self.query.query_id
# Query Object
self.assertEqual(
parse_query_object_or_id(self.query), (expected_params, expected_query_id)
)
# Query ID (integer)
expected_params = None
self.assertEqual(
parse_query_object_or_id(self.query.query_id),
(expected_params, expected_query_id),
)
# Query ID (string)
self.assertEqual(
parse_query_object_or_id(str(self.query.query_id)),
(expected_params, expected_query_id),
)


if __name__ == "__main__":
unittest.main()

0 comments on commit f3a3bf3

Please sign in to comment.