diff --git a/benches/perpetual_benchmarks.rs b/benches/perpetual_benchmarks.rs index cf24016..fc077c0 100644 --- a/benches/perpetual_benchmarks.rs +++ b/benches/perpetual_benchmarks.rs @@ -21,7 +21,7 @@ pub fn tree_benchmarks(c: &mut Criterion) { fs::read_to_string("resources/performance_100k_samp_seed0.csv").expect("Something went wrong reading the file"); let y: Vec = file.lines().map(|x| x.parse::().unwrap()).collect(); let yhat = vec![0.5; y.len()]; - let (g, h) = LogLoss::calc_grad_hess(&y, &yhat, None, None); + let (mut g, mut h) = LogLoss::calc_grad_hess(&y, &yhat, None, None); let loss = LogLoss::calc_loss(&y, &yhat, None, None); let v: Vec = vec![10.; 300000]; @@ -56,8 +56,8 @@ pub fn tree_benchmarks(c: &mut Criterion) { data.index.to_owned(), &col_index, &bindata.cuts, - &g, - h.as_deref(), + &mut g, + h.as_deref_mut(), &splitter, true, Some(f32::MAX), @@ -82,8 +82,8 @@ pub fn tree_benchmarks(c: &mut Criterion) { black_box(data.index.to_owned()), black_box(&col_index), black_box(&bindata.cuts), - black_box(&g), - black_box(h.as_deref()), + black_box(&mut g), + black_box(h.as_deref_mut()), black_box(&splitter), black_box(false), Some(f32::MAX), @@ -108,8 +108,8 @@ pub fn tree_benchmarks(c: &mut Criterion) { black_box(data.index.to_owned()), black_box(&[1, 3, 4]), black_box(&bindata.cuts), - black_box(&g), - black_box(h.as_deref()), + black_box(&mut g), + black_box(h.as_deref_mut()), black_box(&splitter), black_box(false), Some(f32::MAX), diff --git a/python-package/examples/performance_benchmark.ipynb b/python-package/examples/performance_benchmark.ipynb index 233a4b8..70a359e 100644 --- a/python-package/examples/performance_benchmark.ipynb +++ b/python-package/examples/performance_benchmark.ipynb @@ -174,7 +174,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "[I 2024-07-10 12:37:35,963] A new study created in memory with name: no-name-4c6a881c-9102-44d9-b018-3af8d37cb2ae\n" + "[I 2024-07-11 01:26:49,412] A new study created in memory with name: no-name-4a28ade5-6c54-459c-aaa7-2f3e0bd9c040\n" ] } ], @@ -192,15 +192,15 @@ "name": "stderr", "output_type": "stream", "text": [ - "[I 2024-07-10 12:37:36,403] Trial 0 finished with value: 1.0644386016870766 and parameters: {'learning_rate': 0.4073657656436648, 'min_split_gain': 0.0019204079494910193, 'reg_alpha': 0.685655809011563, 'reg_lambda': 0.019448941142879615, 'colsample_bytree': 0.7581830596778167, 'subsample': 0.3728715964643011, 'subsample_freq': 10, 'max_depth': 3, 'num_leaves': 260, 'min_child_samples': 44}. Best is trial 0 with value: 1.0644386016870766.\n" + "[I 2024-07-11 01:26:50,073] Trial 0 finished with value: 1.0644386016870766 and parameters: {'learning_rate': 0.4073657656436648, 'min_split_gain': 0.0019204079494910193, 'reg_alpha': 0.685655809011563, 'reg_lambda': 0.019448941142879615, 'colsample_bytree': 0.7581830596778167, 'subsample': 0.3728715964643011, 'subsample_freq': 10, 'max_depth': 3, 'num_leaves': 260, 'min_child_samples': 44}. Best is trial 0 with value: 1.0644386016870766.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "CPU times: total: 422 ms\n", - "Wall time: 404 ms\n" + "CPU times: total: 469 ms\n", + "Wall time: 617 ms\n" ] } ], @@ -410,14 +410,14 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: total: 21.6 s\n", - "Wall time: 20.7 s\n" + "CPU times: total: 8.58 s\n", + "Wall time: 8.52 s\n" ] }, { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 20, @@ -427,7 +427,7 @@ ], "source": [ "%%time\n", - "model.fit(X_train, y_train, budget=1.5)" + "model.fit(X_train, y_train, budget=1.0)" ] }, { @@ -439,7 +439,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Test mse: 0.192352\n" + "Test mse: 0.198443\n" ] } ], @@ -459,7 +459,7 @@ { "data": { "text/plain": [ - "244" + "106" ] }, "execution_count": 22, @@ -483,20 +483,20 @@ "}\n", "\n", "\n", - "| Perpetual budget | Seed | Perpetual mse | Perpetual cpu time |\n", - "| ---------------- | ---- | ------------- | ------------------ |\n", - "| 1.0 | 0 | 0.187273 | 9.23 |\n", - "| 1.0 | 1 | 0.189911 | 10.5 |\n", - "| 1.0 | 2 | 0.194937 | 11.0 |\n", - "| 1.0 | 3 | 0.182932 | 9.77 |\n", - "| 1.0 | 4 | 0.198443 | 9.88 |\n", - "| 1.0 | avg | 0.190699 | 10.1 |\n", - "| 1.5 | 0 | 0.185843 | 28.6 |\n", - "| 1.5 | 1 | 0.188146 | 26.8 |\n", - "| 1.5 | 2 | 0.190484 | 26.6 |\n", - "| 1.5 | 3 | 0.178708 | 25.1 |\n", - "| 1.5 | 4 | 0.192352 | 21.6 |\n", - "| 1.5 | avg | 0.187107 | 25.7 |" + "| Perpetual budget | Seed | Perpetual mse | Perpetual cpu time | cpu time improved |\n", + "| ---------------- | ---- | ------------- | ------------------ | ----------------- |\n", + "| 1.0 | 0 | 0.187273 | 9.23 | 9.28 |\n", + "| 1.0 | 1 | 0.189911 | 10.5 | 9.69 |\n", + "| 1.0 | 2 | 0.194937 | 11.0 | 11.0 |\n", + "| 1.0 | 3 | 0.182932 | 9.77 | 10.5 |\n", + "| 1.0 | 4 | 0.198443 | 9.88 | 8.58 |\n", + "| 1.0 | avg | 0.190699 | 10.1 | 9.81 |\n", + "| 1.5 | 0 | 0.185843 | 28.6 | 27.2 |\n", + "| 1.5 | 1 | 0.188146 | 26.8 | 25.5 |\n", + "| 1.5 | 2 | 0.190484 | 26.6 | 25.2 |\n", + "| 1.5 | 3 | 0.178708 | 25.1 | 23.1 |\n", + "| 1.5 | 4 | 0.192352 | 21.6 | 20.8 |\n", + "| 1.5 | avg | 0.187107 | 25.7 | 24.4 |" ] }, { diff --git a/src/booster.rs b/src/booster.rs index cd0defa..50d673e 100644 --- a/src/booster.rs +++ b/src/booster.rs @@ -324,7 +324,6 @@ impl PerpetualBooster { self.base_score = calc_init_callables(&self.objective)(y, sample_weight, alpha); yhat = vec![self.base_score; y.len()]; } else { - // self.lumber(data, y, sample_weight, alpha)?; yhat = self.predict(data, self.parallel, None); } @@ -385,8 +384,8 @@ impl PerpetualBooster { data.index.to_owned(), &col_index, &binned_data.cuts, - &grad, - hess.as_deref(), + &mut grad, + hess.as_deref_mut(), splitter, self.parallel, tld, diff --git a/src/histogram.rs b/src/histogram.rs index 6adb4a4..8c827c4 100644 --- a/src/histogram.rs +++ b/src/histogram.rs @@ -223,6 +223,8 @@ impl HistogramMatrix { #[allow(clippy::too_many_arguments)] pub fn update( &mut self, + start: usize, + stop: usize, data: &Matrix, cuts: &JaggedMatrix, grad: &[f32], @@ -242,24 +244,19 @@ impl HistogramMatrix { let (sorted_grad, sorted_hess) = match hess { Some(hess) => { if !sort { - (grad.to_vec(), Some(hess.to_vec())) + (grad, Some(hess)) } else { - let mut n_grad = Vec::with_capacity(index.len()); - let mut n_hess = Vec::with_capacity(index.len()); - for i in index { - let i_ = *i; - n_grad.push(grad[i_]); - n_hess.push(hess[i_]); - } - (n_grad, Some(n_hess)) + let g = &grad[start..stop]; + let h = &hess[start..stop]; + (g, Some(h)) } } None => { if !sort { - (grad.to_vec(), None::>) + (grad, None) } else { - let n_grad = index.iter().map(|i| grad[*i]).collect::>(); - (n_grad, None::>) + let g = &grad[start..stop]; + (g, None) } } }; @@ -271,7 +268,7 @@ impl HistogramMatrix { cuts.get_col(*col), &sorted_grad, sorted_hess.as_deref(), - index, + &index[start..stop], ); }); } @@ -454,7 +451,18 @@ mod tests { let mut hist_init = HistogramMatrix::empty(&bdata, &b.cuts, &col_index, false, false); let hist_col = 1; - hist_init.update(&bdata, &b.cuts, &g, h.as_deref(), &bdata.index, &col_index, true, false); + hist_init.update( + 0, + bdata.index.len(), + &bdata, + &b.cuts, + &g, + h.as_deref(), + &bdata.index, + &col_index, + true, + false, + ); let mut f = bdata.get_col(hist_col).to_owned(); diff --git a/src/partial_dependence.rs b/src/partial_dependence.rs index 1272fe6..01f567d 100644 --- a/src/partial_dependence.rs +++ b/src/partial_dependence.rs @@ -94,7 +94,7 @@ mod tests { let file = fs::read_to_string("resources/performance.csv").expect("Something went wrong reading the file"); let y: Vec = file.lines().map(|x| x.parse::().unwrap()).collect(); let yhat = vec![0.5; y.len()]; - let (g, h) = LogLoss::calc_grad_hess(&y, &yhat, None, None); + let (mut g, mut h) = LogLoss::calc_grad_hess(&y, &yhat, None, None); let loss = LogLoss::calc_loss(&y, &yhat, None, None); let data = Matrix::new(&data_vec, 891, 5); @@ -120,8 +120,8 @@ mod tests { data.index.to_owned(), &col_index, &b.cuts, - &g, - h.as_deref(), + &mut g, + h.as_deref_mut(), &splitter, true, Some(f32::MAX), diff --git a/src/splitter.rs b/src/splitter.rs index 61269c2..0638c82 100644 --- a/src/splitter.rs +++ b/src/splitter.rs @@ -6,7 +6,7 @@ use crate::node::{NodeType, SplittableNode}; use crate::tree::Tree; use crate::utils::{ between, bound_to_parent, constrained_weight, constrained_weight_const_hess, cull_gain, gain_given_weight, - gain_given_weight_const_hess, pivot_on_split, pivot_on_split_exclude_missing, + gain_given_weight_const_hess, pivot_on_split, pivot_on_split_const_hess, pivot_on_split_exclude_missing, }; use hashbrown::HashMap; use std::collections::HashSet; @@ -543,8 +543,8 @@ pub trait Splitter { col_index: &[usize], data: &Matrix, cuts: &JaggedMatrix, - grad: &[f32], - hess: Option<&[f32]>, + grad: &mut [f32], + hess: Option<&mut [f32]>, parallel: bool, hist_tree: &mut HashMap, cat_index: Option<&[u64]>, @@ -561,8 +561,8 @@ pub trait Splitter { col_index: &[usize], data: &Matrix, cuts: &JaggedMatrix, - grad: &[f32], - hess: Option<&[f32]>, + grad: &mut [f32], + hess: Option<&mut [f32]>, parallel: bool, is_const_hess: bool, hist_tree: &mut HashMap, @@ -892,8 +892,8 @@ impl Splitter for MissingBranchSplitter { col_index: &[usize], data: &Matrix, cuts: &JaggedMatrix, - grad: &[f32], - hess: Option<&[f32]>, + grad: &mut [f32], + hess: Option<&mut [f32]>, parallel: bool, hist_map: &mut HashMap, cat_index: Option<&[u64]>, @@ -967,11 +967,13 @@ impl Splitter for MissingBranchSplitter { if max_ == 1 { let right_hist = hist_map.get_mut(&right_child).unwrap(); right_hist.update( + split_idx, + node.stop_idx, data, cuts, grad, - hess, - &index[split_idx..node.stop_idx], + hess.as_deref(), + &index, col_index, parallel, true, @@ -980,11 +982,13 @@ impl Splitter for MissingBranchSplitter { } else { let left_hist = hist_map.get_mut(&left_child).unwrap(); left_hist.update( + missing_split_idx, + split_idx, data, cuts, grad, - hess, - &index[missing_split_idx..split_idx], + hess.as_deref(), + &index, col_index, parallel, true, @@ -996,22 +1000,26 @@ impl Splitter for MissingBranchSplitter { // levels histograms. let left_hist = hist_map.get_mut(&left_child).unwrap(); left_hist.update( + missing_split_idx, + split_idx, data, cuts, grad, - hess, - &index[missing_split_idx..split_idx], + hess.as_deref(), + &index, col_index, parallel, true, ); let right_hist = hist_map.get_mut(&right_child).unwrap(); right_hist.update( + split_idx, + node.stop_idx, data, cuts, grad, - hess, - &index[split_idx..node.stop_idx], + hess.as_deref(), + &index, col_index, parallel, true, @@ -1020,22 +1028,26 @@ impl Splitter for MissingBranchSplitter { } else if max_ == 1 { let miss_hist = hist_map.get_mut(&missing_child).unwrap(); miss_hist.update( + node.start_idx, + missing_split_idx, data, cuts, grad, - hess, - &index[node.start_idx..missing_split_idx], + hess.as_deref(), + &index, col_index, parallel, true, ); let right_hist = hist_map.get_mut(&right_child).unwrap(); right_hist.update( + split_idx, + node.stop_idx, data, cuts, grad, - hess, - &index[split_idx..node.stop_idx], + hess.as_deref(), + &index, col_index, parallel, true, @@ -1045,22 +1057,26 @@ impl Splitter for MissingBranchSplitter { // right is the largest let miss_hist = hist_map.get_mut(&missing_child).unwrap(); miss_hist.update( + node.start_idx, + missing_split_idx, data, cuts, grad, - hess, - &index[node.start_idx..missing_split_idx], + hess.as_deref(), + &index, col_index, parallel, true, ); let left_hist = hist_map.get_mut(&left_child).unwrap(); left_hist.update( + missing_split_idx, + split_idx, data, cuts, grad, - hess, - &index[missing_split_idx..split_idx], + hess.as_deref(), + &index, col_index, parallel, true, @@ -1378,8 +1394,8 @@ impl Splitter for MissingImputerSplitter { col_index: &[usize], data: &Matrix, cuts: &JaggedMatrix, - grad: &[f32], - hess: Option<&[f32]>, + grad: &mut [f32], + mut hess: Option<&mut [f32]>, parallel: bool, hist_map: &mut HashMap, cat_index: Option<&[u64]>, @@ -1403,13 +1419,31 @@ impl Splitter for MissingImputerSplitter { // separate missing branch. // // This function mutates index by swapping indices based on split bin - let mut split_idx = pivot_on_split( - &mut index[node.start_idx..node.stop_idx], - data.get_col(split_info.split_feature), - split_info.split_bin, - missing_right, - left_cat.as_deref(), - ); + let mut split_idx: usize; + if hess.is_none() { + split_idx = pivot_on_split_const_hess( + node.start_idx, + node.stop_idx, + index, + grad, + data.get_col(split_info.split_feature), + split_info.split_bin, + missing_right, + left_cat.as_deref(), + ); + } else { + split_idx = pivot_on_split( + node.start_idx, + node.stop_idx, + index, + grad, + &mut hess.as_mut().unwrap(), + data.get_col(split_info.split_feature), + split_info.split_bin, + missing_right, + left_cat.as_deref(), + ); + } // Calculate histograms let total_recs = node.stop_idx - node.start_idx; @@ -1433,11 +1467,13 @@ impl Splitter for MissingImputerSplitter { if n_left < n_right { let left_hist = hist_map.get_mut(&left_child).unwrap(); left_hist.update( + node.start_idx, + split_idx, data, cuts, grad, - hess, - &index[node.start_idx..split_idx], + hess.as_deref(), + index, col_index, parallel, true, @@ -1446,11 +1482,13 @@ impl Splitter for MissingImputerSplitter { } else { let right_hist = hist_map.get_mut(&right_child).unwrap(); right_hist.update( + split_idx, + node.stop_idx, data, cuts, grad, - hess, - &index[split_idx..node.stop_idx], + hess.as_deref(), + index, col_index, parallel, true, @@ -1533,6 +1571,7 @@ fn get_categories( mod tests { use super::*; use crate::binning::bin_matrix; + use crate::constants::N_NODES_ALLOCATED; use crate::data::Matrix; use crate::node::SplittableNode; use crate::objective::{LogLoss, ObjectiveFunction, SquaredLoss}; @@ -1573,10 +1612,21 @@ mod tests { let col_index: Vec = (0..data.cols).collect(); let mut hist_init = HistogramMatrix::empty(&bdata, &b.cuts, &col_index, true, false); - hist_init.update(&bdata, &b.cuts, &grad, hess.as_deref(), &index, &col_index, true, false); - let hist_capacity = 100; - let mut hist_map: HashMap = HashMap::with_capacity(hist_capacity); - for i in 0..hist_capacity { + hist_init.update( + 0, + index.len(), + &bdata, + &b.cuts, + &grad, + hess.as_deref(), + &index, + &col_index, + true, + false, + ); + + let mut hist_map: HashMap = HashMap::with_capacity(N_NODES_ALLOCATED); + for i in 0..N_NODES_ALLOCATED { hist_map.insert(i, hist_init.clone()); } @@ -1689,7 +1739,18 @@ mod tests { let col_index: Vec = (0..data.cols).collect(); let mut hist_init = HistogramMatrix::empty(&bdata, &b.cuts, &col_index, hess.is_some(), false); - hist_init.update(&bdata, &b.cuts, &grad, hess.as_deref(), &index, &col_index, true, false); + hist_init.update( + 0, + index.len(), + &bdata, + &b.cuts, + &grad, + hess.as_deref(), + &index, + &col_index, + true, + false, + ); let hist_capacity = 10; let mut hist_map: HashMap = HashMap::with_capacity(hist_capacity); for i in 0..hist_capacity { @@ -1759,10 +1820,21 @@ mod tests { let col_index: Vec = (0..data.cols).collect(); let mut hist_init = HistogramMatrix::empty(&bdata, &b.cuts, &col_index, false, false); - hist_init.update(&bdata, &b.cuts, &grad, hess.as_deref(), &index, &col_index, true, false); - let hist_capacity = 10; - let mut hist_map: HashMap = HashMap::with_capacity(hist_capacity); - for i in 0..hist_capacity { + hist_init.update( + 0, + index.len(), + &bdata, + &b.cuts, + &grad, + hess.as_deref(), + &index, + &col_index, + true, + false, + ); + + let mut hist_map: HashMap = HashMap::with_capacity(N_NODES_ALLOCATED); + for i in 0..N_NODES_ALLOCATED { hist_map.insert(i, hist_init.clone()); } diff --git a/src/tree.rs b/src/tree.rs index 9ab227c..1cd76cc 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -51,8 +51,8 @@ impl Tree { mut index: Vec, col_index: &[usize], cuts: &JaggedMatrix, - grad: &[f32], - hess: Option<&[f32]>, + grad: &mut [f32], + mut hess: Option<&mut [f32]>, splitter: &T, parallel: bool, target_loss_decrement: Option, @@ -75,7 +75,18 @@ impl Tree { } let root_hist = hist_tree.get_mut(&0).unwrap(); - root_hist.update(data, cuts, grad, hess, &index, col_index, parallel, false); + root_hist.update( + 0, + index.len(), + data, + cuts, + grad, + hess.as_deref(), + &index, + col_index, + parallel, + false, + ); if let Some(c_index) = cat_index { let histograms = unsafe { hist_tree.get_many_unchecked_mut([&0, &1, &2]).unwrap() }; @@ -127,7 +138,7 @@ impl Tree { data, cuts, grad, - hess, + hess.as_deref_mut(), parallel, is_const_hess, &mut hist_tree, @@ -602,7 +613,7 @@ mod tests { let file = fs::read_to_string("resources/performance.csv").expect("Something went wrong reading the file"); let y: Vec = file.lines().map(|x| x.parse::().unwrap()).collect(); let yhat = vec![0.5; y.len()]; - let (g, h) = LogLoss::calc_grad_hess(&y, &yhat, None, None); + let (mut g, mut h) = LogLoss::calc_grad_hess(&y, &yhat, None, None); let loss = LogLoss::calc_loss(&y, &yhat, None, None); let data = Matrix::new(&data_vec, 891, 5); @@ -633,8 +644,8 @@ mod tests { index, &col_index, &b.cuts, - &g, - h.as_deref(), + &mut g, + h.as_deref_mut(), &splitter, true, Some(f32::MAX), @@ -658,7 +669,7 @@ mod tests { let file = fs::read_to_string("resources/performance.csv").expect("Something went wrong reading the file"); let y: Vec = file.lines().map(|x| x.parse::().unwrap()).collect(); let yhat = vec![0.5; y.len()]; - let (g, h) = LogLoss::calc_grad_hess(&y, &yhat, None, None); + let (mut g, mut h) = LogLoss::calc_grad_hess(&y, &yhat, None, None); let loss = LogLoss::calc_loss(&y, &yhat, None, None); let data = Matrix::new(&data_vec, 891, 5); @@ -685,8 +696,8 @@ mod tests { data.index.to_owned(), &col_index, &b.cuts, - &g, - h.as_deref(), + &mut g, + h.as_deref_mut(), &splitter, true, Some(f32::MAX), @@ -745,7 +756,7 @@ mod tests { let file = fs::read_to_string("resources/performance.csv").expect("Something went wrong reading the file"); let y: Vec = file.lines().map(|x| x.parse::().unwrap()).collect(); let yhat = vec![0.5; y.len()]; - let (g, h) = LogLoss::calc_grad_hess(&y, &yhat, None, None); + let (mut g, mut h) = LogLoss::calc_grad_hess(&y, &yhat, None, None); let loss = LogLoss::calc_loss(&y, &yhat, None, None); let data = Matrix::new(&data_vec, 891, 5); @@ -773,8 +784,8 @@ mod tests { data.index.to_owned(), &col_index, &b.cuts, - &g, - h.as_deref(), + &mut g, + h.as_deref_mut(), &splitter, false, Some(f32::MAX), @@ -803,7 +814,7 @@ mod tests { let file = fs::read_to_string("resources/performance.csv").expect("Something went wrong reading the file"); let y: Vec = file.lines().map(|x| x.parse::().unwrap()).collect(); let yhat = vec![0.5; y.len()]; - let (g, h) = LogLoss::calc_grad_hess(&y, &yhat, None, None); + let (mut g, h) = LogLoss::calc_grad_hess(&y, &yhat, None, None); let loss = LogLoss::calc_loss(&y, &yhat, None, None); println!("GRADIENT -- {:?}", h); @@ -833,7 +844,7 @@ mod tests { data.index.to_owned(), &col_index, &b.cuts, - &g, + &mut g, None, &splitter, true, @@ -892,7 +903,7 @@ mod tests { let file = fs::read_to_string("resources/performance.csv").expect("Something went wrong reading the file"); let y: Vec = file.lines().map(|x| x.parse::().unwrap()).collect(); let yhat = vec![0.5; y.len()]; - let (g, h) = LogLoss::calc_grad_hess(&y, &yhat, None, None); + let (mut g, mut h) = LogLoss::calc_grad_hess(&y, &yhat, None, None); let loss = LogLoss::calc_loss(&y, &yhat, None, None); let data = Matrix::new(&data_vec, 891, 5); @@ -919,8 +930,8 @@ mod tests { data.index.to_owned(), &col_index, &b.cuts, - &g, - h.as_deref(), + &mut g, + h.as_deref_mut(), &splitter, true, Some(f32::MAX), @@ -1028,7 +1039,7 @@ mod tests { ); let y_test_avg = y_test.iter().sum::() / y_test.len() as f64; let yhat = vec![y_test_avg; y_test.len()]; - let (grad, hess) = SquaredLoss::calc_grad_hess(&y_test, &yhat, None, None); + let (mut g, mut h) = SquaredLoss::calc_grad_hess(&y_test, &yhat, None, None); let loss = SquaredLoss::calc_loss(&y_test, &yhat, None, None); let splitter = MissingImputerSplitter { @@ -1055,8 +1066,8 @@ mod tests { data.index.to_owned(), &col_index, &b.cuts, - &grad, - hess.as_deref(), + &mut g, + h.as_deref_mut(), &splitter, true, Some(f32::MAX), @@ -1093,7 +1104,7 @@ mod tests { let y_avg = y.iter().sum::() / y.len() as f64; let yhat = vec![y_avg; y.len()]; - let (grad, hess) = LogLoss::calc_grad_hess(&y, &yhat, None, None); + let (mut grad, mut hess) = LogLoss::calc_grad_hess(&y, &yhat, None, None); let loss = LogLoss::calc_loss(&y, &yhat, None, None); let splitter = MissingImputerSplitter { @@ -1110,7 +1121,18 @@ mod tests { let col_index: Vec = (0..data.cols).collect(); let mut hist_node = HistogramMatrix::empty(&bdata, &b.cuts, &col_index, false, false); - hist_node.update(&bdata, &b.cuts, &grad, hess.as_deref(), &index, &col_index, true, false); + hist_node.update( + 0, + index.len(), + &bdata, + &b.cuts, + &grad, + hess.as_deref(), + &index, + &col_index, + true, + false, + ); let mut hist_tree: BrownHashMap = BrownHashMap::with_capacity(N_NODES_ALLOCATED); for i in 0..N_NODES_ALLOCATED { hist_tree.insert(i, hist_node.clone()); @@ -1122,8 +1144,8 @@ mod tests { data.index.to_owned(), &col_index, &b.cuts, - &grad, - hess.as_deref(), + &mut grad, + hess.as_deref_mut(), &splitter, true, Some(f32::MAX), diff --git a/src/utils.rs b/src/utils.rs index 21aa7bd..00d55d5 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -376,15 +376,72 @@ pub fn map_bin>(x: &[T], v: &T, missing: &T) -> Option { /// to the right of the split value. #[inline] pub fn pivot_on_split( - index: &mut [usize], + start: usize, + stop: usize, + ind: &mut [usize], + grad: &mut [f32], + hess: &mut [f32], + feature: &[u16], + split_value: u16, + missing_right: bool, + left_cat: Option<&[u16]>, +) -> usize { + let index = &mut ind[start..stop]; + let g = &mut grad[start..stop]; + let h = &mut hess[start..stop]; + + let length = index.len(); + let mut last_idx = length - 1; + let mut rv = None; + + for i in 0..length { + loop { + match missing_compare(&split_value, feature[index[i]], missing_right, left_cat) { + Ordering::Less | Ordering::Equal => { + if last_idx <= i { + rv = Some(i); + break; + } + index.swap(i, last_idx); + g.swap(i, last_idx); + h.swap(i, last_idx); + if last_idx == 0 { + rv = Some(0); + break; + } + last_idx -= 1; + } + Ordering::Greater => break, + } + } + if i >= last_idx { + break; + } + } + match rv { + Some(r) => r, + None => last_idx + 1, + } +} + +#[inline] +pub fn pivot_on_split_const_hess( + start: usize, + stop: usize, + ind: &mut [usize], + grad: &mut [f32], feature: &[u16], split_value: u16, missing_right: bool, left_cat: Option<&[u16]>, ) -> usize { + let index = &mut ind[start..stop]; + let g = &mut grad[start..stop]; + let length = index.len(); let mut last_idx = length - 1; let mut rv = None; + for i in 0..length { loop { match missing_compare(&split_value, feature[index[i]], missing_right, left_cat) { @@ -394,6 +451,7 @@ pub fn pivot_on_split( break; } index.swap(i, last_idx); + g.swap(i, last_idx); if last_idx == 0 { rv = Some(0); break; @@ -626,8 +684,10 @@ mod tests { } let mut idx = vec![2, 6, 9, 5, 8, 13, 11, 7]; + let mut grad = vec![0.2, 0.6, 0.9, 0.5, 0.8, 0.1, 0.1, 0.7]; + let mut hess = vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]; let f = vec![15, 10, 10, 11, 3, 18, 9, 3, 5, 2, 6, 13, 19, 14]; - let split_i = pivot_on_split(&mut idx, &f, 10, true, None); + let split_i = pivot_on_split(0, idx.len(), &mut idx, &mut grad, &mut hess, &f, 10, true, None); println!("split_i: {}", split_i); println!("idx: {:?}", idx); println!("sorted: {:?}", idx.iter().map(|i| f[*i]).collect::>()); @@ -635,6 +695,8 @@ mod tests { let missing_right = true; let mut idx = vec![2, 6, 9, 5, 8, 13, 11, 7]; + let mut grad = vec![0.2, 0.6, 0.9, 0.5, 0.8, 0.1, 0.1, 0.7]; + let mut hess = vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]; let f = vec![15, 10, 10, 11, 3, 18, 9, 3, 5, 2, 6, 13, 19, 14]; idx.sort_by_key(|i| { if f[*i] == 0 { @@ -655,7 +717,7 @@ mod tests { let mut idx = vec![2, 6, 9, 5, 8, 13, 11, 7]; let f = vec![15, 10, 10, 11, 3, 18, 0, 9, 3, 5, 2, 6, 13, 19, 14]; - let split_i = pivot_on_split(&mut idx, &f, 10, false, None); + let split_i = pivot_on_split(0, idx.len(), &mut idx, &mut grad, &mut hess, &f, 10, false, None); println!("{}", split_i); println!("{:?}", idx); println!("{:?}", idx.iter().map(|i| f[*i]).collect::>()); @@ -663,24 +725,32 @@ mod tests { // Test Minimum value... let mut idx = vec![2, 6, 9, 5, 8, 13, 11, 7]; + let mut grad = vec![0.2, 0.6, 0.9, 0.5, 0.8, 0.1, 0.1, 0.7]; + let mut hess = vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]; let f = vec![15, 10, 10, 11, 3, 18, 0, 9, 3, 5, 2, 6, 13, 19, 14]; - let split_i = pivot_on_split(&mut idx, &f, 1, true, None); + let split_i = pivot_on_split(0, idx.len(), &mut idx, &mut grad, &mut hess, &f, 1, true, None); pivot_assert(&f, &idx, split_i, true, 1); let mut idx = vec![2, 6, 9, 5, 8, 13, 11, 7]; + let mut grad = vec![0.2, 0.6, 0.9, 0.5, 0.8, 0.1, 0.1, 0.7]; + let mut hess = vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]; let f = vec![15, 10, 10, 11, 3, 18, 0, 9, 3, 5, 2, 6, 13, 19, 14]; - let split_i = pivot_on_split(&mut idx, &f, 1, false, None); + let split_i = pivot_on_split(0, idx.len(), &mut idx, &mut grad, &mut hess, &f, 1, false, None); pivot_assert(&f, &idx, split_i, false, 1); // Test Maximum value... let mut idx = vec![2, 6, 9, 5, 8, 13, 11, 7]; + let mut grad = vec![0.2, 0.6, 0.9, 0.5, 0.8, 0.1, 0.1, 0.7]; + let mut hess = vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]; let f = vec![15, 10, 10, 11, 3, 18, 0, 9, 3, 5, 2, 6, 13, 19, 14]; - let split_i = pivot_on_split(&mut idx, &f, 19, true, None); + let split_i = pivot_on_split(0, idx.len(), &mut idx, &mut grad, &mut hess, &f, 19, true, None); pivot_assert(&f, &idx, split_i, true, 19); let mut idx = vec![2, 6, 9, 5, 8, 13, 11, 7]; + let mut grad = vec![0.2, 0.6, 0.9, 0.5, 0.8, 0.1, 0.1, 0.7]; + let mut hess = vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]; let f = vec![15, 10, 10, 11, 3, 18, 0, 9, 3, 5, 2, 6, 13, 19, 14]; - let split_i = pivot_on_split(&mut idx, &f, 19, false, None); + let split_i = pivot_on_split(0, idx.len(), &mut idx, &mut grad, &mut hess, &f, 19, false, None); pivot_assert(&f, &idx, split_i, false, 19); // Random tests... right... @@ -689,16 +759,18 @@ mod tests { let mut rng = StdRng::seed_from_u64(0); let f = (0..100).map(|_| rng.gen_range(0..15)).collect::>(); let mut idx = index.choose_multiple(&mut rng, 73).copied().collect::>(); - let split_i = pivot_on_split(&mut idx, &f, 7, true, None); + let mut grad = idx.iter().map(|i| *i as f32).collect::>(); + let mut hess = grad.clone(); + let split_i = pivot_on_split(0, idx.len(), &mut idx, &mut grad, &mut hess, &f, 7, true, None); pivot_assert(&f, &idx, split_i, true, 7); // Already sorted... - let split_i = pivot_on_split(&mut idx, &f, 7, true, None); + let split_i = pivot_on_split(0, idx.len(), &mut idx, &mut grad, &mut hess, &f, 7, true, None); pivot_assert(&f, &idx, split_i, true, 7); // Reversed idx.reverse(); - let split_i = pivot_on_split(&mut idx, &f, 7, true, None); + let split_i = pivot_on_split(0, idx.len(), &mut idx, &mut grad, &mut hess, &f, 7, true, None); pivot_assert(&f, &idx, split_i, true, 7); // Without missing... @@ -706,7 +778,9 @@ mod tests { let mut rng = StdRng::seed_from_u64(0); let f = (0..100).map(|_| rng.gen_range(1..15)).collect::>(); let mut idx = index.choose_multiple(&mut rng, 73).copied().collect::>(); - let split_i = pivot_on_split(&mut idx, &f, 5, true, None); + let mut grad = idx.iter().map(|i| *i as f32).collect::>(); + let mut hess = grad.clone(); + let split_i = pivot_on_split(0, idx.len(), &mut idx, &mut grad, &mut hess, &f, 5, true, None); pivot_assert(&f, &idx, split_i, true, 5); // Using max... @@ -715,7 +789,9 @@ mod tests { let f = (0..100).map(|_| rng.gen_range(0..15)).collect::>(); let mut idx = index.choose_multiple(&mut rng, 73).copied().collect::>(); let sv = idx.iter().map(|i| f[*i]).max().unwrap(); - let split_i = pivot_on_split(&mut idx, &f, sv, true, None); + let mut grad = idx.iter().map(|i| *i as f32).collect::>(); + let mut hess = grad.clone(); + let split_i = pivot_on_split(0, idx.len(), &mut idx, &mut grad, &mut hess, &f, sv, true, None); pivot_assert(&f, &idx, split_i, true, sv); // Using non-0 minimum... @@ -724,7 +800,9 @@ mod tests { let f = (0..100).map(|_| rng.gen_range(0..15)).collect::>(); let mut idx = index.choose_multiple(&mut rng, 73).copied().collect::>(); let sv = idx.iter().filter(|i| f[**i] > 0).map(|i| f[*i]).min().unwrap(); - let split_i = pivot_on_split(&mut idx, &f, sv, true, None); + let mut grad = idx.iter().map(|i| *i as f32).collect::>(); + let mut hess = grad.clone(); + let split_i = pivot_on_split(0, idx.len(), &mut idx, &mut grad, &mut hess, &f, sv, true, None); pivot_assert(&f, &idx, split_i, true, sv); // Using non-0 minimum with no missing... @@ -733,7 +811,9 @@ mod tests { let f = (0..100).map(|_| rng.gen_range(1..15)).collect::>(); let mut idx = index.choose_multiple(&mut rng, 73).copied().collect::>(); let sv = idx.iter().filter(|i| f[**i] > 0).map(|i| f[*i]).min().unwrap(); - let split_i = pivot_on_split(&mut idx, &f, sv, true, None); + let mut grad = idx.iter().map(|i| *i as f32).collect::>(); + let mut hess = grad.clone(); + let split_i = pivot_on_split(0, idx.len(), &mut idx, &mut grad, &mut hess, &f, sv, true, None); pivot_assert(&f, &idx, split_i, true, sv); // Left @@ -741,16 +821,18 @@ mod tests { let mut rng = StdRng::seed_from_u64(0); let f = (0..100).map(|_| rng.gen_range(0..15)).collect::>(); let mut idx = index.choose_multiple(&mut rng, 73).copied().collect::>(); - let split_i = pivot_on_split(&mut idx, &f, 7, false, None); + let mut grad = idx.iter().map(|i| *i as f32).collect::>(); + let mut hess = grad.clone(); + let split_i = pivot_on_split(0, idx.len(), &mut idx, &mut grad, &mut hess, &f, 7, false, None); pivot_assert(&f, &idx, split_i, false, 7); // Already sorted... - let split_i = pivot_on_split(&mut idx, &f, 7, false, None); + let split_i = pivot_on_split(0, idx.len(), &mut idx, &mut grad, &mut hess, &f, 7, false, None); pivot_assert(&f, &idx, split_i, false, 7); // Reversed idx.reverse(); - let split_i = pivot_on_split(&mut idx, &f, 7, false, None); + let split_i = pivot_on_split(0, idx.len(), &mut idx, &mut grad, &mut hess, &f, 7, false, None); pivot_assert(&f, &idx, split_i, false, 7); // Without missing... @@ -758,7 +840,9 @@ mod tests { let mut rng = StdRng::seed_from_u64(0); let f = (0..100).map(|_| rng.gen_range(1..15)).collect::>(); let mut idx = index.choose_multiple(&mut rng, 73).copied().collect::>(); - let split_i = pivot_on_split(&mut idx, &f, 5, false, None); + let mut grad = idx.iter().map(|i| *i as f32).collect::>(); + let mut hess = grad.clone(); + let split_i = pivot_on_split(0, idx.len(), &mut idx, &mut grad, &mut hess, &f, 5, false, None); pivot_assert(&f, &idx, split_i, false, 5); // Using max... @@ -767,7 +851,9 @@ mod tests { let f = (0..100).map(|_| rng.gen_range(0..15)).collect::>(); let mut idx = index.choose_multiple(&mut rng, 73).copied().collect::>(); let sv = idx.iter().map(|i| f[*i]).max().unwrap(); - let split_i = pivot_on_split(&mut idx, &f, sv, false, None); + let mut grad = idx.iter().map(|i| *i as f32).collect::>(); + let mut hess = grad.clone(); + let split_i = pivot_on_split(0, idx.len(), &mut idx, &mut grad, &mut hess, &f, sv, false, None); pivot_assert(&f, &idx, split_i, false, sv); // Using non-0 minimum... @@ -776,7 +862,9 @@ mod tests { let f = (0..100).map(|_| rng.gen_range(0..15)).collect::>(); let mut idx = index.choose_multiple(&mut rng, 73).copied().collect::>(); let sv = idx.iter().filter(|i| f[**i] > 0).map(|i| f[*i]).min().unwrap(); - let split_i = pivot_on_split(&mut idx, &f, sv, false, None); + let mut grad = idx.iter().map(|i| *i as f32).collect::>(); + let mut hess = grad.clone(); + let split_i = pivot_on_split(0, idx.len(), &mut idx, &mut grad, &mut hess, &f, sv, false, None); pivot_assert(&f, &idx, split_i, false, sv); // Using non-0 minimum with no missing... @@ -785,7 +873,9 @@ mod tests { let f = (0..100).map(|_| rng.gen_range(1..15)).collect::>(); let mut idx = index.choose_multiple(&mut rng, 73).copied().collect::>(); let sv = idx.iter().filter(|i| f[**i] > 0).map(|i| f[*i]).min().unwrap(); - let split_i = pivot_on_split(&mut idx, &f, sv, false, None); + let mut grad = idx.iter().map(|i| *i as f32).collect::>(); + let mut hess = grad.clone(); + let split_i = pivot_on_split(0, idx.len(), &mut idx, &mut grad, &mut hess, &f, sv, false, None); pivot_assert(&f, &idx, split_i, false, sv); }