From 9aeec81dede66b000164e6632a80e7adf1210041 Mon Sep 17 00:00:00 2001 From: lewardo Date: Tue, 26 Sep 2023 11:11:18 +0100 Subject: [PATCH] fix temporary xvalue binding issues --- include/algorithms/public/Recur.hpp | 63 +++++++++++++++++++---------- 1 file changed, 41 insertions(+), 22 deletions(-) diff --git a/include/algorithms/public/Recur.hpp b/include/algorithms/public/Recur.hpp index 067d3d14a..b385305df 100644 --- a/include/algorithms/public/Recur.hpp +++ b/include/algorithms/public/Recur.hpp @@ -242,12 +242,16 @@ class Recur CellSeries bottomNodes, topNodes; - bottomNodes.emplace_back(mBottomParams); - bottomNodes[0].forwardFrame(input.row(0), StateType{mBottomParams}); + { + StateType bottomState{mBottomParams}, topState{mTopParams}; - topNodes.emplace_back(mTopParams); - topNodes[0].forwardFrame(bottomNodes[0].getState().output(), - StateType{mTopParams}); + bottomNodes.emplace_back(mBottomParams); + bottomNodes[0].forwardFrame(input.row(0), bottomState); + + topNodes.emplace_back(mTopParams); + topNodes[0].forwardFrame(bottomNodes[0].getState().output(), + topState); + } for (index i = 1; i < input.rows(); ++i) { @@ -258,16 +262,22 @@ class Recur topNodes[i].forwardFrame(bottomNodes[i].getState().output(), topNodes[i - 1].getState()); } + + double loss; - double loss = topNodes.back().backwardFrame(output, StateType{mTopParams}); - bottomNodes.back().backwardFrame(topNodes.back().getState(), - StateType{mBottomParams}); - - for (index i = input.rows() - 2; i >= 0; --i) { - topNodes[i].backwardFrame(topNodes[i + 1].getState()); - bottomNodes[i].backwardFrame(topNodes[i].getState(), - bottomNodes[i + 1].getState()); + StateType bottomState{mBottomParams}, topState{mTopParams}; + + loss = topNodes.back().backwardFrame(output, topState); + bottomNodes.back().backwardFrame(topNodes.back().getState(), + bottomState); + + for (index i = input.rows() - 2; i >= 0; --i) + { + topNodes[i].backwardFrame(topNodes[i + 1].getState()); + bottomNodes[i].backwardFrame(topNodes[i].getState(), + bottomNodes[i + 1].getState()); + } } return loss; @@ -280,12 +290,16 @@ class Recur CellSeries bottomNodes, topNodes; - bottomNodes.emplace_back(mBottomParams); - bottomNodes[0].forwardFrame(data.row(0), StateType{mBottomParams}); + { + StateType bottomState{mBottomParams}, topState{mTopParams}; - topNodes.emplace_back(mTopParams); - topNodes[0].forwardFrame(bottomNodes[0].getState().output(), - StateType{mTopParams}); + bottomNodes.emplace_back(mBottomParams); + bottomNodes[0].forwardFrame(data.row(0), bottomState); + + topNodes.emplace_back(mTopParams); + topNodes[0].forwardFrame(bottomNodes[0].getState().output(), + topState); + } for (index i = 1; i < data.rows() - 1; ++i) { @@ -298,10 +312,15 @@ class Recur } double loss = 0.0; - loss += topNodes.back().backwardFrame(data.row(data.rows() - 1), - StateType{mTopParams}); - bottomNodes.back().backwardFrame(topNodes.back().getState(), - StateType{mBottomParams}); + + { + StateType bottomState{mBottomParams}, topState{mTopParams}; + + loss += topNodes.back().backwardFrame(data.row(data.rows() - 1), + topState); + bottomNodes.back().backwardFrame(topNodes.back().getState(), + bottomState); + } for (index i = data.rows() - 3; i >= 0; --i) {