Skip to content

Commit

Permalink
fix temporary xvalue binding issues
Browse files Browse the repository at this point in the history
  • Loading branch information
lewardo committed Sep 26, 2023
1 parent 33924d6 commit 9aeec81
Showing 1 changed file with 41 additions and 22 deletions.
63 changes: 41 additions & 22 deletions include/algorithms/public/Recur.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,12 +242,16 @@ class Recur

CellSeries bottomNodes, topNodes;

bottomNodes.emplace_back(mBottomParams);
bottomNodes[0].forwardFrame(input.row(0), StateType{mBottomParams});
{
StateType bottomState{mBottomParams}, topState{mTopParams};

topNodes.emplace_back(mTopParams);
topNodes[0].forwardFrame(bottomNodes[0].getState().output(),
StateType{mTopParams});
bottomNodes.emplace_back(mBottomParams);
bottomNodes[0].forwardFrame(input.row(0), bottomState);

topNodes.emplace_back(mTopParams);
topNodes[0].forwardFrame(bottomNodes[0].getState().output(),
topState);
}

for (index i = 1; i < input.rows(); ++i)
{
Expand All @@ -258,16 +262,22 @@ class Recur
topNodes[i].forwardFrame(bottomNodes[i].getState().output(),
topNodes[i - 1].getState());
}

double loss;

double loss = topNodes.back().backwardFrame(output, StateType{mTopParams});
bottomNodes.back().backwardFrame(topNodes.back().getState(),
StateType{mBottomParams});

for (index i = input.rows() - 2; i >= 0; --i)
{
topNodes[i].backwardFrame(topNodes[i + 1].getState());
bottomNodes[i].backwardFrame(topNodes[i].getState(),
bottomNodes[i + 1].getState());
StateType bottomState{mBottomParams}, topState{mTopParams};

loss = topNodes.back().backwardFrame(output, topState);
bottomNodes.back().backwardFrame(topNodes.back().getState(),
bottomState);

for (index i = input.rows() - 2; i >= 0; --i)
{
topNodes[i].backwardFrame(topNodes[i + 1].getState());
bottomNodes[i].backwardFrame(topNodes[i].getState(),
bottomNodes[i + 1].getState());
}
}

return loss;
Expand All @@ -280,12 +290,16 @@ class Recur

CellSeries bottomNodes, topNodes;

bottomNodes.emplace_back(mBottomParams);
bottomNodes[0].forwardFrame(data.row(0), StateType{mBottomParams});
{
StateType bottomState{mBottomParams}, topState{mTopParams};

topNodes.emplace_back(mTopParams);
topNodes[0].forwardFrame(bottomNodes[0].getState().output(),
StateType{mTopParams});
bottomNodes.emplace_back(mBottomParams);
bottomNodes[0].forwardFrame(data.row(0), bottomState);

topNodes.emplace_back(mTopParams);
topNodes[0].forwardFrame(bottomNodes[0].getState().output(),
topState);
}

for (index i = 1; i < data.rows() - 1; ++i)
{
Expand All @@ -298,10 +312,15 @@ class Recur
}

double loss = 0.0;
loss += topNodes.back().backwardFrame(data.row(data.rows() - 1),
StateType{mTopParams});
bottomNodes.back().backwardFrame(topNodes.back().getState(),
StateType{mBottomParams});

{
StateType bottomState{mBottomParams}, topState{mTopParams};

loss += topNodes.back().backwardFrame(data.row(data.rows() - 1),
topState);
bottomNodes.back().backwardFrame(topNodes.back().getState(),
bottomState);
}

for (index i = data.rows() - 3; i >= 0; --i)
{
Expand Down

0 comments on commit 9aeec81

Please sign in to comment.