From 825e1a32e47bd71ad32aac6fca10a37700aac679 Mon Sep 17 00:00:00 2001 From: Minhua Chen Date: Tue, 24 Sep 2024 12:46:04 -0700 Subject: [PATCH] avoid early NE bump by setting coef_ema=0 (#3161) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3161 X-link: https://github.com/facebookresearch/FBGEMM/pull/257 refactor ensemble_rowwise_adagrad Reviewed By: csmiler Differential Revision: D63238676 fbshipit-source-id: e49491f742aa601cc44a16fd77bc02e573897041 --- fbgemm_gpu/codegen/genscript/optimizers.py | 31 ++++++++----------- .../tbe/training/backward_optimizers_test.py | 2 +- 2 files changed, 14 insertions(+), 19 deletions(-) diff --git a/fbgemm_gpu/codegen/genscript/optimizers.py b/fbgemm_gpu/codegen/genscript/optimizers.py index 3dd4a4f7e..acf2af31f 100644 --- a/fbgemm_gpu/codegen/genscript/optimizers.py +++ b/fbgemm_gpu/codegen/genscript/optimizers.py @@ -1047,27 +1047,22 @@ def ensemble_rowwise_adagrad() -> Dict[str, Any]: momentum2[idx] = new_sum_square_grads; multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps); - coef_ema = fabs(momentum); + coef_ema = (row_counter[idx] > step_start) ? (momentum*1.0) : 0.0; if (step_mode == 1) { - // row_counter[idx] records the number of appearances of this row + // row_counter[idx] tracks the number of appearances of this ID row_counter[idx] += 1.0; should_ema = floorf(row_counter[idx] / step_ema) - floorf((row_counter[idx]-1.0) / step_ema); should_swap = floorf(row_counter[idx] / step_swap) - floorf((row_counter[idx]-1.0) / step_swap); } else if (step_mode == 2) { - // row_counter[idx] records the step of last ema; prev_iter[idx] records the step of last swap - if (momentum > 0) { - should_ema = floorf(iter*1.0 / step_ema) - floorf(row_counter[idx] / step_ema); - should_swap = floorf(iter*1.0 / step_swap) - floorf(prev_iter[idx] / step_swap); - coef_ema = (should_ema > 0.5) ? powf(coef_ema, should_ema) : coef_ema; - } else { - should_ema = floorf((iter*1.0 - row_counter[idx]) / step_ema); - should_swap = floorf((iter*1.0 - prev_iter[idx]) / step_swap); - coef_ema = (should_ema > 0.5) ? powf(coef_ema, (iter*1.0 - row_counter[idx]) / step_ema) : coef_ema; - } + should_ema = floorf((iter*1.0 - row_counter[idx]) / step_ema); + should_swap = floorf((iter*1.0 - prev_iter[idx]) / step_swap); + // row_counter[idx] records the step of last ema if (should_ema > 0.5) { + coef_ema = powf(coef_ema, (iter*1.0 - row_counter[idx]) / step_ema); row_counter[idx] = iter*1.0; } - if (iter*1.0 > step_start && should_swap > 0.5) { + // prev_iter[idx] records the step of last swap + if (should_swap > 0.5) { prev_iter[idx] = iter*1.0; } } else { @@ -1089,14 +1084,14 @@ def ensemble_rowwise_adagrad() -> Dict[str, Any]: if (should_ema > 0.5) { // slow table ema Vec4T m_t(&momentum1[idx * D + d]); - m_t.acc.x = (1.0 - coef_ema) * weight_new.acc.x + coef_ema * m_t.acc.x + (fabs(momentum) - coef_ema) * multiplier * grad.acc.x; - m_t.acc.y = (1.0 - coef_ema) * weight_new.acc.y + coef_ema * m_t.acc.y + (fabs(momentum) - coef_ema) * multiplier * grad.acc.y; - m_t.acc.z = (1.0 - coef_ema) * weight_new.acc.z + coef_ema * m_t.acc.z + (fabs(momentum) - coef_ema) * multiplier * grad.acc.z; - m_t.acc.w = (1.0 - coef_ema) * weight_new.acc.w + coef_ema * m_t.acc.w + (fabs(momentum) - coef_ema) * multiplier * grad.acc.w; + m_t.acc.x = (1.0 - coef_ema) * weight_new.acc.x + coef_ema * m_t.acc.x + (momentum - coef_ema) * multiplier * grad.acc.x; + m_t.acc.y = (1.0 - coef_ema) * weight_new.acc.y + coef_ema * m_t.acc.y + (momentum - coef_ema) * multiplier * grad.acc.y; + m_t.acc.z = (1.0 - coef_ema) * weight_new.acc.z + coef_ema * m_t.acc.z + (momentum - coef_ema) * multiplier * grad.acc.z; + m_t.acc.w = (1.0 - coef_ema) * weight_new.acc.w + coef_ema * m_t.acc.w + (momentum - coef_ema) * multiplier * grad.acc.w; m_t.store(&momentum1[idx * D + d]); } - if (iter*1.0 > step_start && should_swap > 0.5) { // slow-to-fast swap + if (should_swap > 0.5) { // slow-to-fast swap Vec4T m_t(&momentum1[idx * D + d]); weight_new.acc.x = m_t.acc.x * 1.0; weight_new.acc.y = m_t.acc.y * 1.0; diff --git a/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py b/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py index 1365a8a36..adb96daaa 100644 --- a/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py +++ b/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py @@ -311,7 +311,7 @@ def execute_backward_optimizers_( # noqa C901 1e-4, 1.0, 1.0, - 0.0, + -1.0, StepMode.USE_ITER, 0.8, )