Skip to content

Commit

Permalink
Add append_data option to ParquetDataCatalog.write_data
Browse files Browse the repository at this point in the history
  • Loading branch information
faysou committed Oct 1, 2024
1 parent ecdcf98 commit 6441d98
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 40 deletions.
8 changes: 5 additions & 3 deletions examples/backtest/databento_option_greeks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from nautilus_trader.adapters.databento.data_utils import data_path
from nautilus_trader.adapters.databento.data_utils import databento_data
from nautilus_trader.adapters.databento.data_utils import load_catalog

# from nautilus_trader.adapters.databento.data_utils import init_databento_client
from nautilus_trader.backtest.node import BacktestNode
from nautilus_trader.common.enums import LogColor
from nautilus_trader.config import BacktestDataConfig
Expand Down Expand Up @@ -175,8 +177,8 @@ def user_log(self, msg):
# %%
# BacktestEngineConfig

load_greeks = False
stream_data = False
# for saving and loading custom data greeks, use False, True then True, False below
load_greeks, stream_data = False, False

actors = [
ImportableActorConfig(
Expand Down Expand Up @@ -273,7 +275,7 @@ def user_log(self, msg):
engine=engine_config,
data=data,
venues=venues,
chunk_size=10_000, # use None when using load_greeks ?
chunk_size=None, # use None when loading custom data
),
]

Expand Down
85 changes: 56 additions & 29 deletions nautilus_trader/adapters/databento/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,25 @@ def databento_cost(symbols, start_time, end_time, schema, dataset="GLBX.MDP3", *
Calculate the cost of retrieving data from the Databento API for the given
parameters.
Args:
symbols (list[str]): The symbols to retrieve data for.
start_time (str): The start time of the data in ISO 8601 format.
end_time (str): The end time of the data in ISO 8601 format.
schema (str): The data schema to retrieve.
dataset (str, optional): The Databento dataset to use, defaults to "GLBX.MDP3".
**kwargs: Additional keyword arguments to pass to the Databento API.
Returns:
float: The estimated cost of retrieving the data.
Parameters
----------
symbols : list of str
The symbols to retrieve data for.
start_time : str
The start time of the data in ISO 8601 format.
end_time : str
The end time of the data in ISO 8601 format.
schema : str
The data schema to retrieve.
dataset : str, optional
The Databento dataset to use, defaults to "GLBX.MDP3".
**kwargs
Additional keyword arguments to pass to the Databento API.
Returns
-------
float
The estimated cost of retrieving the data.
"""
definition_start_date, definition_end_date = databento_definition_dates(start_time)
Expand All @@ -98,29 +107,46 @@ def databento_data(
dataset="GLBX.MDP3",
to_catalog=True,
base_path=None,
append_data=False,
**kwargs,
):
"""
Download and save Databento data and definition files, and optionally save the data
to a catalog.
Args:
symbols (list[str]): The symbols to retrieve data for.
start_time (str): The start time of the data in ISO 8601 format.
end_time (str): The end time of the data in ISO 8601 format.
schema (str): The data schema to retrieve, either "definition" or another valid schema.
file_prefix (str): The prefix to use for the downloaded data files.
*folders (str): Additional folders to create in the data path.
dataset (str, optional): The Databento dataset to use, defaults to "GLBX.MDP3".
to_catalog (bool, optional): Whether to save the data to a catalog, defaults to True.
base_path (str, optional): The base path to use for the data folder, defaults to None.
**kwargs: Additional keyword arguments to pass to the Databento API.
Returns:
dict: A dictionary containing the downloaded data and metadata.
Note:
If schema is equal to 'definition' then no data is downloaded or saved to the catalog.
Parameters
----------
symbols : list of str
The symbols to retrieve data for.
start_time : str
The start time of the data in ISO 8601 format.
end_time : str
The end time of the data in ISO 8601 format.
schema : str
The data schema to retrieve, either "definition" or another valid schema.
file_prefix : str
The prefix to use for the downloaded data files.
*folders : str
Additional folders to create in the data path.
dataset : str, optional
The Databento dataset to use, defaults to "GLBX.MDP3".
to_catalog : bool, optional
Whether to save the data to a catalog, defaults to True.
base_path : str, optional
The base path to use for the data folder, defaults to None.
append_data : bool, optional
Whether to append data to an existing catalog, defaults to False.
**kwargs
Additional keyword arguments to pass to the Databento API.
Returns
-------
dict
A dictionary containing the downloaded data and metadata.
Notes
-----
If schema is equal to 'definition' then no data is downloaded or saved to the catalog.
"""
used_path = create_data_folder(*folders, "databento", base_path=base_path)
Expand Down Expand Up @@ -185,21 +211,22 @@ def databento_data(
data_file,
*folders,
base_path=base_path,
append_data=append_data,
)
result.update(catalog_data)

return result


def save_data_to_catalog(definition_file, data_file, *folders, base_path=None):
def save_data_to_catalog(definition_file, data_file, *folders, base_path=None, append_data=False):
catalog = load_catalog(*folders, base_path=base_path)

loader = DatabentoDataLoader()
nautilus_definition = loader.from_dbn_file(definition_file, as_legacy_cython=True)
nautilus_data = loader.from_dbn_file(data_file, as_legacy_cython=False)

catalog.write_data(nautilus_definition)
catalog.write_data(nautilus_data)
catalog.write_data(nautilus_data, append_data=append_data)

return {
"catalog": catalog,
Expand Down
40 changes: 32 additions & 8 deletions nautilus_trader/persistence/catalog/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ def write_chunk(
data_cls: type[Data],
instrument_id: str | None = None,
basename_template: str = "part-{i}",
append_data=False,
**kwargs: Any,
) -> None:
if isinstance(data[0], CustomData):
Expand All @@ -250,6 +251,7 @@ def write_chunk(
path=path,
fs=self.fs,
basename_template=basename_template,
append_data=append_data,
)
else:
# Write parquet file
Expand All @@ -261,8 +263,7 @@ def write_chunk(
filesystem=self.fs,
min_rows_per_group=self.min_rows_per_group,
max_rows_per_group=self.max_rows_per_group,
**self.dataset_kwargs,
**kwargs,
**kw,
)

def _fast_write(
Expand All @@ -271,20 +272,40 @@ def _fast_write(
path: str,
fs: fsspec.AbstractFileSystem,
basename_template: str,
append_data=False,
) -> None:
name = basename_template.format(i=0)
fs.mkdirs(path, exist_ok=True)
pq.write_table(
table,
where=f"{path}/{name}.parquet",
filesystem=fs,
row_group_size=self.max_rows_per_group,
)
parquet_file = f"{path}/{name}.parquet"

# following solution from https://stackoverflow.com/a/70817689
if append_data and Path(parquet_file).exists():
existing_table = pq.read_table(source=parquet_file, pre_buffer=False, memory_map=True)

with pq.ParquetWriter(
where=parquet_file,
schema=existing_table.schema,
filesystem=fs,
write_batch_size=self.max_rows_per_group,
) as pq_writer:

pq_writer.write_table(existing_table)

table = table.cast(existing_table.schema)
pq_writer.write_table(table)
else:
pq.write_table(
table,
where=parquet_file,
filesystem=fs,
row_group_size=self.max_rows_per_group,
)

def write_data(
self,
data: list[Data | Event] | list[NautilusRustDataType],
basename_template: str = "part-{i}",
append_data=False,
**kwargs: Any,
) -> None:
"""
Expand All @@ -303,6 +324,8 @@ def write_data(
The token '{i}' will be replaced with an automatically incremented
integer as files are partitioned.
If not specified, it defaults to 'part-{i}' + the default extension '.parquet'.
append_data : bool, default False
If True, appends the data to an existing file instead of overwriting it.
kwargs : Any
Additional keyword arguments to be passed to the `write_chunk` method.
Expand Down Expand Up @@ -352,6 +375,7 @@ def obj_to_type(obj: Data) -> type:
data_cls=name_to_cls[cls_name],
instrument_id=instrument_id,
basename_template=basename_template,
append_data=append_data,
**kwargs,
)

Expand Down
21 changes: 21 additions & 0 deletions tests/unit_tests/persistence/test_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,27 @@ def test_catalog_bars_querying_by_bar_type(catalog: ParquetDataCatalog) -> None:
assert len(bars) == len(stub_bars) == 10


def test_catalog_append_data(catalog: ParquetDataCatalog) -> None:
# Arrange
bar_type = TestDataStubs.bartype_adabtc_binance_1min_last()
instrument = TestInstrumentProvider.adabtc_binance()
stub_bars = TestDataStubs.binance_bars_from_csv(
"ADABTC-1m-2021-11-27.csv",
bar_type,
instrument,
)
catalog.write_data(stub_bars)

# Act
catalog.write_data(stub_bars, append_data=True)

# Assert
bars = catalog.bars(bar_types=[str(bar_type)])
all_bars = catalog.bars()
assert len(all_bars) == 20
assert len(bars) == len(stub_bars) == 20


def test_catalog_bars_querying_by_instrument_id(catalog: ParquetDataCatalog) -> None:
# Arrange
bar_type = TestDataStubs.bartype_adabtc_binance_1min_last()
Expand Down

0 comments on commit 6441d98

Please sign in to comment.