diff --git a/.ci/test-python-oldest.sh b/.ci/test-python-oldest.sh index 09cc24633e15..3c1bcaa3884d 100644 --- a/.ci/test-python-oldest.sh +++ b/.ci/test-python-oldest.sh @@ -7,8 +7,10 @@ # echo "installing lightgbm's dependencies" pip install \ + 'cffi==1.15.1' \ 'numpy==1.12.0' \ 'pandas==0.24.0' \ + 'pyarrow==12.0.0' \ 'scikit-learn==0.18.2' \ 'scipy==0.19.0' \ || exit -1 diff --git a/.ci/test.sh b/.ci/test.sh index b3acc4a670cf..1a1ad4611c09 100755 --- a/.ci/test.sh +++ b/.ci/test.sh @@ -121,6 +121,7 @@ fi # including python=version[build=*cpython] to ensure that conda doesn't fall back to pypy conda create -q -y -n $CONDA_ENV \ + cffi \ cloudpickle \ dask-core \ distributed \ @@ -129,6 +130,7 @@ conda create -q -y -n $CONDA_ENV \ numpy \ pandas \ psutil \ + pyarrow \ pytest \ ${CONDA_PYTHON_REQUIREMENT} \ python-graphviz \ diff --git a/.ci/test_windows.ps1 b/.ci/test_windows.ps1 index 413af821e065..6b02aed6ce8b 100644 --- a/.ci/test_windows.ps1 +++ b/.ci/test_windows.ps1 @@ -52,12 +52,14 @@ conda install brotlipy conda update -q -y conda conda create -q -y -n $env:CONDA_ENV ` + cffi ` cloudpickle ` joblib ` matplotlib ` numpy ` pandas ` psutil ` + pyarrow ` pytest ` "python=$env:PYTHON_VERSION[build=*cpython]" ` python-graphviz ` diff --git a/.gitignore b/.gitignore index d4045d9a4798..3fb53b8f1e2a 100644 --- a/.gitignore +++ b/.gitignore @@ -139,7 +139,7 @@ publish/ # Publish Web Output *.[Pp]ublish.xml *.azurePubxml -# TODO: Comment the next line if you want to checkin your web deploy settings +# TODO: Comment the next line if you want to checkin your web deploy settings # but database connection strings (with potential passwords) will be unencrypted *.pubxml *.publishproj @@ -270,6 +270,7 @@ _Pvt_Extensions /windows/LightGBM.VC.db lightgbm /testlightgbm +!include/LightGBM # Created by https://www.gitignore.io/api/python diff --git a/include/LightGBM/arrow.h b/include/LightGBM/arrow.h new file mode 100644 index 000000000000..3d1c74713bd3 --- /dev/null +++ b/include/LightGBM/arrow.h @@ -0,0 +1,256 @@ +/*! + * Copyright (c) 2023 Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE file in the project root for license information. + * + * Author: Oliver Borchert + */ + +#ifndef LIGHTGBM_ARROW_H_ +#define LIGHTGBM_ARROW_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +/* -------------------------------------- C DATA INTERFACE ------------------------------------- */ +// The C data interface is taken from +// https://arrow.apache.org/docs/format/CDataInterface.html#structure-definitions +// and is available under Apache License 2.0 (https://www.apache.org/licenses/LICENSE-2.0). + +#ifdef __cplusplus +extern "C" { +#endif + +#define ARROW_FLAG_DICTIONARY_ORDERED 1 +#define ARROW_FLAG_NULLABLE 2 +#define ARROW_FLAG_MAP_KEYS_SORTED 4 + +struct ArrowSchema { + // Array type description + const char* format; + const char* name; + const char* metadata; + int64_t flags; + int64_t n_children; + struct ArrowSchema** children; + struct ArrowSchema* dictionary; + + // Release callback + void (*release)(struct ArrowSchema*); + // Opaque producer-specific data + void* private_data; +}; + +struct ArrowArray { + // Array data description + int64_t length; + int64_t null_count; + int64_t offset; + int64_t n_buffers; + int64_t n_children; + const void** buffers; + struct ArrowArray** children; + struct ArrowArray* dictionary; + + // Release callback + void (*release)(struct ArrowArray*); + // Opaque producer-specific data + void* private_data; +}; + +#ifdef __cplusplus +} +#endif + +/* --------------------------------------------------------------------------------------------- */ +/* CHUNKED ARRAY */ +/* --------------------------------------------------------------------------------------------- */ + +namespace LightGBM { + +/** + * @brief Arrow array-like container for a list of Arrow arrays. + */ +class ArrowChunkedArray { + /* List of length `n` for `n` chunks containing the individual Arrow arrays. */ + std::vector chunks_; + /* Schema for all chunks. */ + const ArrowSchema* schema_; + /* List of length `n + 1` for `n` chunks containing the offsets for each chunk. */ + std::vector chunk_offsets_; + + inline void construct_chunk_offsets() { + chunk_offsets_.reserve(chunks_.size() + 1); + chunk_offsets_.emplace_back(0); + for (size_t k = 0; k < chunks_.size(); ++k) { + chunk_offsets_.emplace_back(chunks_[k]->length + chunk_offsets_.back()); + } + } + + public: + /** + * @brief Construct a new Arrow Chunked Array object. + * + * @param chunks A list with the chunks. + * @param schema The schema for all chunks. + */ + inline ArrowChunkedArray(std::vector chunks, const ArrowSchema* schema) { + chunks_ = chunks; + schema_ = schema; + construct_chunk_offsets(); + } + + /** + * @brief Construct a new Arrow Chunked Array object. + * + * @param n_chunks The number of chunks. + * @param chunks A C-style array containing the chunks. + * @param schema The schema for all chunks. + */ + inline ArrowChunkedArray(int64_t n_chunks, + const struct ArrowArray* chunks, + const struct ArrowSchema* schema) { + chunks_.reserve(n_chunks); + for (auto k = 0; k < n_chunks; ++k) { + chunks_.push_back(&chunks[k]); + } + schema_ = schema; + construct_chunk_offsets(); + } + + /** + * @brief Get the length of the chunked array. + * This method returns the cumulative length of all chunks. + * Complexity: O(1) + * + * @return int64_t The number of elements in the chunked array. + */ + inline int64_t get_length() const { return chunk_offsets_.back(); } + + /* ----------------------------------------- ITERATOR ---------------------------------------- */ + template + class Iterator { + using getter_fn = std::function; + + /* Reference to the chunked array that this iterator iterates over. */ + const ArrowChunkedArray& array_; + /* Function to fetch the value at a certain index from a single chunk. */ + getter_fn get_; + /* The chunk the iterator currently points to. */ + int64_t ptr_chunk_; + /* The index inside the current chunk that the iterator points to. */ + int64_t ptr_offset_; + + public: + using iterator_category = std::random_access_iterator_tag; + using difference_type = int64_t; + using value_type = T; + using pointer = value_type*; + using reference = value_type&; + + /** + * @brief Construct a new Iterator object. + * + * @param array Reference to the chunked array to iterator over. + * @param get Function to fetch the value at a certain index from a single chunk. + * @param ptr_chunk The index of the chunk to whose first index the iterator points to. + */ + Iterator(const ArrowChunkedArray& array, getter_fn get, int64_t ptr_chunk); + + T operator*() const; + template + T operator[](I idx) const; + + Iterator& operator++(); + Iterator& operator--(); + Iterator& operator+=(int64_t c); + + template + friend bool operator==(const Iterator& a, const Iterator& b); + template + friend bool operator!=(const Iterator& a, const Iterator& b); + template + friend int64_t operator-(const Iterator& a, const Iterator& b); + }; + + /** + * @brief Obtain an iterator to the beginning of the chunked array. + * + * @tparam T The value type of the iterator. May be any primitive type. + * @return Iterator The iterator. + */ + template + inline Iterator begin() const; + + /** + * @brief Obtain an iterator to the beginning of the chunked array. + * + * @tparam T The value type of the iterator. May be any primitive type. + * @return Iterator The iterator. + */ + template + inline Iterator end() const; + + template + friend int64_t operator-(const Iterator& a, const Iterator& b); +}; + +/** + * @brief Arrow container for a list of chunked arrays. + */ +class ArrowTable { + std::vector columns_; + + public: + /** + * @brief Construct a new Arrow Table object. + * + * @param n_chunks The number of chunks. + * @param chunks A C-style array containing the chunks. + * @param schema The schema for all chunks. + */ + inline ArrowTable(int64_t n_chunks, const ArrowArray* chunks, const ArrowSchema* schema) { + columns_.reserve(schema->n_children); + for (int64_t j = 0; j < schema->n_children; ++j) { + std::vector children_chunks; + children_chunks.reserve(n_chunks); + for (int64_t k = 0; k < n_chunks; ++k) { + children_chunks.push_back(chunks[k].children[j]); + } + columns_.emplace_back(children_chunks, schema->children[j]); + } + } + + /** + * @brief Get the number of rows in the table. + * + * @return int64_t The number of rows. + */ + inline int64_t get_num_rows() const { return columns_.front().get_length(); } + + /** + * @brief Get the number of columns of this table. + * + * @return int64_t The column count. + */ + inline int64_t get_num_columns() const { return columns_.size(); } + + /** + * @brief Get the column at a particular index. + * + * @param idx The index of the column, must me in the range `[0, num_columns)`. + * @return const ArrowChunkedArray& The chunked array for the child at the provided index. + */ + inline const ArrowChunkedArray& get_column(size_t idx) const { return this->columns_[idx]; } +}; + +} // namespace LightGBM + +#include "arrow.tpp" + +#endif /* LIGHTGBM_ARROW_H_ */ diff --git a/include/LightGBM/arrow.tpp b/include/LightGBM/arrow.tpp new file mode 100644 index 000000000000..67b481c9497e --- /dev/null +++ b/include/LightGBM/arrow.tpp @@ -0,0 +1,190 @@ +#include + +#ifndef ARROW_TPP_ +#define ARROW_TPP_ + +namespace LightGBM { + +/** + * @brief Obtain a function to access an index from an Arrow array. + * + * @tparam T The return type of the function, must be a primitive type. + * @param dtype The Arrow format string describing the datatype of the Arrow array. + * @return std::function The index accessor function. + */ +template +std::function get_index_accessor(const char* dtype); + +/* ---------------------------------- ITERATOR INITIALIZATION ---------------------------------- */ + +template +inline ArrowChunkedArray::Iterator ArrowChunkedArray::begin() const { + return ArrowChunkedArray::Iterator(*this, get_index_accessor(schema_->format), 0); +} + +template +inline ArrowChunkedArray::Iterator ArrowChunkedArray::end() const { + return ArrowChunkedArray::Iterator(*this, get_index_accessor(schema_->format), + chunk_offsets_.size() - 1); +} + +/* ---------------------------------- ITERATOR IMPLEMENTATION ---------------------------------- */ + +template +ArrowChunkedArray::Iterator::Iterator(const ArrowChunkedArray& array, + getter_fn get, + int64_t ptr_chunk) + : array_(array), get_(get), ptr_chunk_(ptr_chunk) { + this->ptr_offset_ = 0; +} + +template +T ArrowChunkedArray::Iterator::operator*() const { + auto chunk = array_.chunks_[ptr_chunk_]; + return static_cast(get_(chunk, ptr_offset_)); +} + +template +template +T ArrowChunkedArray::Iterator::operator[](I idx) const { + auto it = std::lower_bound(array_.chunk_offsets_.begin(), array_.chunk_offsets_.end(), idx, + [](int64_t a, int64_t b) { return a <= b; }); + + auto chunk_idx = std::distance(array_.chunk_offsets_.begin() + 1, it); + auto chunk = array_.chunks_[chunk_idx]; + + auto ptr_offset = static_cast(idx) - array_.chunk_offsets_[chunk_idx]; + return static_cast(get_(chunk, ptr_offset)); +} + +template +ArrowChunkedArray::Iterator& ArrowChunkedArray::Iterator::operator++() { + if (ptr_offset_ + 1 >= array_.chunks_[ptr_chunk_]->length) { + ptr_offset_ = 0; + ptr_chunk_++; + } else { + ptr_offset_++; + } + return *this; +} + +template +ArrowChunkedArray::Iterator& ArrowChunkedArray::Iterator::operator--() { + if (ptr_offset_ == 0) { + ptr_chunk_--; + ptr_offset_ = array_.chunks_[ptr_chunk_]->length - 1; + } else { + ptr_chunk_--; + } + return *this; +} + +template +ArrowChunkedArray::Iterator& ArrowChunkedArray::Iterator::operator+=(int64_t c) { + while (ptr_offset_ + c >= array_.chunks_[ptr_chunk_]->length) { + c -= array_.chunks_[ptr_chunk_]->length - ptr_offset_; + ptr_offset_ = 0; + ptr_chunk_++; + } + ptr_offset_ += c; + return *this; +} + +template +bool operator==(const ArrowChunkedArray::Iterator& a, const ArrowChunkedArray::Iterator& b) { + return a.ptr_chunk_ == b.ptr_chunk_ && a.ptr_offset_ == b.ptr_offset_; +} + +template +bool operator!=(const ArrowChunkedArray::Iterator& a, const ArrowChunkedArray::Iterator& b) { + return a.ptr_chunk_ != b.ptr_chunk_ || a.ptr_offset_ != b.ptr_offset_; +} + +template +int64_t operator-(const ArrowChunkedArray::Iterator& a, + const ArrowChunkedArray::Iterator& b) { + auto full_offset_a = a.array_.chunk_offsets_[a.ptr_chunk_] + a.ptr_offset_; + auto full_offset_b = b.array_.chunk_offsets_[b.ptr_chunk_] + b.ptr_offset_; + return full_offset_a - full_offset_b; +} + +/* --------------------------------------- INDEX ACCESSOR -------------------------------------- */ + +/** + * @brief The value of "no value" for a primitive type. + * + * @tparam T The type for which the missing value is defined. + * @return T The missing value. + */ +template +inline T arrow_primitive_missing_value() { + return 0; +} + +template <> +inline double arrow_primitive_missing_value() { + return std::numeric_limits::quiet_NaN(); +} + +template <> +inline float arrow_primitive_missing_value() { + return std::numeric_limits::quiet_NaN(); +} + +template +struct ArrayIndexAccessor { + V operator()(const ArrowArray* array, size_t idx) { + auto buffer_idx = idx + array->offset; + + // For primitive types, buffer at idx 0 provides validity, buffer at idx 1 data, see: + // https://arrow.apache.org/docs/format/Columnar.html#buffer-listing-for-each-layout + auto validity = static_cast(array->buffers[0]); + + // Take return value from data buffer conditional on the validity of the index: + // - The structure of validity bitmasks is taken from here: + // https://arrow.apache.org/docs/format/Columnar.html#validity-bitmaps + // - If the bitmask is NULL, all indices are valid + if (validity == nullptr || !(validity[buffer_idx / 8] & (1 << (buffer_idx % 8)))) { + // In case the index is valid, we take it from the data buffer + auto data = static_cast(array->buffers[1]); + return static_cast(data[buffer_idx]); + } + + // In case the index is not valid, we return a default value + return arrow_primitive_missing_value(); + } +}; + +template +std::function get_index_accessor(const char* dtype) { + // Mapping obtained from: + // https://arrow.apache.org/docs/format/CDataInterface.html#data-type-description-format-strings + switch (dtype[0]) { + case 'c': + return ArrayIndexAccessor(); + case 'C': + return ArrayIndexAccessor(); + case 's': + return ArrayIndexAccessor(); + case 'S': + return ArrayIndexAccessor(); + case 'i': + return ArrayIndexAccessor(); + case 'I': + return ArrayIndexAccessor(); + case 'l': + return ArrayIndexAccessor(); + case 'L': + return ArrayIndexAccessor(); + case 'f': + return ArrayIndexAccessor(); + case 'g': + return ArrayIndexAccessor(); + default: + throw std::invalid_argument("unsupported Arrow datatype"); + } +} + +} // namespace LightGBM + +#endif diff --git a/include/LightGBM/c_api.h b/include/LightGBM/c_api.h index bba46a02a492..8727712a5f93 100644 --- a/include/LightGBM/c_api.h +++ b/include/LightGBM/c_api.h @@ -13,6 +13,7 @@ #ifndef LIGHTGBM_C_API_H_ #define LIGHTGBM_C_API_H_ +#include #include #ifdef __cplusplus @@ -437,6 +438,23 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromMats(int32_t nmat, const DatasetHandle reference, DatasetHandle* out); +/*! + * \brief Create dataset from Arrow. + * \param n_chunks The number of Arrow arrays passed to this function + * \param chunks Pointer to the list of Arrow arrays + * \param schema Pointer to the schema of all Arrow arrays + * \param parameters Additional parameters + * \param reference Used to align bin mapper with other dataset, nullptr means isn't used + * \param[out] out Created dataset + * \return 0 when succeed, -1 when failure happens + */ +LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromArrow(int64_t n_chunks, + const ArrowArray* chunks, + const ArrowSchema* schema, + const char* parameters, + const DatasetHandle reference, + DatasetHandle *out); + /*! * \brief Create subset of a data. * \param handle Handle of full dataset @@ -537,6 +555,25 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetSetField(DatasetHandle handle, int num_element, int type); +/*! + * \brief Set vector to a content in info. + * \note + * - \a group converts input datatype into ``int32``; + * - \a label and \a weight convert input datatype into ``float32``; + * - \a init_score converts input datatype into ``float64``. + * \param handle Handle of dataset + * \param field_name Field name, can be \a label, \a weight, \a init_score, \a group + * \param n_chunks The number of Arrow arrays passed to this function + * \param chunks Pointer to the list of Arrow arrays + * \param schema Pointer to the schema of all Arrow arrays + * \return 0 when succeed, -1 when failure happens + */ +LIGHTGBM_C_EXPORT int LGBM_DatasetSetFieldFromArrow(DatasetHandle handle, + const char* field_name, + int64_t n_chunks, + const ArrowArray* chunks, + const ArrowSchema* schema); + /*! * \brief Get info vector from dataset. * \param handle Handle of dataset @@ -1380,6 +1417,40 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMats(BoosterHandle handle, int64_t* out_len, double* out_result); +/*! + * \brief Make prediction for a new dataset. + * \note + * You should pre-allocate memory for ``out_result``: + * - for normal and raw score, its length is equal to ``num_class * num_data``; + * - for leaf index, its length is equal to ``num_class * num_data * num_iteration``; + * - for feature contributions, its length is equal to ``num_class * num_data * (num_feature + 1)``. + * \param handle Handle of booster + * \param n_chunks The number of Arrow arrays passed to this function + * \param chunks Pointer to the list of Arrow arrays + * \param schema Pointer to the schema of all Arrow arrays + * \param predict_type What should be predicted + * - ``C_API_PREDICT_NORMAL``: normal prediction, with transform (if needed); + * - ``C_API_PREDICT_RAW_SCORE``: raw score; + * - ``C_API_PREDICT_LEAF_INDEX``: leaf index; + * - ``C_API_PREDICT_CONTRIB``: feature contributions (SHAP values) + * \param start_iteration Start index of the iteration to predict + * \param num_iteration Number of iteration for prediction, <= 0 means no limit + * \param parameter Other parameters for prediction, e.g. early stopping for prediction + * \param[out] out_len Length of output result + * \param[out] out_result Pointer to array with predictions + * \return 0 when succeed, -1 when failure happens + */ +LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForArrow(BoosterHandle handle, + int64_t n_chunks, + const ArrowArray* chunks, + const ArrowSchema* schema, + int predict_type, + int start_iteration, + int num_iteration, + const char* parameter, + int64_t* out_len, + double* out_result); + /*! * \brief Save model into file. * \param handle Handle of booster diff --git a/include/LightGBM/dataset.h b/include/LightGBM/dataset.h index 825c5c6ebcf8..b73f753e3826 100644 --- a/include/LightGBM/dataset.h +++ b/include/LightGBM/dataset.h @@ -5,6 +5,7 @@ #ifndef LIGHTGBM_DATASET_H_ #define LIGHTGBM_DATASET_H_ +#include #include #include #include @@ -109,17 +110,20 @@ class Metadata { const std::vector& used_data_indices); void SetLabel(const label_t* label, data_size_t len); + void SetLabel(const ArrowChunkedArray& array); void SetWeights(const label_t* weights, data_size_t len); + void SetWeights(const ArrowChunkedArray& array); void SetQuery(const data_size_t* query, data_size_t len); + void SetQuery(const ArrowChunkedArray& array); /*! * \brief Set initial scores * \param init_score Initial scores, this class will manage memory for init_score. */ void SetInitScore(const double* init_score, data_size_t len); - + void SetInitScore(const ArrowChunkedArray& array); /*! * \brief Save binary data to file @@ -297,12 +301,24 @@ class Metadata { void CalculateQueryBoundaries(); /*! \brief Insert labels at the given index */ void InsertLabels(const label_t* labels, data_size_t start_index, data_size_t len); + /*! \brief Set labels from pointers to the first element and the end of an iterator. */ + template + void SetLabelsFromIterator(It first, It last); /*! \brief Insert weights at the given index */ void InsertWeights(const label_t* weights, data_size_t start_index, data_size_t len); + /*! \brief Set weights from pointers to the first element and the end of an iterator. */ + template + void SetWeightsFromIterator(It first, It last); /*! \brief Insert initial scores at the given index */ void InsertInitScores(const double* init_scores, data_size_t start_index, data_size_t len, data_size_t source_size); + /*! \brief Set init scores from pointers to the first element and the end of an iterator. */ + template + void SetInitScoresFromIterator(It first, It last); /*! \brief Insert queries at the given index */ void InsertQueries(const data_size_t* queries, data_size_t start_index, data_size_t len); + /*! \brief Set queries from pointers to the first element and the end of an iterator. */ + template + void SetQueriesFromIterator(It first, It last); /*! \brief Filename of current data */ std::string data_filename_; /*! \brief Number of data */ @@ -502,24 +518,29 @@ class Dataset { } } - inline void PushOneRow(int tid, data_size_t row_idx, const std::vector& feature_values) { - if (is_finish_load_) { return; } - for (size_t i = 0; i < feature_values.size() && i < static_cast(num_total_features_); ++i) { - int feature_idx = used_feature_map_[i]; - if (feature_idx >= 0) { - const int group = feature2group_[feature_idx]; - const int sub_feature = feature2subfeature_[feature_idx]; - feature_groups_[group]->PushData(tid, sub_feature, row_idx, feature_values[i]); - if (has_raw_) { - int feat_ind = numeric_feature_map_[feature_idx]; - if (feat_ind >= 0) { - raw_data_[feat_ind][row_idx] = static_cast(feature_values[i]); - } + inline void PushOneValue(int tid, data_size_t row_idx, size_t col_idx, double value) { + if (this->is_finish_load_) + return; + auto feature_idx = this->used_feature_map_[col_idx]; + if (feature_idx >= 0) { + auto group = this->feature2group_[feature_idx]; + auto sub_feature = this->feature2subfeature_[feature_idx]; + this->feature_groups_[group]->PushData(tid, sub_feature, row_idx, value); + if (this->has_raw_) { + auto feat_ind = numeric_feature_map_[feature_idx]; + if (feat_ind >= 0) { + raw_data_[feat_ind][row_idx] = static_cast(value); } } } } + inline void PushOneRow(int tid, data_size_t row_idx, const std::vector& feature_values) { + for (size_t i = 0; i < feature_values.size() && i < static_cast(num_total_features_); ++i) { + this->PushOneValue(tid, row_idx, i, feature_values[i]); + } + } + inline void PushOneRow(int tid, data_size_t row_idx, const std::vector>& feature_values) { if (is_finish_load_) { return; } std::vector is_feature_added(num_features_, false); @@ -606,6 +627,8 @@ class Dataset { LIGHTGBM_EXPORT void FinishLoad(); + bool SetFieldFromArrow(const char* field_name, const ArrowChunkedArray& ca); + LIGHTGBM_EXPORT bool SetFloatField(const char* field_name, const float* field_data, data_size_t num_element); LIGHTGBM_EXPORT bool SetDoubleField(const char* field_name, const double* field_data, data_size_t num_element); diff --git a/python-package/lightgbm/arrow.py b/python-package/lightgbm/arrow.py new file mode 100644 index 000000000000..f89c4a7e3fca --- /dev/null +++ b/python-package/lightgbm/arrow.py @@ -0,0 +1,54 @@ +# coding: utf-8 +"""Utilities for handling Arrow in LightGBM.""" +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Iterator, Union + +import pyarrow as pa +from pyarrow.cffi import ffi + + +@dataclass +class ArrowCArray: + """Simple wrapper around the C representation of an Arrow type.""" + + n_chunks: int + chunks: ffi.CData + schema: ffi.CData + + @property + def chunks_ptr(self) -> int: + """Returns the address of the pointer to the list of chunks making up the array.""" + return int(ffi.cast("uintptr_t", ffi.addressof(self.chunks[0]))) + + @property + def schema_ptr(self) -> int: + """Returns the address of the pointer to the schema of the array.""" + return int(ffi.cast("uintptr_t", self.schema)) + + +@contextmanager +def export_arrow_to_c(data: Union[pa.Table, pa.Array, pa.ChunkedArray]) -> Iterator[ArrowCArray]: + """Export an Arrow type to its C representation.""" + # Obtain objects to export + if isinstance(data, pa.Table): + export_objects = data.to_batches() + elif isinstance(data, pa.Array): + export_objects = [data] + elif isinstance(data, pa.ChunkedArray): + export_objects = data.chunks + + # Prepare export + chunks = ffi.new(f"struct ArrowArray[{len(export_objects)}]") + schema = ffi.new("struct ArrowSchema*") + + # Export all objects + for i, obj in enumerate(export_objects): + chunk_ptr = int(ffi.cast("uintptr_t", ffi.addressof(chunks[i]))) + if i == 0: + schema_ptr = int(ffi.cast("uintptr_t", schema)) + obj._export_to_c(chunk_ptr, schema_ptr) + else: + obj._export_to_c(chunk_ptr) + + yield ArrowCArray(len(chunks), chunks, schema) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 606dccefa6f7..a92a9cbb7764 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -18,9 +18,16 @@ import numpy as np import scipy.sparse -from .compat import PANDAS_INSTALLED, concat, dt_DataTable, pd_CategoricalDtype, pd_DataFrame, pd_Series +from .compat import PANDAS_INSTALLED, concat, dt_DataTable, pd_CategoricalDtype, pd_DataFrame, pd_Series, pa_Table, pa_Array, pa_ChunkedArray from .libpath import find_lib_path +try: + import pyarrow as pa + + from .arrow import export_arrow_to_c +except ImportError: + pass + if TYPE_CHECKING: from typing import Literal @@ -60,7 +67,9 @@ List[float], List[int], np.ndarray, - pd_Series + pd_Series, + pa_Array, + pa_ChunkedArray ] _LGBM_InitScoreType = Union[ List[float], @@ -68,6 +77,9 @@ np.ndarray, pd_Series, pd_DataFrame, + pa_Array, + pa_ChunkedArray, + pa_Table ] _LGBM_TrainDataType = Union[ str, @@ -78,14 +90,17 @@ scipy.sparse.spmatrix, "Sequence", List["Sequence"], - List[np.ndarray] + List[np.ndarray], + pa_Table ] _LGBM_LabelType = Union[ List[float], List[int], np.ndarray, pd_Series, - pd_DataFrame + pd_DataFrame, + pa_Array, + pa_ChunkedArray ] _LGBM_PredictDataType = Union[ str, @@ -93,13 +108,16 @@ np.ndarray, pd_DataFrame, dt_DataTable, - scipy.sparse.spmatrix + scipy.sparse.spmatrix, + pa_Table ] _LGBM_WeightType = Union[ List[float], List[int], np.ndarray, - pd_Series + pd_Series, + pa_Array, + pa_ChunkedArray ] ZERO_THRESHOLD = 1e-35 @@ -325,6 +343,16 @@ def _is_2d_collection(data: Any) -> bool: ) +def _is_pyarrow_type(data: Any) -> bool: + """Check whether data is a PyArrow type.""" + return isinstance(data, (pa_Table, pa_Array, pa_ChunkedArray)) + + +def _is_pyarrow_table(data: Any) -> bool: + """Check whether data is a PyArrow table.""" + return isinstance(data, pa_Table) + + def _data_to_2d_numpy( data: Any, dtype: "np.typing.DTypeLike", @@ -1043,6 +1071,13 @@ def predict( num_iteration=num_iteration, predict_type=predict_type ) + elif _is_pyarrow_table(data): + preds, nrow = self.__pred_for_arrow_table( + table=data, + start_iteration=start_iteration, + num_iteration=num_iteration, + predict_type=predict_type + ) elif isinstance(data, list): try: data = np.array(data) @@ -1497,6 +1532,45 @@ def __pred_for_csc( raise ValueError("Wrong length for predict results") return preds, nrow + def __pred_for_arrow_table( + self, + table: pa_Table, + start_iteration: int, + num_iteration: int, + predict_type: int + ) -> Tuple[np.ndarray, int]: + """Predict for a PyArrow table.""" + # Check that the input is valid: we only handle numbers (for now) + if not all(pa.types.is_integer(t) or pa.types.is_floating(t) for t in table.schema.types): + raise ValueError("Arrow table may only have integer or floating point datatypes") + + # Prepare prediction output array + n_preds = self.__get_num_preds( + start_iteration=start_iteration, + num_iteration=num_iteration, + nrow=table.num_rows, + predict_type=predict_type + ) + preds = np.empty(n_preds, dtype=np.float64) + out_num_preds = ctypes.c_int64(0) + + # Export Arrow table to C and run prediction + with export_arrow_to_c(table) as c_array: + _safe_call(_LIB.LGBM_BoosterPredictForArrow( + self._handle, + ctypes.c_int64(c_array.n_chunks), + ctypes.c_void_p(c_array.chunks_ptr), + ctypes.c_void_p(c_array.schema_ptr), + ctypes.c_int(predict_type), + ctypes.c_int(start_iteration), + ctypes.c_int(num_iteration), + _c_str(self.pred_parameter), + ctypes.byref(out_num_preds), + preds.ctypes.data_as(ctypes.POINTER(ctypes.c_double)))) + if n_preds != out_num_preds.value: + raise ValueError("Wrong length for predict results") + return preds, table.num_rows + def current_iteration(self) -> int: """Get the index of the current iteration. @@ -1532,26 +1606,26 @@ def __init__( Parameters ---------- - data : str, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequence or list of numpy array + data : str, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequence, list of numpy array or pyarrow Table Data source of Dataset. If str or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM) or a LightGBM Dataset binary file. - label : list, numpy 1-D array, pandas Series / one-column DataFrame or None, optional (default=None) + label : list, numpy 1-D array, pandas Series / one-column DataFrame, pyarrow Array or None, optional (default=None) Label of the data. reference : Dataset or None, optional (default=None) If this is Dataset for validation, training data should be used as reference. - weight : list, numpy 1-D array, pandas Series or None, optional (default=None) + weight : list, numpy 1-D array, pandas Series, pyarrow Array or None, optional (default=None) Weight for each instance. Weights should be non-negative. - group : list, numpy 1-D array, pandas Series or None, optional (default=None) + group : list, numpy 1-D array, pandas Series, pyarrow Array or None, optional (default=None) Group/query data. Only used in the learning-to-rank task. sum(group) = n_samples. For example, if you have a 100-document dataset with ``group = [10, 20, 40, 10, 10, 10]``, that means that you have 6 groups, where the first 10 records are in the first group, records 11-30 are in the second group, records 31-70 are in the third group, etc. - init_score : list, list of lists (for multi-class task), numpy array, pandas Series, pandas DataFrame (for multi-class task), or None, optional (default=None) + init_score : list, list of lists (for multi-class task), numpy array, pandas Series, pandas DataFrame (for multi-class task), pyarrow Array, pyarrow Table (for multi-class task), or None, optional (default=None) Init score for Dataset. feature_name : list of str, or 'auto', optional (default="auto") Feature names. - If 'auto' and data is pandas DataFrame, data columns names are used. + If 'auto' and data is pandas DataFrame or pyarrow Table, data columns names are used. categorical_feature : list of str or int, or 'auto', optional (default="auto") Categorical features. If list of int, interpreted as indices. @@ -1901,6 +1975,9 @@ def _lazy_init( self.__init_from_csc(data, params_str, ref_dataset) elif isinstance(data, np.ndarray): self.__init_from_np2d(data, params_str, ref_dataset) + elif _is_pyarrow_table(data): + self.__init_from_pyarrow_table(data, params_str, ref_dataset) + feature_name = data.column_names elif isinstance(data, list) and len(data) > 0: if all(isinstance(x, np.ndarray) for x in data): self.__init_from_list_np2d(data, params_str, ref_dataset) @@ -2159,6 +2236,29 @@ def __init_from_csc( ctypes.byref(self._handle))) return self + def __init_from_pyarrow_table( + self, + table: pa_Table, + params_str: str, + ref_dataset: Optional[_DatasetHandle] + ) -> "Dataset": + """Initialize data from a PyArrow table.""" + # Check that the input is valid: we only handle numbers (for now) + if not all(pa.types.is_integer(t) or pa.types.is_floating(t) for t in table.schema.types): + raise ValueError("Arrow table may only have integer or floating point datatypes") + + # Export Arrow table to C + with export_arrow_to_c(table) as c_array: + self._handle = ctypes.c_void_p() + _safe_call(_LIB.LGBM_DatasetCreateFromArrow( + ctypes.c_int64(c_array.n_chunks), + ctypes.c_void_p(c_array.chunks_ptr), + ctypes.c_void_p(c_array.schema_ptr), + _c_str(params_str), + ref_dataset, + ctypes.byref(self._handle))) + return self + @staticmethod def _compare_params_for_warning( params: Optional[Dict[str, Any]], @@ -2394,7 +2494,7 @@ def _reverse_update_params(self) -> "Dataset": def set_field( self, field_name: str, - data: Optional[Union[List[List[float]], List[List[int]], List[float], List[int], np.ndarray, pd_Series, pd_DataFrame]] + data: Optional[Union[List[List[float]], List[List[int]], List[float], List[int], np.ndarray, pd_Series, pd_DataFrame, pa_Table, pa_Array, pa_ChunkedArray]] ) -> "Dataset": """Set property into the Dataset. @@ -2402,7 +2502,7 @@ def set_field( ---------- field_name : str The field name of the information. - data : list, list of lists (for multi-class task), numpy array, pandas Series, pandas DataFrame (for multi-class task), or None + data : list, list of lists (for multi-class task), numpy array, pandas Series, pandas DataFrame (for multi-class task), pyarrow Table, pyarrow Array, or None The data to be set. Returns @@ -2421,6 +2521,31 @@ def set_field( ctypes.c_int(0), ctypes.c_int(_FIELD_TYPE_MAPPER[field_name]))) return self + + # If the data is arrow data, we can just pass it to C + if _is_pyarrow_type(data): + # If a table is being passed, we concatenate the columns. This is only valid for + # 'init_score'. + if isinstance(data, pa_Table): + if field_name != "init_score": + raise ValueError("pyarrow table provided for field other than init_score") + data = pa.chunked_array([ + chunk for array in data.columns for chunk in array.chunks + ]) + + with export_arrow_to_c(data) as c_array: + _safe_call(_LIB.LGBM_DatasetSetFieldFromArrow( + self._handle, + _c_str(field_name), + ctypes.c_int64(c_array.n_chunks), + ctypes.c_void_p(c_array.chunks_ptr), + ctypes.c_void_p(c_array.schema_ptr), + )) + + self.version += 1 + return self + + # Otherwise, we have to do some more work dtype: "np.typing.DTypeLike" if field_name == 'init_score': dtype = np.float64 @@ -2624,7 +2749,7 @@ def set_label(self, label: Optional[_LGBM_LabelType]) -> "Dataset": Parameters ---------- - label : list, numpy 1-D array, pandas Series / one-column DataFrame or None + label : list, numpy 1-D array, pandas Series / one-column DataFrame, pyarrow Array or None The label information to be set into Dataset. Returns @@ -2649,6 +2774,8 @@ def set_label(self, label: Optional[_LGBM_LabelType]) -> "Dataset": # data has nullable dtypes, but we can specify na_value argument and copy will be made label = label.to_numpy(dtype=np.float32, na_value=np.nan) label_array = np.ravel(label) + elif _is_pyarrow_type(label): + label_array = label else: label_array = _list_to_1d_numpy(label, dtype=np.float32, name='label') self.set_field('label', label_array) @@ -2663,7 +2790,7 @@ def set_weight( Parameters ---------- - weight : list, numpy 1-D array, pandas Series or None + weight : list, numpy 1-D array, pandas Series, pyarrow Array or None Weight to be set for each data point. Weights should be non-negative. Returns @@ -2671,11 +2798,19 @@ def set_weight( self : Dataset Dataset with set weight. """ - if weight is not None and np.all(weight == 1): - weight = None + # Check if the weight contains values other than one + if weight is not None: + if _is_pyarrow_type(weight): + if pa.compute.all(pa.compute.equal(weight, 1)).as_py(): + weight = None + elif np.all(weight == 1): + weight = None self.weight = weight + + # Set field if self._handle is not None and weight is not None: - weight = _list_to_1d_numpy(weight, dtype=np.float32, name='weight') + if not _is_pyarrow_type(weight): + weight = _list_to_1d_numpy(weight, dtype=np.float32, name='weight') self.set_field('weight', weight) self.weight = self.get_field('weight') # original values can be modified at cpp side return self @@ -2688,7 +2823,7 @@ def set_init_score( Parameters ---------- - init_score : list, list of lists (for multi-class task), numpy array, pandas Series, pandas DataFrame (for multi-class task), or None + init_score : list, list of lists (for multi-class task), numpy array, pandas Series, pandas DataFrame (for multi-class task), pyarrow Array, pyarrow Table (for multi-class task), or None Init score for Booster. Returns @@ -2710,7 +2845,7 @@ def set_group( Parameters ---------- - group : list, numpy 1-D array, pandas Series or None + group : list, numpy 1-D array, pandas Series, pyarrow Array or None Group/query data. Only used in the learning-to-rank task. sum(group) = n_samples. @@ -2724,7 +2859,8 @@ def set_group( """ self.group = group if self._handle is not None and group is not None: - group = _list_to_1d_numpy(group, dtype=np.int32, name='group') + if not _is_pyarrow_type(group): + group = _list_to_1d_numpy(group, dtype=np.int32, name='group') self.set_field('group', group) return self diff --git a/python-package/lightgbm/compat.py b/python-package/lightgbm/compat.py index 0a55ccd1e421..c5ffa34a6907 100644 --- a/python-package/lightgbm/compat.py +++ b/python-package/lightgbm/compat.py @@ -185,6 +185,33 @@ class dask_Series: # type: ignore def __init__(self, *args, **kwargs): pass +"""pyarrow""" +try: + from pyarrow import Array as pa_Array + from pyarrow import ChunkedArray as pa_ChunkedArray + from pyarrow import Table as pa_Table + PYARROW_INSTALLED = True +except ImportError: + PYARROW_INSTALLED = False + + class pa_Table: # type: ignore + """Dummy class for pa.Table.""" + + def __init__(self, *args, **kwargs): + pass + + class pa_Array: # type: ignore + """Dummy class for pa.Array.""" + + def __init__(self, *args, **kwargs): + pass + + class pa_ChunkedArray: # type: ignore + """Dummy class for pa.ChunkedArray.""" + + def __init__(self, *args, **kwargs): + pass + """cpu_count()""" try: from joblib import cpu_count diff --git a/python-package/pyproject.toml b/python-package/pyproject.toml index d3ff28286bb9..126d9fa4fe7f 100644 --- a/python-package/pyproject.toml +++ b/python-package/pyproject.toml @@ -32,6 +32,10 @@ requires-python = ">=3.6" version = "4.0.0.99" [project.optional-dependencies] +arrow = [ + "cffi>=1.15.1", + "pyarrow>=12.0.0" +] dask = [ "dask[array,dataframe,distributed]>=2.0.0", "pandas>=0.24.0" diff --git a/src/c_api.cpp b/src/c_api.cpp index 442247d7a9dd..2f5250e8eaec 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -4,6 +4,7 @@ */ #include +#include #include #include #include @@ -832,6 +833,8 @@ class Booster { // explicitly declare symbols from LightGBM namespace using LightGBM::AllgatherFunction; +using LightGBM::ArrowChunkedArray; +using LightGBM::ArrowTable; using LightGBM::Booster; using LightGBM::Common::CheckElementsIntervalClosed; using LightGBM::Common::RemoveQuotationSymbol; @@ -1567,6 +1570,98 @@ int LGBM_DatasetCreateFromCSC(const void* col_ptr, API_END(); } +int LGBM_DatasetCreateFromArrow(int64_t n_chunks, + const ArrowArray* chunks, + const ArrowSchema* schema, + const char* parameters, + const DatasetHandle reference, + DatasetHandle *out) { + API_BEGIN(); + + auto param = Config::Str2Map(parameters); + Config config; + config.Set(param); + OMP_SET_NUM_THREADS(config.num_threads); + + std::unique_ptr ret; + + // Prepare the Arrow data + ArrowTable table(n_chunks, chunks, schema); + + // Initialize the dataset + if (reference == nullptr) { + // If there is no reference dataset, we first sample indices + auto sample_indices = CreateSampleIndices(static_cast(table.get_num_rows()), config); + auto sample_count = static_cast(sample_indices.size()); + std::vector> sample_values(table.get_num_columns()); + std::vector> sample_idx(table.get_num_columns()); + + // Then, we obtain sample values by parallelizing across columns + OMP_INIT_EX(); + #pragma omp parallel for schedule(static) + for (int64_t j = 0; j < table.get_num_columns(); ++j) { + OMP_LOOP_EX_BEGIN(); + + // Values need to be copied from the record batches. + sample_values[j].reserve(sample_indices.size()); + sample_idx[j].reserve(sample_indices.size()); + + // The chunks are iterated over in the inner loop as columns can be treated independently. + int last_idx = 0; + int i = 0; + auto it = table.get_column(j).begin(); + for (auto idx : sample_indices) { + std::advance(it, idx - last_idx); + auto v = *it; + if (std::fabs(v) > kZeroThreshold || std::isnan(v)) { + sample_values[j].emplace_back(v); + sample_idx[j].emplace_back(i); + } + last_idx = idx; + i++; + } + OMP_LOOP_EX_END(); + } + OMP_THROW_EX(); + + // Finally, we initialize a loader from the sampled values + DatasetLoader loader(config, nullptr, 1, nullptr); + ret.reset(loader.ConstructFromSampleData(Vector2Ptr(&sample_values).data(), + Vector2Ptr(&sample_idx).data(), + table.get_num_columns(), + VectorSize(sample_values).data(), + sample_count, + table.get_num_rows(), + table.get_num_rows())); + } else { + ret.reset(new Dataset(static_cast(table.get_num_rows()))); + ret->CreateValid(reinterpret_cast(reference)); + if (ret->has_raw()) { + ret->ResizeRaw(static_cast(table.get_num_rows())); + } + } + + // After sampling and properly initializing all bins, we can add our data to the dataset. Here, + // we parallelize across rows. + OMP_INIT_EX(); + #pragma omp parallel for schedule(static) + for (int64_t j = 0; j < table.get_num_columns(); ++j) { + OMP_LOOP_EX_BEGIN(); + const int tid = omp_get_thread_num(); + data_size_t idx = 0; + auto column = table.get_column(j); + for (auto it = column.begin(), end = column.end(); it != end; ++it) { + ret->PushOneValue(tid, idx++, j, *it); + } + OMP_LOOP_EX_END(); + } + OMP_THROW_EX(); + + ret->FinishLoad(); + *out = ret.release(); + API_END(); +} + int LGBM_DatasetGetSubset( const DatasetHandle handle, const int32_t* used_row_indices, @@ -1686,6 +1781,21 @@ int LGBM_DatasetSetField(DatasetHandle handle, API_END(); } +int LGBM_DatasetSetFieldFromArrow(DatasetHandle handle, + const char* field_name, + int64_t n_chunks, + const ArrowArray* chunks, + const ArrowSchema* schema) { + API_BEGIN(); + auto dataset = reinterpret_cast(handle); + ArrowChunkedArray ca(n_chunks, chunks, schema); + auto is_success = dataset->SetFieldFromArrow(field_name, ca); + if (!is_success) { + Log::Fatal("Input field not found"); + } + API_END(); +} + int LGBM_DatasetGetField(DatasetHandle handle, const char* field_name, int* out_len, @@ -2458,6 +2568,57 @@ int LGBM_BoosterPredictForMats(BoosterHandle handle, API_END(); } +int LGBM_BoosterPredictForArrow(BoosterHandle handle, + int64_t n_chunks, + const ArrowArray* chunks, + const ArrowSchema* schema, + int predict_type, + int start_iteration, + int num_iteration, + const char* parameter, + int64_t* out_len, + double* out_result) { + API_BEGIN(); + + // Apply the configuration + auto param = Config::Str2Map(parameter); + Config config; + config.Set(param); + OMP_SET_NUM_THREADS(config.num_threads); + + // Set up chunked array and iterators for all columns + ArrowTable table(n_chunks, chunks, schema); + std::vector> its; + its.reserve(table.get_num_columns()); + for (int64_t j = 0; j < table.get_num_columns(); ++j) { + its.emplace_back(table.get_column(j).begin()); + } + + // Build row function + auto num_columns = table.get_num_columns(); + auto row_fn = [num_columns, &its] (int row_idx) { + std::vector> result; + result.reserve(num_columns); + for (int64_t j = 0; j < num_columns; ++j) { + result.emplace_back(static_cast(j), its[j][row_idx]); + } + return result; + }; + + // Run prediction + Booster* ref_booster = reinterpret_cast(handle); + ref_booster->Predict(start_iteration, + num_iteration, + predict_type, + static_cast(table.get_num_rows()), + static_cast(table.get_num_columns()), + row_fn, + config, + out_result, + out_len); + API_END(); +} + int LGBM_BoosterSaveModel(BoosterHandle handle, int start_iteration, int num_iteration, diff --git a/src/io/dataset.cpp b/src/io/dataset.cpp index 5b23f01ec3a0..98f58e1d08d0 100644 --- a/src/io/dataset.cpp +++ b/src/io/dataset.cpp @@ -897,6 +897,23 @@ void Dataset::CopySubrow(const Dataset* fullset, #endif // USE_CUDA } +bool Dataset::SetFieldFromArrow(const char* field_name, const ArrowChunkedArray &ca) { + std::string name(field_name); + name = Common::Trim(name); + if (name == std::string("label") || name == std::string("target")) { + metadata_.SetLabel(ca); + } else if (name == std::string("weight") || name == std::string("weights")) { + metadata_.SetWeights(ca); + } else if (name == std::string("init_score")) { + metadata_.SetInitScore(ca); + } else if (name == std::string("query") || name == std::string("group")) { + metadata_.SetQuery(ca); + } else { + return false; + } + return true; +} + bool Dataset::SetFloatField(const char* field_name, const float* field_data, data_size_t num_element) { std::string name(field_name); diff --git a/src/io/metadata.cpp b/src/io/metadata.cpp index 2a589fa24ef8..3e182fbf7a9c 100644 --- a/src/io/metadata.cpp +++ b/src/io/metadata.cpp @@ -325,32 +325,44 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector +void Metadata::SetInitScoresFromIterator(It first, It last) { std::lock_guard lock(mutex_); - // save to nullptr - if (init_score == nullptr || len == 0) { + // Clear init scores on empty input + if (last - first == 0) { init_score_.clear(); num_init_score_ = 0; return; } - if ((len % num_data_) != 0) { + if (((last - first) % num_data_) != 0) { Log::Fatal("Initial score size doesn't match data size"); } - if (init_score_.empty()) { init_score_.resize(len); } - num_init_score_ = len; + if (init_score_.empty()) { + init_score_.resize(last - first); + } + num_init_score_ = last - first; #pragma omp parallel for schedule(static, 512) if (num_init_score_ >= 1024) for (int64_t i = 0; i < num_init_score_; ++i) { - init_score_[i] = Common::AvoidInf(init_score[i]); + init_score_[i] = Common::AvoidInf(first[i]); } init_score_load_from_file_ = false; + #ifdef USE_CUDA if (cuda_metadata_ != nullptr) { - cuda_metadata_->SetInitScore(init_score_.data(), len); + cuda_metadata_->SetInitScore(init_score_.data(), init_score_.size()); } #endif // USE_CUDA } +void Metadata::SetInitScore(const double* init_score, data_size_t len) { + SetInitScoresFromIterator(init_score, init_score + len); +} + +void Metadata::SetInitScore(const ArrowChunkedArray& array) { + SetInitScoresFromIterator(array.begin(), array.end()); +} + void Metadata::InsertInitScores(const double* init_scores, data_size_t start_index, data_size_t len, data_size_t source_size) { if (num_init_score_ <= 0) { Log::Fatal("Inserting initial score data into dataset with no initial scores"); @@ -373,27 +385,39 @@ void Metadata::InsertInitScores(const double* init_scores, data_size_t start_ind // CUDA is handled after all insertions are complete } -void Metadata::SetLabel(const label_t* label, data_size_t len) { +template +void Metadata::SetLabelsFromIterator(It first, It last) { std::lock_guard lock(mutex_); - if (label == nullptr) { - Log::Fatal("label cannot be nullptr"); + if (num_data_ != last - first) { + Log::Fatal("Length of labels differs from the length of #data"); } - if (num_data_ != len) { - Log::Fatal("Length of label is not same with #data"); + if (label_.empty()) { + label_.resize(num_data_); } - if (label_.empty()) { label_.resize(num_data_); } #pragma omp parallel for schedule(static, 512) if (num_data_ >= 1024) for (data_size_t i = 0; i < num_data_; ++i) { - label_[i] = Common::AvoidInf(label[i]); + label_[i] = Common::AvoidInf(first[i]); } + #ifdef USE_CUDA if (cuda_metadata_ != nullptr) { - cuda_metadata_->SetLabel(label_.data(), len); + cuda_metadata_->SetLabel(label_.data(), label_.size()); } #endif // USE_CUDA } +void Metadata::SetLabel(const label_t* label, data_size_t len) { + if (label == nullptr) { + Log::Fatal("label cannot be nullptr"); + } + SetLabelsFromIterator(label, label + len); +} + +void Metadata::SetLabel(const ArrowChunkedArray& array) { + SetLabelsFromIterator(array.begin(), array.end()); +} + void Metadata::InsertLabels(const label_t* labels, data_size_t start_index, data_size_t len) { if (labels == nullptr) { Log::Fatal("label cannot be nullptr"); @@ -408,33 +432,45 @@ void Metadata::InsertLabels(const label_t* labels, data_size_t start_index, data // CUDA is handled after all insertions are complete } -void Metadata::SetWeights(const label_t* weights, data_size_t len) { +template +void Metadata::SetWeightsFromIterator(It first, It last) { std::lock_guard lock(mutex_); - // save to nullptr - if (weights == nullptr || len == 0) { + // Clear weights on empty input + if (last - first == 0) { weights_.clear(); num_weights_ = 0; return; } - if (num_data_ != len) { - Log::Fatal("Length of weights is not same with #data"); + if (num_data_ != last - first) { + Log::Fatal("Length of weights differs from the length of #data"); + } + if (weights_.empty()) { + weights_.resize(num_data_); } - if (weights_.empty()) { weights_.resize(num_data_); } num_weights_ = num_data_; #pragma omp parallel for schedule(static, 512) if (num_weights_ >= 1024) for (data_size_t i = 0; i < num_weights_; ++i) { - weights_[i] = Common::AvoidInf(weights[i]); + weights_[i] = Common::AvoidInf(first[i]); } CalculateQueryWeights(); weight_load_from_file_ = false; + #ifdef USE_CUDA if (cuda_metadata_ != nullptr) { - cuda_metadata_->SetWeights(weights_.data(), len); + cuda_metadata_->SetWeights(weights_.data(), weights_.size()); } #endif // USE_CUDA } +void Metadata::SetWeights(const label_t* weights, data_size_t len) { + SetWeightsFromIterator(weights, weights + len); +} + +void Metadata::SetWeights(const ArrowChunkedArray& array) { + SetWeightsFromIterator(array.begin(), array.end()); +} + void Metadata::InsertWeights(const label_t* weights, data_size_t start_index, data_size_t len) { if (!weights) { Log::Fatal("Passed null weights"); @@ -453,30 +489,34 @@ void Metadata::InsertWeights(const label_t* weights, data_size_t start_index, da // CUDA is handled after all insertions are complete } -void Metadata::SetQuery(const data_size_t* query, data_size_t len) { +template +void Metadata::SetQueriesFromIterator(It first, It last) { std::lock_guard lock(mutex_); - // save to nullptr - if (query == nullptr || len == 0) { + // Clear weights on empty input + if (last - first == 0) { query_boundaries_.clear(); num_queries_ = 0; return; } + data_size_t sum = 0; #pragma omp parallel for schedule(static) reduction(+:sum) - for (data_size_t i = 0; i < len; ++i) { - sum += query[i]; + for (data_size_t i = 0; i < last - first; ++i) { + sum += first[i]; } if (num_data_ != sum) { - Log::Fatal("Sum of query counts is not same with #data"); + Log::Fatal("Sum of query counts differs from the length of #data"); } - num_queries_ = len; + num_queries_ = last - first; + query_boundaries_.resize(num_queries_ + 1); query_boundaries_[0] = 0; for (data_size_t i = 0; i < num_queries_; ++i) { - query_boundaries_[i + 1] = query_boundaries_[i] + query[i]; + query_boundaries_[i + 1] = query_boundaries_[i] + first[i]; } CalculateQueryWeights(); query_load_from_file_ = false; + #ifdef USE_CUDA if (cuda_metadata_ != nullptr) { if (query_weights_.size() > 0) { @@ -489,6 +529,14 @@ void Metadata::SetQuery(const data_size_t* query, data_size_t len) { #endif // USE_CUDA } +void Metadata::SetQuery(const data_size_t* query, data_size_t len) { + SetQueriesFromIterator(query, query + len); +} + +void Metadata::SetQuery(const ArrowChunkedArray& array) { + SetQueriesFromIterator(array.begin(), array.end()); +} + void Metadata::InsertQueries(const data_size_t* queries, data_size_t start_index, data_size_t len) { if (!queries) { Log::Fatal("Passed null queries"); diff --git a/tests/cpp_tests/test_arrow.cpp b/tests/cpp_tests/test_arrow.cpp new file mode 100644 index 000000000000..52fe695b0824 --- /dev/null +++ b/tests/cpp_tests/test_arrow.cpp @@ -0,0 +1,208 @@ +/*! + * Copyright (c) 2023 Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE file in the project root for license information. + * + * Author: Oliver Borchert + */ + +#include +#include +#include "../include/LightGBM/arrow.h" + +using LightGBM::ArrowChunkedArray; +using LightGBM::ArrowTable; + +class ArrowChunkedArrayTest : public testing::Test { + protected: + void SetUp() override {} + + ArrowArray created_nested_array(const std::vector& arrays) { + ArrowArray arr; + arr.buffers = nullptr; + arr.children = (ArrowArray**)arrays.data(); + arr.dictionary = nullptr; + arr.length = arrays[0]->length; + arr.n_buffers = 0; + arr.n_children = arrays.size(); + arr.null_count = 0; + arr.offset = 0; + arr.private_data = nullptr; + arr.release = nullptr; + return arr; + } + + template + ArrowArray create_primitive_array(const std::vector& values, + int64_t offset = 0, + std::vector null_indices = {}) { + // NOTE: Arrow arrays have 64-bit alignment but we can safely ignore this in tests + // 1) Create validity bitmap + char* validity = nullptr; + if (!null_indices.empty()) { + validity = static_cast(calloc(values.size() + sizeof(char) - 1, sizeof(char))); + for (size_t i = 0; i < values.size(); ++i) { + if (std::find(null_indices.begin(), null_indices.end(), i) != null_indices.end()) { + validity[i / 8] |= (1 << (i % 8)); + } + } + } + + // 2) Create buffers + const void** buffers = (const void**)malloc(sizeof(void*) * 2); + buffers[0] = validity; + buffers[1] = values.data() + offset; + + // Create arrow array + ArrowArray arr; + arr.buffers = buffers; + arr.children = nullptr; + arr.dictionary = nullptr; + arr.length = values.size() - offset; + arr.null_count = 0; + arr.offset = 0; + arr.private_data = nullptr; + arr.release = [](ArrowArray* arr) { + if (arr->buffers[0] != nullptr) + free((void*)arr->buffers[0]); + free((void*)arr->buffers); + }; + return arr; + } + + ArrowSchema create_nested_schema(const std::vector& arrays) { + ArrowSchema schema; + schema.format = "+s"; + schema.name = nullptr; + schema.metadata = nullptr; + schema.flags = 0; + schema.n_children = arrays.size(); + schema.children = (ArrowSchema**)arrays.data(); + schema.dictionary = nullptr; + schema.private_data = nullptr; + schema.release = nullptr; + return schema; + } + + template + ArrowSchema create_primitive_schema() { + std::logic_error("not implemented"); + } + + template <> + ArrowSchema create_primitive_schema() { + ArrowSchema schema; + schema.format = "f"; + schema.name = nullptr; + schema.metadata = nullptr; + schema.flags = 0; + schema.n_children = 0; + schema.children = nullptr; + schema.dictionary = nullptr; + schema.private_data = nullptr; + schema.release = nullptr; + return schema; + } +}; + +TEST_F(ArrowChunkedArrayTest, GetLength) { + std::vector dat1 = {1, 2}; + auto arr1 = create_primitive_array(dat1); + + ArrowChunkedArray ca1(1, &arr1, nullptr); + ASSERT_EQ(ca1.get_length(), 2); + + std::vector dat2 = {3, 4, 5, 6}; + auto arr2 = create_primitive_array(dat2); + ArrowArray arrs[2] = {arr1, arr2}; + ArrowChunkedArray ca2(2, arrs, nullptr); + ASSERT_EQ(ca2.get_length(), 6); + + arr1.release(&arr1); + arr2.release(&arr2); +} + +TEST_F(ArrowChunkedArrayTest, GetColumns) { + std::vector dat1 = {1, 2, 3}; + auto arr1 = create_primitive_array(dat1); + std::vector dat2 = {4, 5, 6}; + auto arr2 = create_primitive_array(dat2); + std::vector arrs = {&arr1, &arr2}; + auto arr = created_nested_array(arrs); + + auto schema1 = create_primitive_schema(); + auto schema2 = create_primitive_schema(); + std::vector schemas = {&schema1, &schema2}; + auto schema = create_nested_schema(schemas); + + ArrowTable table(1, &arr, &schema); + ASSERT_EQ(table.get_num_rows(), 3); + ASSERT_EQ(table.get_num_columns(), 2); + + auto ca1 = table.get_column(0); + ASSERT_EQ(ca1.get_length(), 3); + ASSERT_EQ(*ca1.begin(), 1); + + auto ca2 = table.get_column(1); + ASSERT_EQ(ca2.get_length(), 3); + ASSERT_EQ(*ca2.begin(), 4); + + arr1.release(&arr1); + arr2.release(&arr2); +} + +TEST_F(ArrowChunkedArrayTest, IteratorArithmetic) { + std::vector dat1 = {1, 2}; + auto arr1 = create_primitive_array(dat1); + std::vector dat2 = {3, 4, 5, 6}; + auto arr2 = create_primitive_array(dat2); + std::vector dat3 = {7}; + auto arr3 = create_primitive_array(dat3); + auto schema = create_primitive_schema(); + + ArrowArray arrs[3] = {arr1, arr2, arr3}; + ArrowChunkedArray ca(3, arrs, &schema); + + // Arithmetic + auto it = ca.begin(); + ASSERT_EQ(*it, 1); + ++it; + ASSERT_EQ(*it, 2); + ++it; + ASSERT_EQ(*it, 3); + it += 2; + ASSERT_EQ(*it, 5); + it += 2; + ASSERT_EQ(*it, 7); + --it; + ASSERT_EQ(*it, 6); + + // Subscripts + ASSERT_EQ(it[0], 1); + ASSERT_EQ(it[1], 2); + ASSERT_EQ(it[2], 3); + ASSERT_EQ(it[6], 7); + + // End + auto end = ca.end(); + ASSERT_EQ(end - it, 2); + ASSERT_EQ(end - ca.begin(), 7); + + arr1.release(&arr1); + arr2.release(&arr2); + arr2.release(&arr3); +} + +TEST_F(ArrowChunkedArrayTest, OffsetAndValidity) { + std::vector dat = {0, 1, 2, 3, 4, 5, 6}; + auto arr = create_primitive_array(dat, 2, {0, 1}); + auto schema = create_primitive_schema(); + ArrowChunkedArray ca(1, &arr, &schema); + + auto it = ca.begin(); + ASSERT_TRUE(std::isnan(*it)); + ASSERT_TRUE(std::isnan(*(++it))); + ASSERT_EQ(it[2], 4); + ASSERT_EQ(it[4], 6); + + arr.release(&arr); +} diff --git a/tests/python_package_test/test_arrow.py b/tests/python_package_test/test_arrow.py new file mode 100644 index 000000000000..723a713ea56d --- /dev/null +++ b/tests/python_package_test/test_arrow.py @@ -0,0 +1,187 @@ +# coding: utf-8 +import filecmp +import tempfile +from pathlib import Path +from typing import Any, Dict + +import numpy as np +import pyarrow as pa +import pytest + +import lightgbm as lgb + +# ----------------------------------------------------------------------------------------------- # +# UTILITIES # +# ----------------------------------------------------------------------------------------------- # + + +def generate_simple_arrow_table() -> pa.Table: + columns = [ + pa.chunked_array([[1, 2, 3, 4, 5]], type=pa.uint8()), + pa.chunked_array([[1, 2, 3, 4, 5]], type=pa.int8()), + pa.chunked_array([[1, 2, 3, 4, 5]], type=pa.uint16()), + pa.chunked_array([[1, 2, 3, 4, 5]], type=pa.int16()), + pa.chunked_array([[1, 2, 3, 4, 5]], type=pa.uint32()), + pa.chunked_array([[1, 2, 3, 4, 5]], type=pa.int32()), + pa.chunked_array([[1, 2, 3, 4, 5]], type=pa.uint64()), + pa.chunked_array([[1, 2, 3, 4, 5]], type=pa.int64()), + pa.chunked_array([[1, 2, 3, 4, 5]], type=pa.float32()), + pa.chunked_array([[1, 2, 3, 4, 5]], type=pa.float64()), + ] + return pa.Table.from_arrays(columns, names=[str(i) for i in range(len(columns))]) + + +def generate_dummy_arrow_table() -> pa.Table: + col1 = pa.chunked_array([[1, 2, 3], [4, 5]], type=pa.uint8()) + col2 = pa.chunked_array([[0.5, 0.6], [0.1, 0.8, 1.5]], type=pa.float32()) + return pa.Table.from_arrays([col1, col2], names=["a", "b"]) + + +def generate_random_arrow_table(num_columns: int, num_datapoints: int, seed: int) -> pa.Table: + columns = [generate_random_arrow_array(num_datapoints, seed + i) for i in range(num_columns)] + names = [str(i) for i in range(num_columns)] + return pa.Table.from_arrays(columns, names=names) + + +def generate_random_arrow_array(num_datapoints: int, seed: int) -> pa.ChunkedArray: + generator = np.random.default_rng(seed) + data = generator.standard_normal(num_datapoints) + + # Set random nulls + indices = generator.choice(len(data), size=num_datapoints // 10) + data[indices] = None + + # Split data into random chunks + n_chunks = generator.integers(1, num_datapoints // 3) + split_points = np.sort(generator.choice(np.arange(1, num_datapoints), n_chunks, replace=False)) + split_points = np.concatenate([[0], split_points, [num_datapoints]]) + chunks = [data[split_points[i] : split_points[i + 1]] for i in range(len(split_points) - 1)] + chunks = [chunk for chunk in chunks if len(chunk) > 0] + + # Turn chunks into array + return pa.chunked_array(chunks, type=pa.float32()) + + +def dummy_dataset_params() -> Dict[str, Any]: + return { + "min_data_in_bin": 1, + "min_data_in_leaf": 1, + } + + +def arrays_equal(lhs: np.ndarray, rhs: np.ndarray) -> bool: + return lhs.dtype == rhs.dtype and np.array_equal(lhs, rhs) + + +# ----------------------------------------------------------------------------------------------- # +# UNIT TESTS # +# ----------------------------------------------------------------------------------------------- # + +# ------------------------------------------- DATASET ------------------------------------------- # + + +def test_dataset_construct_smoke(): + data = generate_random_arrow_table(10, 10000, 42) + label = generate_random_arrow_array(10000, 43) + weight = generate_random_arrow_array(10000, 44) + init_scores = generate_random_arrow_array(10000, 45) + + dataset = lgb.Dataset(data, label=label, weight=weight, init_score=init_scores) + dataset.construct() + + +@pytest.mark.parametrize( + ("arrow_table", "dataset_params"), + [ + (generate_simple_arrow_table(), dummy_dataset_params()), + (generate_dummy_arrow_table(), dummy_dataset_params()), + (generate_random_arrow_table(3, 1000, 42), {}), + (generate_random_arrow_table(100, 10000, 43), {}), + ], +) +def test_dataset_construct_fuzzy(arrow_table: pa.Table, dataset_params: Dict[str, Any]): + arrow_dataset = lgb.Dataset(arrow_table, params=dataset_params) + arrow_dataset.construct() + + pandas_dataset = lgb.Dataset(arrow_table.to_pandas(), params=dataset_params) + pandas_dataset.construct() + + with tempfile.TemporaryDirectory() as t: + tmpdir = Path(t) + arrow_dataset._dump_text(tmpdir / "arrow.txt") + pandas_dataset._dump_text(tmpdir / "pandas.txt") + assert filecmp.cmp(tmpdir / "arrow.txt", tmpdir / "pandas.txt") + + +def test_dataset_construct_labels(): + data = generate_dummy_arrow_table() + labels = pa.chunked_array([[0], [1, 0, 0, 1]], type=pa.uint8()) + dataset = lgb.Dataset(data, label=labels, params=dummy_dataset_params()) + dataset.construct() + + dataset._dump_text("out.txt") + + expected = np.array([0, 1, 0, 0, 1], dtype=np.float32) + assert arrays_equal(expected, dataset.get_label()) + + +def test_dataset_construct_weights(): + data = generate_dummy_arrow_table() + weights = pa.chunked_array([[0.3], [0.6, 1.0, 4.0], [2.5]], type=pa.float32()) + dataset = lgb.Dataset(data, weight=weights, params=dummy_dataset_params()) + dataset.construct() + + expected = np.array([0.3, 0.6, 1.0, 4.0, 2.5], dtype=np.float32) + assert arrays_equal(expected, dataset.get_weight()) + + +def test_dataset_construct_groups(): + data = generate_dummy_arrow_table() + groups = pa.chunked_array([[2], [1, 2]], type=pa.uint8()) + dataset = lgb.Dataset(data, group=groups, params=dummy_dataset_params()) + dataset.construct() + + expected = np.array([0, 2, 3, 5], dtype=np.int32) + assert arrays_equal(expected, dataset.get_field("group")) + + +def test_dataset_construct_init_scores_1d(): + data = generate_dummy_arrow_table() + init_scores = pa.chunked_array([[1.0, 2.0], [1.0, 1.0, 1.0]], type=pa.float32()) + dataset = lgb.Dataset(data, init_score=init_scores, params=dummy_dataset_params()) + dataset.construct() + + expected = np.array([1.0, 2.0, 1.0, 1.0, 1.0], dtype=np.float64) + assert arrays_equal(expected, dataset.get_init_score()) + + +def test_dataset_construct_init_scores_2d(): + data = generate_dummy_arrow_table() + init_scores = pa.Table.from_arrays( + [ + pa.chunked_array([[1.0, 2.0], [1.0, 1.0, 1.0]], type=pa.float32()), + pa.array([3.5, 3.5, 3.5, 3.5, 3.5], type=pa.float32()), + ], + names=["a", "b"], + ) + dataset = lgb.Dataset(data, init_score=init_scores, params=dummy_dataset_params()) + dataset.construct() + + expected = np.array( + [[1.0, 3.5], [2.0, 3.5], [1.0, 3.5], [1.0, 3.5], [1.0, 3.5]], dtype=np.float64 + ) + assert arrays_equal(expected, dataset.get_init_score()) + + +# ------------------------------------------ PREDICTION ----------------------------------------- # + + +def test_predict(): + data = generate_random_arrow_table(10, 10000, 42) + labels = generate_random_arrow_array(10000, 43) + dataset = lgb.Dataset(data, label=labels, params=dummy_dataset_params()) + booster = lgb.train({}, dataset, num_boost_round=1) + + out_arrow = booster.predict(data) + out_pandas = booster.predict(data.to_pandas()) + assert arrays_equal(out_arrow, out_pandas)