diff --git a/fbgemm_gpu/codegen/genscript/optimizers.py b/fbgemm_gpu/codegen/genscript/optimizers.py index 2e6b6ea544..79c89087b6 100644 --- a/fbgemm_gpu/codegen/genscript/optimizers.py +++ b/fbgemm_gpu/codegen/genscript/optimizers.py @@ -1040,33 +1040,33 @@ def ensemble_rowwise_adagrad() -> Dict[str, Any]: at::acc_type multiplier; at::acc_type coef_ema; - at::acc_type should_ema; - at::acc_type should_swap; + at::acc_type should_ema; + at::acc_type should_swap; if (threadIdx.x == 0) { at::acc_type new_sum_square_grads = momentum2[idx] + g_avg_square; momentum2[idx] = new_sum_square_grads; multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps); - coef_ema = fabs(momentum); + coef_ema = fabs(momentum); if (step_mode == 1) { // row_counter[idx] records the number of appearances of this row row_counter[idx] += 1.0; - should_ema = ((int64_t)round(fmod(row_counter[idx], step_ema)) == 0); - should_swap = (row_counter[idx] > step_start && (int64_t)round(fmod(row_counter[idx], step_swap)) == 0); + should_ema = floorf(row_counter[idx] / step_ema) - floorf((row_counter[idx]-1.0) / step_ema); + should_swap = floorf(row_counter[idx] / step_swap) - floorf((row_counter[idx]-1.0) / step_swap); } else if (step_mode == 2) { // row_counter[idx] records the step of last ema; prev_iter[idx] records the step of last swap - should_ema = ((iter*1.0 - row_counter[idx]) >= step_ema); - should_swap = (iter*1.0 > step_start && (iter*1.0 - prev_iter[idx]) >= step_swap); - if (should_ema) { - coef_ema = (momentum>0) ? powf(fabs(momentum), (iter*1.0 - row_counter[idx])/max(1.0, step_ema)) : fabs(momentum); + should_ema = floorf(iter*1.0 / step_ema) - floorf(row_counter[idx] / step_ema); + should_swap = floorf(iter*1.0 / step_swap) - floorf(prev_iter[idx] / step_swap); + if (should_ema > 0.5) { + coef_ema = (momentum>0) ? powf(coef_ema, should_ema) : coef_ema; row_counter[idx] = iter*1.0; } - if (should_swap) { + if (iter*1.0 > step_start && should_swap > 0.5) { prev_iter[idx] = iter*1.0; } } else { - should_ema = false; - should_swap = false; + should_ema = 0.0; + should_swap = 0.0; } } multiplier = SHFL_SYNC(multiplier, 0); @@ -1081,7 +1081,7 @@ def ensemble_rowwise_adagrad() -> Dict[str, Any]: weight_new.acc.z = weight_new.acc.z - multiplier * grad.acc.z; weight_new.acc.w = weight_new.acc.w - multiplier * grad.acc.w; - if (should_ema) { // slow table ema + if (should_ema > 0.5) { // slow table ema Vec4T m_t(&momentum1[idx * D + d]); m_t.acc.x = (1.0 - coef_ema) * weight_new.acc.x + coef_ema * m_t.acc.x + (fabs(momentum) - coef_ema) * multiplier * grad.acc.x; m_t.acc.y = (1.0 - coef_ema) * weight_new.acc.y + coef_ema * m_t.acc.y + (fabs(momentum) - coef_ema) * multiplier * grad.acc.y; @@ -1090,12 +1090,12 @@ def ensemble_rowwise_adagrad() -> Dict[str, Any]: m_t.store(&momentum1[idx * D + d]); } - if (should_swap) { // slow-to-fast swap + if (iter*1.0 > step_start && should_swap > 0.5) { // slow-to-fast swap Vec4T m_t(&momentum1[idx * D + d]); - weight_new.acc.x = m_t.acc.x; - weight_new.acc.y = m_t.acc.y; - weight_new.acc.z = m_t.acc.z; - weight_new.acc.w = m_t.acc.w; + weight_new.acc.x = m_t.acc.x * 1.0; + weight_new.acc.y = m_t.acc.y * 1.0; + weight_new.acc.z = m_t.acc.z * 1.0; + weight_new.acc.w = m_t.acc.w * 1.0; } """