Skip to content

Commit

Permalink
complete encoder decoder and transformer model
Browse files Browse the repository at this point in the history
  • Loading branch information
mrityunjay-tripathi committed Aug 22, 2020
1 parent 0a307db commit 0c08f69
Show file tree
Hide file tree
Showing 11 changed files with 1,023 additions and 1 deletion.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
build*
xcode*
.vscode/
.DS_Store
.idea
cmake-build-*
Expand Down
10 changes: 9 additions & 1 deletion models/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
cmake_minimum_required(VERSION 3.1.0 FATAL_ERROR)
project(models)

add_subdirectory(darknet)
# Recurse into each model mlpack provides.
set(DIRS
darknet
transformer
)

foreach(dir ${DIRS})
add_subdirectory(${dir})
endforeach()

# Add directory name to sources.
set(DIR_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/)
Expand Down
9 changes: 9 additions & 0 deletions models/models.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
/**
* @file models.hpp
* @author Mrityunjay Tripathi
*
* This includes various models.
*/

#include "transformer/encoder.hpp"
#include "transformer/decoder.hpp"
20 changes: 20 additions & 0 deletions models/transformer/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
cmake_minimum_required(VERSION 3.1.0 FATAL_ERROR)
project(transformer)

set(DIR_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/)
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../")

set(SOURCES
decoder.hpp
decoder_impl.hpp
encoder.hpp
encoder_impl.hpp
transformer.hpp
transformer_impl.hpp
)

foreach(file ${SOURCES})
set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file})
endforeach()

set(DIRS ${DIRS} ${DIR_SRCS} PARENT_SCOPE)
230 changes: 230 additions & 0 deletions models/transformer/decoder.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
/**
* @file models/transformer/decoder.hpp
* @author Mikhail Lozhnikov
* @author Mrityunjay Tripathi
*
* Definition of the Transformer Decoder layer.
*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/

#ifndef MODELS_TRANSFORMER_DECODER_HPP
#define MODELS_TRANSFORMER_DECODER_HPP

#include <mlpack/prereqs.hpp>
#include <mlpack/methods/ann/layer/layer_types.hpp>
#include <mlpack/methods/ann/layer/base_layer.hpp>
#include <mlpack/methods/ann/regularizer/no_regularizer.hpp>

namespace mlpack {
namespace ann /** Artificial Neural Network. */ {

/**
* In addition to the two sub-layers in each encoder layer, the decoder inserts
* a third sub-layer, which performs multi-head attention over the output of the
* encoder stack. Similar to the encoder, we employ residual connections around
* each of the sub-layers, followed by layer normalization. We also modify the
* self-attention sub-layer in the decoder stack to prevent positions from
* attending to subsequent positions. This masking, combined with fact that the
* output embeddings are offset by one position, ensures that the predictions
* for position i can depend only on the known outputs at positions less than i.
*
* @tparam ActivationFunction The type of the activation function to be used in
* the position-wise feed forward neural network.
* @tparam RegularizerType The type of regularizer to be applied to layer
* parameters.
* @tparam InputDataType Type of the input data (arma::colvec, arma::mat,
* arma::sp_mat or arma::cube).
* @tparam OutputDataType Type of the output data (arma::colvec, arma::mat,
* arma::sp_mat or arma::cube).
*/
template <
typename ActivationFunction = ReLULayer<>,
typename RegularizerType = NoRegularizer,
typename InputDataType = arma::mat,
typename OutputDataType = arma::mat
>
class TransformerDecoder
{
public:
TransformerDecoder();

/**
* Create the TransformerDecoder object using the specified parameters.
*
* @param numLayers The number of decoder blocks.
* @param tgtSeqLen Target Sequence Length.
* @param srcSeqLen Source Sequence Length.
* @param memoryModule The last Encoder module.
* @param dModel The number of features in the input. Also, same as the
* 'embedDim' in 'MultiheadAttention' layer.
* @param numHeads The number of attention heads.
* @param dimFFN The dimentionality of feedforward network.
* @param dropout The dropout rate.
* @param attentionMask The attention mask used to black-out future sequences.
* @param keyPaddingMask The padding mask used to black-out particular token.
*/
TransformerDecoder(const size_t numLayers,
const size_t tgtSeqLen,
const size_t srcSeqLen,
const size_t dModel = 512,
const size_t numHeads = 8,
const size_t dimFFN = 1024,
const double dropout = 0.1,
const InputDataType& attentionMask = InputDataType(),
const InputDataType& keyPaddingMask = InputDataType());

/**
* Get the Transformer Decoder model.
*/
Sequential<>* Model() { return decoder; }
/**
* Load the network from a local directory.
*
* @param filepath The location of the stored model.
*/
void LoadModel(const std::string& filepath);

/**
* Save the network locally.
*
* @param filepath The location where the model is to be saved.
*/
void SaveModel(const std::string& filepath);

//! Get the key matrix, the output of the Transformer Encoder.
InputDataType const& Key() const { return key; }

//! Modify the key matrix.
InputDataType& Key() { return key; }

private:
/**
* This method adds the attention block to the decoder.
*/
void AttentionBlock()
{
Sequential<>* decoderBlockBottom = new Sequential<>();
decoderBlockBottom->Add<Subview<>>(1, 0, dModel * tgtSeqLen - 1, 0, -1);

// Broadcast the incoming input to decoder
// i.e. query into (query, key, value).
Concat<>* decoderInput = new Concat<>();
decoderInput->Add<IdentityLayer<>>();
decoderInput->Add<IdentityLayer<>>();
decoderInput->Add<IdentityLayer<>>();

// Masked Self attention layer.
Sequential<>* maskedSelfAttention = new Sequential<>();
maskedSelfAttention->Add(decoderInput);
maskedSelfAttention->Add<MultiheadAttention<
InputDataType, OutputDataType, RegularizerType>>(
tgtSeqLen,
tgtSeqLen,
dModel,
numHeads,
attentionMask
);

// Residual connection.
AddMerge<>* residualAdd = new AddMerge<>();
residualAdd->Add(maskedSelfAttention);
residualAdd->Add<IdentityLayer<>>();

decoderBlockBottom->Add(residualAddMerge);

// Add the LayerNorm layer with required parameters.
decoderBlockBottom->Add<LayerNorm<>>(dModel * tgtSeqLen);

// This layer broadcasts the output of encoder i.e. key into (key, value).
Concat<>* broadcastEncoderOutput = new Concat<>();
broadcastEncoderOutput->Add<Subview<>>(1, dModel * tgtSeqLen, -1, 0, -1);
broadcastEncoderOutput->Add<Subview<>>(1, dModel * tgtSeqLen, -1, 0, -1);

// This layer concatenates the output of the bottom decoder block (query)
// and the output of the encoder (key, value).
Concat<>* encoderDecoderAttentionInput = new Concat<>();
encoderDecoderAttentionInput->Add(decoderBlockBottom);
encoderDecoderAttentionInput->Add(broadcastEncoderOutput);

// Encoder-decoder attention.
Sequential<>* encoderDecoderAttention = new Sequential<>();
encoderDecoderAttention->Add(encoderDecoderAttentionInput);
encoderDecoderAttention->Add<MultiheadAttention<
InputDataType, OutputDataType, RegularizerType>>(
tgtSeqLen,
srcSeqLen,
dModel,
numHeads,
InputDatatype(), // No attention mask to encoder-decoder attention.
keyPaddingMask);

// Residual connection.
AddMerge<>* residualAdd = new AddMerge<>();
residualAdd->Add(encoderDecoderAttention);
residualAdd->Add<IdentityLayer<>>();

decoder->Add(residualAdd);
decoder->Add<LayerNorm<>>(dModel * tgtSeqLen);
}

/**
* This method adds the position-wise feed forward network to the decoder.
*/
void PositionWiseFFNBlock()
{
Sequential<>* positionWiseFFN = new Sequential<>();
positionWiseFFN->Add<Linear3D<>>(dModel, dimFFN);
positionWiseFFN->Add<ActivationFunction>();
positionWiseFFN->Add<Linear3D<>>(dimFFN, dModel);
positionWiseFFN->Add<Dropout<>>(dropout);

/* Residual connection. */
AddMerge<>* residualAdd = new AddMerge<>();
residualAdd->Add(positionWiseFFN);
residualAdd->Add<IdentityLayer<>>();
decoder->Add(residualAdd);
}

//! Locally-stored number of decoder layers.
size_t numLayers;

//! Locally-stored target sequence length.
size_t tgtSeqLen;

//! Locally-stored source sequence length.
size_t srcSeqLen;

//! Locally-stored number of input units.
size_t dModel;

//! Locally-stored number of output units.
size_t numHeads;

//! Locally-stored weight object.
size_t dimFFN;

//! Locally-stored weight parameters.
double dropout;

//! Locally-stored attention mask.
InputDataType attentionMask;

//! Locally-stored key padding mask.
InputDataType keyPaddingMask;

//! Locally-stored complete decoder network.
Sequential<InputDataType, OutputDataType, false>* decoder;

}; // class TransformerDecoder

} // namespace ann
} // namespace mlpack

// Include implementation.
#include "decoder_impl.hpp"

#endif
91 changes: 91 additions & 0 deletions models/transformer/decoder_impl.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/**
* @file models/transformer/decoder_impl.hpp
* @author Mikhail Lozhnikov
* @author Mrityunjay Tripathi
*
* Implementation of the Transformer Decoder class.
*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/

#ifndef MODELS_TRANSFORMER_DECODER_IMPL_HPP
#define MODELS_TRANSFORMER_DECODER_IMPL_HPP

#include "decoder.hpp"

namespace mlpack {
namespace ann /** Artificial Neural Network. */ {

template<typename ActivationFunction, typename RegularizerType,
typename InputDataType, typename OutputDataType>
TransformerDecoder<ActivationFunction, RegularizerType, InputDataType,
OutputDataType>::TransformerDecoder() :
tgtSeqLen(0),
srcSeqLen(0),
memoryModule(NULL),
dModel(0),
numHeads(0),
dimFFN(0),
dropout(0)
{
// Nothing to do here.
}

template<typename ActivationFunction, typename RegularizerType,
typename InputDataType, typename OutputDataType>
TransformerDecoder<ActivationFunction, RegularizerType, InputDataType,
OutputDataType>::TransformerDecoder(
const size_t numLayers,
const size_t tgtSeqLen,
const size_t srcSeqLen,
const size_t dModel,
const size_t numHeads,
const size_t dimFFN,
const double dropout,
const InputDataType& attentionMask,
const InputDataType& keyPaddingMask) :
numLayers(numLayers),
tgtSeqLen(tgtSeqLen),
srcSeqLen(srcSeqLen),
dModel(dModel),
numHeads(numHeads),
dimFFN(dimFFN),
dropout(dropout),
attentionMask(attentionMask),
keyPaddingMask(keyPaddingMask)
{
decoder = new Sequential<InputDataType, OutputDataType, false>();

for (size_t N = 0; N < numLayers; ++N)
{
AttentionBlock();
PositionWiseFFNBlock();
}
}

template<typename ActivationFunction, typename RegularizerType,
typename InputDataType, typename OutputDataType>
void TransformerDecoder<ActivationFunction, RegularizerType,
InputDataType, OutputDataType>::LoadModel(const std::string& filepath)
{
data::Load(filepath, "TransformerDecoder", decoder);
std::cout << "Loaded model" << std::endl;
}

template<typename ActivationFunction, typename RegularizerType,
typename InputDataType, typename OutputDataType>
void TransformerDecoder<ActivationFunction, RegularizerType,
InputDataType, OutputDataType>::SaveModel(const std::string& filepath)
{
std::cout << "Saving model" << std::endl;
data::Save(filepath, "TransformerDecoder", decoder);
std::cout << "Model saved in " << filepath << std::endl;
}

} // namespace ann
} // namespace mlpack

#endif
Loading

0 comments on commit 0c08f69

Please sign in to comment.