Skip to content

Commit

Permalink
fix SetPosition
Browse files Browse the repository at this point in the history
add tests for get_position and set_position
  • Loading branch information
shiyu1994 committed Aug 4, 2023
1 parent 757f7cb commit 70fc191
Show file tree
Hide file tree
Showing 6 changed files with 1,631 additions and 14 deletions.
8 changes: 4 additions & 4 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1507,8 +1507,6 @@ def __init__(
sum(group) = n_samples.
For example, if you have a 100-document dataset with ``group = [10, 20, 40, 10, 10, 10]``, that means that you have 6 groups,
where the first 10 records are in the first group, records 11-30 are in the second group, records 31-70 are in the third group, etc.
position : numpy 1-D array, pandas Series or None, optional (default=None)
Position of items used in unbiased learning-to-rank task.
init_score : list, list of lists (for multi-class task), numpy array, pandas Series, pandas DataFrame (for multi-class task), or None, optional (default=None)
Init score for Dataset.
feature_name : list of str, or 'auto', optional (default="auto")
Expand All @@ -1528,6 +1526,8 @@ def __init__(
Other parameters for Dataset.
free_raw_data : bool, optional (default=True)
If True, raw data is freed after constructing inner Dataset.
position : numpy 1-D array, pandas Series or None, optional (default=None)
Position of items used in unbiased learning-to-rank task.
"""
self._handle: Optional[_DatasetHandle] = None
self.data = data
Expand Down Expand Up @@ -2263,12 +2263,12 @@ def create_valid(
sum(group) = n_samples.
For example, if you have a 100-document dataset with ``group = [10, 20, 40, 10, 10, 10]``, that means that you have 6 groups,
where the first 10 records are in the first group, records 11-30 are in the second group, records 31-70 are in the third group, etc.
position : numpy 1-D array, pandas Series or None, optional (default=None)
Position of items used in unbiased learning-to-rank task.
init_score : list, list of lists (for multi-class task), numpy array, pandas Series, pandas DataFrame (for multi-class task), or None, optional (default=None)
Init score for Dataset.
params : dict or None, optional (default=None)
Other parameters for validation Dataset.
position : numpy 1-D array, pandas Series or None, optional (default=None)
Position of items used in unbiased learning-to-rank task.
Returns
-------
Expand Down
22 changes: 12 additions & 10 deletions src/io/metadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -540,19 +540,21 @@ void Metadata::SetPosition(const data_size_t* positions, data_size_t len) {

position_load_from_file_ = false;

#pragma omp parallel for schedule(static, 512) if (num_positions_ >= 1024)
position_ids_.clear();
std::unordered_map<data_size_t, int> map_id2pos;
for (data_size_t i = 0; i < num_positions_; ++i) {
positions_[i] = positions[i];
if (map_id2pos.count(positions[i]) == 0) {
int pos = static_cast<int>(map_id2pos.size());
map_id2pos[positions[i]] = pos;
position_ids_.push_back(std::to_string(positions[i]));
}
}

position_ids_.clear();
std::set<int> position_set;
for (int position : positions_) {
position_set.insert(position);
}
Log::Debug("number of unique positions found = %ld", position_set.size());
for (int position : position_set) {
position_ids_.push_back(std::to_string(position));
Log::Debug("number of unique positions found = %ld", position_ids_.size());

#pragma omp parallel for schedule(static, 512) if (num_positions_ >= 1024)
for (data_size_t i = 0; i < num_positions_; ++i) {
positions_[i] = map_id2pos.at(positions[i]);
}
}

Expand Down
Loading

0 comments on commit 70fc191

Please sign in to comment.