diff --git a/include/compute/add/addop.h b/include/compute/add/addop.h deleted file mode 100644 index a065c2b..0000000 --- a/include/compute/add/addop.h +++ /dev/null @@ -1,58 +0,0 @@ -/** - * @file addop.h - * @author Daniel Nichols - * @version 1.0 - * @date 2019-02-18 - * - * @copyright Copyright (c) 2019 - */ -#pragma once -#include -#include "compute/operation.h" -#include "geadd_internal.h" -#include "tensor/tensor.h" - -namespace magmadnn { -namespace op { - -/** An addition operation on two tensors. - * @tparam T - */ -template -class AddOp : public Operation { - public: - /** Creates an Add Operation, which adds two tensors together. - * @param a a tensor - * @param b a tensor - * @param copy copy into new tensor - * @param needs_grad if this needs a gradient - */ - AddOp(Operation *a, Operation *b, bool copy = true, bool needs_grad = true); - - std::string to_string() { return "(" + a->to_string() + " + " + b->to_string() + ")"; } - - protected: - Tensor *_eval(bool recompute = true); - Tensor &_grad(Operation *consumer, Operation *var, const Tensor &grad); - - Operation *a; - Operation *b; - - Tensor *a_tensor; - Tensor *b_tensor; - - bool copy; -}; - -/** Returns a new add operation (@see AddOp). - * @tparam T - * @param a - * @param b - * @param copy If copy is true then it returns a new tensor, if false then b=a+b. - * @return AddOp* - */ -template -AddOp *add(Operation *a, Operation *b, bool copy = true, bool needs_grad = true); - -} // namespace op -} // namespace magmadnn diff --git a/include/compute/add/geadd_internal.h b/include/compute/add/geadd_internal.h deleted file mode 100644 index fe5aef8..0000000 --- a/include/compute/add/geadd_internal.h +++ /dev/null @@ -1,72 +0,0 @@ -/** - * @file geadd_internal.h - * @author Daniel Nichols - * @version 1.0 - * @date 2019-02-22 - * - * @copyright Copyright (c) 2019 - */ -#pragma once -#include "cblas.h" -#include "tensor/tensor.h" - -namespace magmadnn { -namespace internal { - -/** Returns true if A, B, C are valid parameters to geadd_full. - * @tparam T - * @param A a tensor - * @param B a tensor - * @param C a tensor - * @return true - * @return false - */ -template -bool geadd_check(Tensor *A, Tensor *B, Tensor *C); - -/** Computes C = alpha*A + beta*B All tensors must have the same memory type and shape/size. - * @tparam T int, float, or double - * @param alpha scaling value - * @param A a tensor - * @param beta scaling value - * @param B a tensor - * @param C a tensor - */ -template -void geadd_full(T alpha, Tensor *A, T beta, Tensor *B, Tensor *C); - -#if defined(_HAS_CUDA_) -/** Computes C=alpha*A + beta*B - * @tparam T int, float, double - * @param alpha scaling value - * @param A tensor - * @param beta scaling value - * @param B tensor - * @param C tensor - */ -template -void geadd_full_device(T alpha, Tensor *A, T beta, Tensor *B, Tensor *C); -#endif - -/** - * @tparam T numeric - * @param alpha - * @param x - * @param out - */ -template -void tensor_scalar_add_full(T alpha, Tensor *x, Tensor *out); - -#if defined(_HAS_CUDA_) -/** - * @tparam T numeric - * @param alpha - * @param x - * @param out - */ -template -void tensor_scalar_add_full_device(T alpha, Tensor *x, Tensor *out); -#endif - -} // namespace internal -} // namespace magmadnn diff --git a/include/compute/batchnorm/batchnormop.h b/include/compute/batchnorm/batchnormop.h index a42a005..c8af2a8 100644 --- a/include/compute/batchnorm/batchnormop.h +++ b/include/compute/batchnorm/batchnormop.h @@ -2,6 +2,7 @@ #pragma once #include +#include "compute/compute_graph.h" #include "compute/operation.h" #include "math/batchnorm.h" #include "tensor/tensor.h" @@ -13,31 +14,27 @@ namespace magmadnn { namespace op { -template -class BatchNormOp : public Operation { +class BatchNormOp : public Operation { public: - BatchNormOp(Operation *input, bool needs_grad = true); + BatchNormOp(Operation *input); - virtual ~BatchNormOp(); - - std::string to_string() { return "BatchNorm(" + input->to_string() + ")"; } + std::string to_string() const override { return "BatchNorm(" + input->to_string() + ")"; } protected: - Tensor *_eval(bool recompute); - Tensor *_grad(Operation *consumer, Operation *var, Tensor *grad); + Tensor &_eval(bool recompute) override; + Tensor &_grad(Operation *consumer, Operation *var, const Tensor &grad) override; - Operation *input; - Tensor *input_tensor; + Operation *input; unsigned int num_calls; - Tensor *bn_scale; - Tensor *bn_scale_diff; - Tensor *bn_bias; - Tensor *bn_bias_diff; - Tensor *running_mean; - Tensor *running_variance; - Tensor *saved_mean; - Tensor *saved_variance; + Tensor bn_scale; + Tensor bn_scale_diff; + Tensor bn_bias; + Tensor bn_bias_diff; + Tensor running_mean; + Tensor running_variance; + Tensor saved_mean; + Tensor saved_variance; #if defined(_HAS_CUDA_) void init_settings(); @@ -48,8 +45,7 @@ class BatchNormOp : public Operation { bool copy; }; -template -BatchNormOp *batchnorm(Operation *input, bool needs_grad = true); +inline Operation *batchnorm(Operation *input) { return default_graph.add_operation(input); } } // namespace op } // namespace magmadnn \ No newline at end of file diff --git a/include/compute/binaryop/binaryop.h b/include/compute/binaryop/binaryop.h new file mode 100644 index 0000000..4f650dd --- /dev/null +++ b/include/compute/binaryop/binaryop.h @@ -0,0 +1,62 @@ +/** + * @file binaryop.h + * @author Daniel Nichols + * @version 0.1 + * @date 2019-08-08 + * + * @copyright Copyright (c) 2019 + */ +#pragma once + +#include "compute/compute_graph.h" +#include "compute/operation.h" + +#include "magmadnn_device_types.h" + +#include "math/binary_math_operations.h" +#include "math/launch_math_kernel.h" + +namespace magmadnn { +namespace op { + +template +class BinaryOp : public Operation { + public: + BinaryOp(Operation *x, Operation *y) { + this->use_tensor_settings(x, true); + + this->output_tensor_ = Tensor(this->output_shape_, this->dtype_, {NONE}, this->mem_type_); + } + + std::string to_string() const override { return "BIN_OP(" + x->to_string() + ", " + y->to_string() + ")"; } + + protected: + Tensor &_eval(bool recompute = true) override { + Tensor &x_tensor = x->eval(recompute); + Tensor &y_tensor = y->eval(recompute); + + FOR_ALL_DEVICE_TYPES(getDeviceType(this->mem_type_), DEV_TYPE, { + ::magmadnn::math::ParallelLauncher(x_tensor, y_tensor, this->output_tensor_); + }) + } + + Tensor &_grad(Operation *consumer, Operation *var, const Tensor &grad) override {} + + Operation *x, *y; +}; + +#define MAKE_BINARY(name) \ + inline Operation *name(Operation *a, Operation *b) { \ + return default_graph.add_operation>(a, b); \ + } + +MAKE_BINARY(add) +MAKE_BINARY(sub) +MAKE_BINARY(product) +MAKE_BINARY(div) +MAKE_BINARY(pow) + +#undef MAKE_BINARY + +} // namespace op +} // namespace magmadnn \ No newline at end of file diff --git a/include/compute/compute_graph.h b/include/compute/compute_graph.h index aa40575..8f98eb0 100644 --- a/include/compute/compute_graph.h +++ b/include/compute/compute_graph.h @@ -31,7 +31,7 @@ class Graph { template inline Operation* Graph::add_operation(Args... args) { // std::unique_ptr tmp_ptr{new op_type(args)}; - std::unique_ptr tmp_ptr = ::magmadnn::internal::make_unique(args); + std::unique_ptr tmp_ptr = ::magmadnn::internal::make_unique(args...); /* use std::move to transfer ownership */ this->nodes.push_back(std::move(tmp_ptr)); diff --git a/include/compute/conv2dforward/conv2dforward_internal.h b/include/compute/conv2dforward/conv2dforward_internal.h deleted file mode 100644 index 32dce07..0000000 --- a/include/compute/conv2dforward/conv2dforward_internal.h +++ /dev/null @@ -1,13 +0,0 @@ - -#pragma once - -#include "tensor/tensor.h" - -namespace magmadnn { -namespace internal { - -template -void conv2dforward_full(Tensor *in, Tensor *out); - -} // namespace internal -} // namespace magmadnn \ No newline at end of file diff --git a/include/compute/conv2dforward/conv2dforwardop.h b/include/compute/conv2dforward/conv2dforwardop.h index b083b76..e50ce0c 100644 --- a/include/compute/conv2dforward/conv2dforwardop.h +++ b/include/compute/conv2dforward/conv2dforwardop.h @@ -1,7 +1,7 @@ #pragma once -#include "compute/conv2dforward/conv2dforward_internal.h" +#include "compute/compute_graph.h" #include "compute/operation.h" #include "math/conv2d.h" #include "tensor/tensor.h" @@ -9,25 +9,24 @@ namespace magmadnn { namespace op { -template -class Conv2DForwardOp : public Operation { +class Conv2DForwardOp : public Operation { public: - Conv2DForwardOp(Operation *input, Operation *filter, int pad_h = 0, int pad_w = 0, int vertical_stride = 1, + Conv2DForwardOp(Operation *input, Operation *filter, int pad_h = 0, int pad_w = 0, int vertical_stride = 1, int horizontal_stride = 1, int dilation_h = 1, int dilation_w = 1, bool use_cross_correlation = true, bool needs_grad = true); ~Conv2DForwardOp(); - std::string to_string() { return "Conv2DForward(" + input->to_string() + ")"; } + std::string to_string() const override { return "Conv2DForward(" + input->to_string() + ")"; } protected: - Tensor *_eval(bool recompute); - Tensor &_grad(Operation *consumer, Operation *var, const Tensor &grad); + Tensor &_eval(bool recompute) override; + Tensor &_grad(Operation *consumer, Operation *var, const Tensor &grad) override; void init_settings(); void calculate_and_set_output_shape(); - Operation *input, *filter; - Tensor *input_tensor, *filter_tensor; + Operation *input, *filter; + // Tensor *input_tensor, *filter_tensor; int pad_h, pad_w, vertical_stride, horizontal_stride, dilation_h, dilation_w; bool use_cross_correlation; @@ -37,10 +36,12 @@ class Conv2DForwardOp : public Operation { #endif }; -template -Conv2DForwardOp *conv2dforward(Operation *input, Operation *filter, int pad_h = 0, int pad_w = 0, - int vertical_stride = 1, int horizontal_stride = 1, int dilation_h = 1, - int dilation_w = 1, bool use_cross_correlation = true, bool needs_grad = true); +inline Operation *conv2dforward(Operation *input, Operation *filter, int pad_h = 0, int pad_w = 0, + int vertical_stride = 1, int horizontal_stride = 1, int dilation_h = 1, + int dilation_w = 1, bool use_cross_correlation = true, bool needs_grad = true) { + return default_graph.add_operation(input, filter, pad_h, pad_w, vertical_stride, horizontal_stride, + dilation_h, dilation_w, use_cross_correlation, needs_grad); +} } // namespace op } // namespace magmadnn \ No newline at end of file diff --git a/include/compute/crossentropy/crossentropy_internal.h b/include/compute/crossentropy/crossentropy_internal.h deleted file mode 100644 index 0436720..0000000 --- a/include/compute/crossentropy/crossentropy_internal.h +++ /dev/null @@ -1,29 +0,0 @@ -/** - * @file crossentropy_internal.h - * @author Daniel Nichols - * @version 0.1 - * @date 2019-05-30 - * - * @copyright Copyright (c) 2019 - */ -#pragma once - -#include "tensor/tensor.h" - -#if defined(_HAS_CUDA_) -#include -#endif - -namespace magmadnn { -namespace internal { - -template -void crossentropy_full(Tensor *x, Tensor *y, Tensor *softmax, Tensor *out); - -#if defined(_HAS_CUDA_) -template -void crossentropy_full_device(Tensor *x, Tensor *y, Tensor *softmax, Tensor *out); -#endif - -} // namespace internal -} // namespace magmadnn \ No newline at end of file diff --git a/include/compute/crossentropy/crossentropyop.h b/include/compute/crossentropy/crossentropyop.h index bb794c2..8a78100 100644 --- a/include/compute/crossentropy/crossentropyop.h +++ b/include/compute/crossentropy/crossentropyop.h @@ -1,7 +1,6 @@ #pragma once -#include "compute/crossentropy/crossentropy_internal.h" #include "compute/log/logop.h" #include "compute/negative/negativeop.h" #include "compute/operation.h" @@ -14,24 +13,6 @@ namespace magmadnn { namespace op { -template -class CrossEntropyOp : public Operation { - public: - CrossEntropyOp(Operation *x, Operation *y, bool copy = true, bool needs_grad = true); - ~CrossEntropyOp(); - - std::string to_string() { return "CrossEntropy(Softmax(" + x->to_string() + "), " + y->to_string() + ")"; } - - protected: - Tensor *_eval(bool recompute = true); - Tensor &_grad(Operation *consumer, Operation *var, const Tensor &grad); - - Operation *x, *y; - Tensor *x_tensor, *y_tensor, *softmax; /* scratch is used in the interal calc */ - - bool copy; -}; - /** Returns an operation, which computes the crossentropy between ground_truth and predicted. This must be passed * one-hot encoded data. If not one-hot encoded the return values will not be correct or an error may occur. This * operation is equivalent to `negative(reducesum(reducesum(product(ground_truth, log(predicted)), axis=1), axis=0)) @@ -42,9 +23,7 @@ class CrossEntropyOp : public Operation { * @param needs_grad if this operation needs a gradient or not * @return Operation* an operation with output size = {1}. This scalar represents the crossentropy. */ -template -Operation *crossentropy(Operation *ground_truth, Operation *predicted, bool copy = true, - bool needs_grad = true); +Operation *crossentropy(Operation *ground_truth, Operation *predicted, bool needs_grad = true); } // namespace op } // namespace magmadnn \ No newline at end of file diff --git a/include/compute/div/divop.h b/include/compute/div/divop.h index 6fc1d60..5e5da47 100644 --- a/include/compute/div/divop.h +++ b/include/compute/div/divop.h @@ -28,7 +28,7 @@ class DivOp : public Operation { std::string to_string() { return "( " + a->to_string() + " / " + b->to_string() + " )"; } protected: - Tensor *_eval(bool recompute = true); + Tensor &_eval(bool recompute = true); Tensor &_grad(Operation *consumer, Operation *var, const Tensor &grad); Operation *a, *b; diff --git a/include/compute/dropout/dropoutop.h b/include/compute/dropout/dropoutop.h index efa9b33..905a618 100644 --- a/include/compute/dropout/dropoutop.h +++ b/include/compute/dropout/dropoutop.h @@ -23,7 +23,7 @@ class DropoutOp : public Operation { std::string to_string() { return "Dropout(" + input->to_string() + ")"; } protected: - Tensor *_eval(bool recompute); + Tensor &_eval(bool recompute); Tensor &_grad(Operation *consumer, Operation *var, const Tensor &grad); Operation *input; diff --git a/include/compute/flatten/flattenop.h b/include/compute/flatten/flattenop.h index 8170213..4d7abb3 100644 --- a/include/compute/flatten/flattenop.h +++ b/include/compute/flatten/flattenop.h @@ -15,7 +15,7 @@ class FlattenOp : public Operation { std::string to_string() { return "Flatten(" + input->to_string() + ")"; } protected: - Tensor *_eval(bool recompute); + Tensor &_eval(bool recompute); Tensor &_grad(Operation *consumer, Operation *var, const Tensor &grad); Operation *input; diff --git a/include/compute/gradients.h b/include/compute/gradients.h index f44f64a..c95104e 100644 --- a/include/compute/gradients.h +++ b/include/compute/gradients.h @@ -10,7 +10,7 @@ #pragma once #include -#include "compute/add/addop.h" +#include "compute/binaryop/binaryop.h" #include "compute/gradtable.h" #include "compute/operation.h" #include "compute/sum/sumop.h" @@ -27,8 +27,7 @@ namespace op { * @param table GradTable to be filled in * @return magmadnn_error_t non-zero on error */ -template -magmadnn_error_t get_grad_table(const std::vector *> &vars, Operation *graph, GradTable &table); +magmadnn_error_t get_grad_table(const std::vector &vars, Operation *graph, GradTable &table); } // namespace op @@ -42,8 +41,7 @@ namespace internal { * @param table GradTable to put gradients in * @return magmadnn_error_t non-zero on error */ -template -magmadnn_error_t build_grad(op::Operation *var, op::Operation *graph, op::GradTable &table, Tensor **grad); +magmadnn_error_t build_grad(op::Operation *var, op::Operation *graph, op::GradTable &table, Tensor &grad); } // namespace internal } // namespace magmadnn \ No newline at end of file diff --git a/include/compute/gradtable.h b/include/compute/gradtable.h index bddbc9f..9bb56a5 100644 --- a/include/compute/gradtable.h +++ b/include/compute/gradtable.h @@ -20,7 +20,6 @@ namespace op { /** GradTable class. * @tparam T Numeric */ -template class GradTable { public: /** Constructs a new grad table. @@ -36,21 +35,21 @@ class GradTable { * @param var * @return Operation* */ - Tensor* get(Operation* var); + std::pair> get(Operation* var); /** Sets var's gradient to grad. * @param var * @param grad */ - void set(Operation* var, Tensor* grad); + void set(Operation* var, Tensor& grad); /** Removes all entries. */ void clear(); protected: - std::map*, Tensor*> _table; // the underlying table to store data - typename std::map*, Tensor*>::iterator tmp_map_iterator; + std::map> _table; // the underlying table to store data + typename std::map>::iterator tmp_map_iterator; }; } // namespace op diff --git a/include/compute/linearforward/linearforwardop.h b/include/compute/linearforward/linearforwardop.h index 17a2cea..a3a8c3f 100644 --- a/include/compute/linearforward/linearforwardop.h +++ b/include/compute/linearforward/linearforwardop.h @@ -21,7 +21,7 @@ class LinearForwardOp : public Operation { std::string to_string() { return "LinearForward(" + input->to_string() + ", " + weights->to_string() + ")"; } protected: - Tensor *_eval(bool recompute); + Tensor &_eval(bool recompute); Tensor &_grad(Operation *consumer, Operation *var, const Tensor &grad); void init_bias_settings(); /* init ones and bias_reduce_settings */ diff --git a/include/compute/log/logop.h b/include/compute/log/logop.h index 43cf9e1..91a0f74 100644 --- a/include/compute/log/logop.h +++ b/include/compute/log/logop.h @@ -19,7 +19,7 @@ class LogOp : public Operation { std::string to_string() { return "log( " + x->to_string() + " )"; } protected: - Tensor *_eval(bool recompute = true); + Tensor &_eval(bool recompute = true); Tensor &_grad(Operation *consumer, Operation *var, Tensor *grad); Operation *x; diff --git a/include/compute/matmul/matmulop.h b/include/compute/matmul/matmulop.h index 6f2a978..a252584 100644 --- a/include/compute/matmul/matmulop.h +++ b/include/compute/matmul/matmulop.h @@ -29,7 +29,7 @@ class MatmulOp : public Operation { std::string to_string() { return "(" + a->to_string() + " x " + b->to_string() + ")"; } protected: - Tensor *_eval(bool recompute = true); + Tensor &_eval(bool recompute = true); Tensor &_grad(Operation *consumer, Operation *var, const Tensor &grad); Operation *a; diff --git a/include/compute/negative/negativeop.h b/include/compute/negative/negativeop.h index 01c54ed..34acc4b 100644 --- a/include/compute/negative/negativeop.h +++ b/include/compute/negative/negativeop.h @@ -16,7 +16,7 @@ class NegativeOp : public Operation { std::string to_string() { return "-" + x->to_string() + ""; } protected: - Tensor *_eval(bool recompute = true); + Tensor &_eval(bool recompute = true); Tensor &_grad(Operation *consumer, Operation *var, const Tensor &grad); Operation *x; diff --git a/include/compute/op_utilities.h b/include/compute/op_utilities.h index f156332..9df5e45 100644 --- a/include/compute/op_utilities.h +++ b/include/compute/op_utilities.h @@ -16,8 +16,7 @@ namespace magmadnn { namespace op { namespace utility { -template -magmadnn_error_t print_compute_graph(::magmadnn::op::Operation *_root, bool debug = true); +magmadnn_error_t print_compute_graph(::magmadnn::op::Operation *_root, bool debug = true); } } // namespace op } // namespace magmadnn \ No newline at end of file diff --git a/include/compute/operation.h b/include/compute/operation.h index 0113be4..4955c38 100644 --- a/include/compute/operation.h +++ b/include/compute/operation.h @@ -29,6 +29,8 @@ class Operation { } } + virtual ~Operation() {} + virtual Operation &operator=(const Operation &o) = delete; /** Returns the expected output shape of this operation. @@ -77,6 +79,8 @@ class Operation { */ virtual memory_t get_memory_type() const { return this->mem_type_; } + virtual DataType dtype() const { return this->dtype_; }; + /** Returns the operation's evaluated tensor. * @param recompute whether to use previous value or recalculate * @return Tensor* @@ -174,6 +178,17 @@ class Operation { } } + inline void use_operation_settings(const Operation *o, bool use_shape = true) { + if (o == nullptr) { + return; + } + this->mem_type_ = o->get_memory_type(); + this->dtype_ = o->dtype(); + if (use_shape) { + this->output_shape_ = o->get_output_shape(); + } + } + std::vector inputs_; /* children */ std::vector consumers_; /* parents */ @@ -182,8 +197,7 @@ class Operation { DataType dtype_; /* TODO -- get rid of _grad_cache */ - std::map> - _grad_cache; /* this will cache the tensors for the gradient computation */ + std::map _grad_cache; /* this will cache the tensors for the gradient computation */ std::string name_ = "DefaultOpName"; Tensor output_tensor_; /* the return tensor */ diff --git a/include/compute/pooling/poolingop.h b/include/compute/pooling/poolingop.h index 60f873c..5090fa1 100644 --- a/include/compute/pooling/poolingop.h +++ b/include/compute/pooling/poolingop.h @@ -19,7 +19,7 @@ class PoolingOp : public Operation { std::string to_string() { return "Pooling(" + input->to_string() + ")"; } protected: - Tensor *_eval(bool recompute); + Tensor &_eval(bool recompute); Tensor &_grad(Operation *consumer, Operation *var, const Tensor &grad); void init_settings(); diff --git a/include/compute/pow/powop.h b/include/compute/pow/powop.h index 7d0d25e..84a2e6a 100644 --- a/include/compute/pow/powop.h +++ b/include/compute/pow/powop.h @@ -17,7 +17,7 @@ class PowOp : public Operation { std::string to_string() { return "POW(" + input->to_string() + ",)"; } protected: - Tensor *_eval(bool recompute); + Tensor &_eval(bool recompute); Tensor &_grad(Operation *consumer, Operation *var, const Tensor &grad); Operation *input; diff --git a/include/compute/product/productop.h b/include/compute/product/productop.h index aa57922..31c1669 100644 --- a/include/compute/product/productop.h +++ b/include/compute/product/productop.h @@ -32,7 +32,7 @@ class ProductOp : public Operation { std::string to_string() { return "(" + a->to_string() + " * " + b->to_string() + ")"; } protected: - Tensor *_eval(bool recompute = true); + Tensor &_eval(bool recompute = true); Tensor &_grad(Operation *consumer, Operation *var, const Tensor &grad); T alpha; diff --git a/include/compute/reducesum/reducesumop.h b/include/compute/reducesum/reducesumop.h index b4fd24f..620eff7 100644 --- a/include/compute/reducesum/reducesumop.h +++ b/include/compute/reducesum/reducesumop.h @@ -20,7 +20,7 @@ class ReduceSumOp : public Operation { std::string to_string() { return "ReduceSum( " + x->to_string() + " )"; } protected: - Tensor *_eval(bool recompute = true); + Tensor &_eval(bool recompute = true); Tensor &_grad(Operation *consumer, Operation *var, const Tensor &grad); Operation *x; diff --git a/include/compute/relu/reluop.h b/include/compute/relu/reluop.h index 8bf8975..8754bea 100644 --- a/include/compute/relu/reluop.h +++ b/include/compute/relu/reluop.h @@ -24,7 +24,7 @@ class ReluOp : public Operation { std::string to_string() { return "RELU( " + x->to_string() + " )"; } protected: - Tensor *_eval(bool recompute = true); + Tensor &_eval(bool recompute = true); Tensor &_grad(Operation *consumer, Operation *var, const Tensor &grad); Operation *x; diff --git a/include/compute/scalarproduct/scalarproductop.h b/include/compute/scalarproduct/scalarproductop.h index 541829c..ecf1bea 100644 --- a/include/compute/scalarproduct/scalarproductop.h +++ b/include/compute/scalarproduct/scalarproductop.h @@ -20,7 +20,7 @@ class ScalarProductOp : public Operation { std::string to_string(); protected: - Tensor *_eval(bool recompute = true); + Tensor &_eval(bool recompute = true); Tensor &_grad(Operation *consumer, Operation *var, const Tensor &grad); T alpha; diff --git a/include/compute/sigmoid/sigmoidop.h b/include/compute/sigmoid/sigmoidop.h index eb9bc04..844cd55 100644 --- a/include/compute/sigmoid/sigmoidop.h +++ b/include/compute/sigmoid/sigmoidop.h @@ -30,7 +30,7 @@ class SigmoidOp : public Operation { std::string to_string() { return "SIGMOID( " + x->to_string() + " )"; } protected: - Tensor *_eval(bool recompute = true); + Tensor &_eval(bool recompute = true); Tensor &_grad(Operation *consumer, Operation *var, const Tensor &grad); Operation *x; diff --git a/include/compute/softmax/softmaxop.h b/include/compute/softmax/softmaxop.h index 9587fcd..ea22ad5 100644 --- a/include/compute/softmax/softmaxop.h +++ b/include/compute/softmax/softmaxop.h @@ -21,7 +21,7 @@ class SoftmaxOp : public Operation { std::string to_string() { return "Softmax(" + input->to_string() + ")"; } protected: - Tensor *_eval(bool recompute); + Tensor &_eval(bool recompute); Tensor &_grad(Operation *consumer, Operation *var, const Tensor &grad); Operation *input; diff --git a/include/compute/sum/sum_internal.h b/include/compute/sum/sum_internal.h index fc412ae..26de087 100644 --- a/include/compute/sum/sum_internal.h +++ b/include/compute/sum/sum_internal.h @@ -18,12 +18,11 @@ namespace internal { * @tparam T * @param vals */ -template -void sum_full(std::vector *> &vals, Tensor &out); +void sum_full(const std::vector> &vals, Tensor &out); #if defined(_HAS_CUDA_) template -void sum_full_device(std::vector *> &vals, Tensor &out); +void sum_full_device(const std::vector> &vals, Tensor &out); #endif } // namespace internal diff --git a/include/compute/sum/sumop.h b/include/compute/sum/sumop.h index 8c3f987..ce07100 100644 --- a/include/compute/sum/sumop.h +++ b/include/compute/sum/sumop.h @@ -10,6 +10,7 @@ #include #include +#include "compute/compute_graph.h" #include "compute/operation.h" #include "compute/sum/sum_internal.h" #include "tensor/tensor.h" @@ -18,23 +19,20 @@ namespace magmadnn { namespace op { -template -class SumOp : public Operation { +class SumOp : public Operation { public: - SumOp(std::vector *> ops, bool copy = true); + SumOp(std::vector ops); - std::string to_string(); + std::string to_string() const override; protected: - Tensor *_eval(bool recompute = true); - Tensor &_grad(Operation *consumer, Operation *var, const Tensor &grad); + Tensor &_eval(bool recompute = true) override; + Tensor &_grad(Operation *consumer, Operation *var, const Tensor &grad) override; - std::vector *> ops; - bool copy; + std::vector ops; }; -template -Operation *sum(std::vector *> ops, bool copy = true); +inline Operation *sum(std::vector ops) { return default_graph.add_operation(ops); } } // namespace op } // namespace magmadnn \ No newline at end of file diff --git a/include/compute/tanh/tanhop.h b/include/compute/tanh/tanhop.h index b6162fb..5cb9425 100644 --- a/include/compute/tanh/tanhop.h +++ b/include/compute/tanh/tanhop.h @@ -25,7 +25,7 @@ class TanhOp : public Operation { std::string to_string() { return "TANH( " + x->to_string() + " )"; } protected: - Tensor *_eval(bool recompute = true); + Tensor &_eval(bool recompute = true); Tensor &_grad(Operation *consumer, Operation *var, const Tensor &grad); Operation *x; diff --git a/include/compute/transpose/transposeop.h b/include/compute/transpose/transposeop.h index b110746..5a090ba 100644 --- a/include/compute/transpose/transposeop.h +++ b/include/compute/transpose/transposeop.h @@ -17,7 +17,7 @@ class TransposeOp : public Operation { std::string to_string() { return x->to_string() + ".T"; } protected: - Tensor *_eval(bool recompute = true); + Tensor &_eval(bool recompute = true); Tensor &_grad(Operation *consumer, Operation *var, const Tensor &grad); Operation *x; diff --git a/include/compute/unaryop/unaryop.h b/include/compute/unaryop/unaryop.h new file mode 100644 index 0000000..c7be2dd --- /dev/null +++ b/include/compute/unaryop/unaryop.h @@ -0,0 +1,41 @@ +/** + * @file unaryop.h + * @author Daniel Nichols + * @version 0.1 + * @date 2019-08-09 + * + * @copyright Copyright (c) 2019 + */ +#pragma once + +#include "compute/compute_graph.h" +#include "compute/operation.h" + +namespace magmadnn { +namespace op { + +template +class UnaryOp : public Operation { + public: + UnaryOp(Operation *x); + + std::string to_string() const override { return "UNARY_OP(" + x->to_string() + ")"; } + + protected: + Tensor &_eval(bool recompute = true) override; + Tensor &_grad(Operation *consumer, Operation *var, const Tensor &grad) override; + + Operation *x; +}; + +#define MAKE_UNARY(name) \ + inline Operation *name(Operation *a, Operation *out) { \ + return default_graph.add_operation>(a, out); \ + } + +MAKE_UNARY(log) + +#undef MAKE_UNARY + +} // namespace op +} // namespace magmadnn \ No newline at end of file diff --git a/include/compute/variable.h b/include/compute/variable.h index 6919564..109599a 100644 --- a/include/compute/variable.h +++ b/include/compute/variable.h @@ -68,8 +68,8 @@ inline Operation *var(std::string name, const std::vector &shape, DataT */ template inline Operation *scalar(std::string name, T val, memory_t mem_type) { - return default_graph.add_operation(name, 1, ::GetDataType::value, {CONSTANT, static_cast(val)}, - mem_type); + return default_graph.add_operation(name, 1, ::magmadnn::GetDataType::value, + {CONSTANT, static_cast(val)}, mem_type); } } // namespace op } // namespace magmadnn diff --git a/include/math/batchnorm.h b/include/math/batchnorm.h index 8432857..694f05e 100644 --- a/include/math/batchnorm.h +++ b/include/math/batchnorm.h @@ -19,10 +19,8 @@ namespace magmadnn { namespace math { -template void batchnorm(const Tensor &x, Tensor &out); -template void batchnorm_grad(const Tensor &grad, Tensor &out); #if defined(_HAS_CUDA_) @@ -33,11 +31,10 @@ struct cudnn_batchnorm_settings_t { cudnnTensorDescriptor_t bn_tensor_desc; }; -template void batchnorm_device(const Tensor &x, Tensor &out, Tensor &bn_scale, Tensor &bn_bias, Tensor &running_mean, Tensor &running_variance, Tensor &saved_mean, Tensor &saved_variance, unsigned int &num_calls, cudnn_batchnorm_settings_t settings); -template + void batchnorm_grad_device(const Tensor &x, const Tensor &grad, Tensor &out, Tensor &bn_scale, Tensor &bn_scale_diff, Tensor &bn_bias_diff, Tensor &saved_mean, Tensor &saved_variance, cudnn_batchnorm_settings_t settings); diff --git a/scripts/operation_templates/op_header_template.h b/scripts/operation_templates/op_header_template.h index 284b67c..ac39979 100644 --- a/scripts/operation_templates/op_header_template.h +++ b/scripts/operation_templates/op_header_template.h @@ -1,33 +1,32 @@ #pragma once +#include "compute/<#OPERATION_NAME_LOWER#>/<#OPERATION_NAME_LOWER#>_internal.h" #include "compute/operation.h" #include "tensor/tensor.h" -#include "compute/<#OPERATION_NAME_LOWER#>/<#OPERATION_NAME_LOWER#>_internal.h" namespace magmadnn { namespace op { template -class <#OPERATION_NAME#>Op : public Operation { -public: - <#OPERATION_NAME#>Op(Operation *input, bool copy=true, bool needs_grad=true); +class<#OPERATION_NAME #> Op : public Operation { + public: + <#OPERATION_NAME #> Op(Operation *input, bool copy = true, bool needs_grad = true); - - std::string to_string() { return ""; } -protected: - Tensor *_eval(bool recompute); - Tensor *_grad(Operation *consumer, Operation *var, Tensor *grad); + std::string to_string() { return ""; } - Operation *input; - Tensor *input_tensor; + protected: + Tensor &_eval(borecompute); + Tensor *_grad(Operation *consumer, Operation *var, Tensor *grad); - bool copy; + Operation *input; + Tensor *input_tensor; + bool copy; }; template -<#OPERATION_NAME#>Op* <#OPERATION_NAME_LOWER#>(Operation *input, bool copy=true, bool needs_grad=true); +<#OPERATION_NAME #> Op *<#OPERATION_NAME_LOWER #>(Operation *input, bool copy = true, bool needs_grad = true); -} // namespace op -} // namespace magmadnn \ No newline at end of file +} // namespace op +} // namespace magmadnn \ No newline at end of file diff --git a/src/compute/add/addop.cpp b/src/compute/add/addop.cpp deleted file mode 100644 index 9988e92..0000000 --- a/src/compute/add/addop.cpp +++ /dev/null @@ -1,68 +0,0 @@ -/** - * @file add_op.cpp - * @author Daniel Nichols - * @version 1.0 - * @date 2019-02-20 - * - * @copyright Copyright (c) 2019 - */ -#include "compute/add/addop.h" - -namespace magmadnn { -namespace op { - -template -AddOp::AddOp(Operation *a, Operation *b, bool copy, bool needs_grad) - : Operation::Operation({a, b}, needs_grad), a(a), b(b), copy(copy) { - assert(a->get_memory_type() == b->get_memory_type()); - assert(a->get_output_size() == b->get_output_size() || a->get_output_size() == 1 || b->get_output_size() == 1); - - /* if a is scalar then use b's size */ - if (a->get_output_size() == 1) { - this->output_shape = b->get_output_shape(); - } else { - /* other wise a's size is good */ - this->output_shape = a->get_output_shape(); - } - this->mem_type = a->get_memory_type(); - - /* Go ahead and create copy tensor if we can */ - this->output_tensor = new Tensor(this->output_shape, {NONE, {}}, this->mem_type); -} - -template -Tensor *AddOp::_eval(bool recompute) { - a_tensor = a->eval(recompute); - b_tensor = b->eval(recompute); - - if (a_tensor->get_size() == 1) { - a_tensor->get_memory_manager()->sync(true); - internal::tensor_scalar_add_full(a_tensor->get(0), b_tensor, this->output_tensor); - } else if (b_tensor->get_size() == 1) { - b_tensor->get_memory_manager()->sync(true); - internal::tensor_scalar_add_full(b_tensor->get(0), a_tensor, this->output_tensor); - } else { - internal::geadd_full((T) 1, a_tensor, (T) 1, b_tensor, this->output_tensor); - } - - return this->output_tensor; -} - -template -Tensor *AddOp::_grad(Operation *consumer, Operation *var, Tensor *grad) { - return grad; -} -template class AddOp; -template class AddOp; -template class AddOp; - -template -AddOp *add(Operation *a, Operation *b, bool copy, bool needs_grad) { - return new AddOp(a, b, copy, needs_grad); -} -template AddOp *add(Operation *a, Operation *b, bool copy, bool needs_grad); -template AddOp *add(Operation *a, Operation *b, bool copy, bool needs_grad); -template AddOp *add(Operation *a, Operation *b, bool copy, bool needs_grad); - -} // namespace op -} // namespace magmadnn \ No newline at end of file diff --git a/src/compute/add/geadd_internal.cpp b/src/compute/add/geadd_internal.cpp deleted file mode 100644 index 8b3224c..0000000 --- a/src/compute/add/geadd_internal.cpp +++ /dev/null @@ -1,71 +0,0 @@ -/** - * @file geadd_internal.cpp - * @author Daniel Nichols - * @version 1.0 - * @date 2019-02-22 - * - * @copyright Copyright (c) 2019 - */ -#include "compute/add/geadd_internal.h" - -namespace magmadnn { -namespace internal { - -template -bool geadd_check(Tensor *A, Tensor *B, Tensor *C) { - assert(A->get_shape().size() == 2); - assert(B->get_shape().size() == 2); - assert(C->get_shape().size() == 2); - - assert(A->get_shape(0) == B->get_shape(0)); - assert(A->get_shape(0) == C->get_shape(0)); - assert(A->get_shape(1) == B->get_shape(1)); - assert(A->get_shape(1) == C->get_shape(1)); - return true; -} - -template -void geadd_full(T alpha, Tensor *A, T beta, Tensor *B, Tensor *C) { - if (A->get_memory_type() == HOST) { - T *a_ptr = A->get_ptr(); - T *b_ptr = B->get_ptr(); - T *c_ptr = C->get_ptr(); - unsigned int size = A->get_size(); - - for (unsigned int i = 0; i < size; i++) { - c_ptr[i] = (alpha * a_ptr[i]) + (beta * b_ptr[i]); - } - } -#if defined(_HAS_CUDA_) - else { - geadd_full_device(alpha, A, beta, B, C); - } -#endif -} -template void geadd_full(int alpha, Tensor *A, int beta, Tensor *B, Tensor *C); -template void geadd_full(float alpha, Tensor *A, float beta, Tensor *B, Tensor *C); -template void geadd_full(double alpha, Tensor *A, double beta, Tensor *B, Tensor *C); - -template -void tensor_scalar_add_full(T alpha, Tensor *x, Tensor *out) { - if (out->get_memory_type() == HOST) { - T *x_ptr = x->get_ptr(); - T *out_ptr = out->get_ptr(); - unsigned int size = out->get_size(); - - for (unsigned int i = 0; i < size; i++) { - out_ptr[i] = alpha + x_ptr[i]; - } - } -#if defined(_HAS_CUDA_) - else { - tensor_scalar_add_full_device(alpha, x, out); - } -#endif -} -template void tensor_scalar_add_full(int alpha, Tensor *x, Tensor *out); -template void tensor_scalar_add_full(float alpha, Tensor *x, Tensor *out); -template void tensor_scalar_add_full(double alpha, Tensor *x, Tensor *out); - -} // namespace internal -} // namespace magmadnn diff --git a/src/compute/add/geadd_internal_device.cu b/src/compute/add/geadd_internal_device.cu deleted file mode 100644 index e30d748..0000000 --- a/src/compute/add/geadd_internal_device.cu +++ /dev/null @@ -1,59 +0,0 @@ -/** - * @file geadd_internal_device.cu - * @author Daniel Nichols - * @version 1.0 - * @date 2019-02-22 - * - * @copyright Copyright (c) 2019 - */ -#include "compute/add/geadd_internal.h" - -#define BLK_SIZE 1024 - -namespace magmadnn { -namespace internal { - -template -__global__ void kernel_geadd_full_device(T alpha, T *A, T beta, T *B, T *C, unsigned int size) { - unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; - unsigned int stride = blockDim.x * gridDim.x; - - for (unsigned int i = idx; i < size; i += stride) { - C[i] = alpha * A[i] + beta * B[i]; - } -} - -template -void geadd_full_device(T alpha, Tensor *A, T beta, Tensor *B, Tensor *C) { - unsigned int size = C->get_size(); - kernel_geadd_full_device<<<(size + BLK_SIZE - 1) / BLK_SIZE, BLK_SIZE>>>(alpha, A->get_ptr(), beta, B->get_ptr(), - C->get_ptr(), size); -} -template void geadd_full_device(int alpha, Tensor *A, int beta, Tensor *B, Tensor *C); -template void geadd_full_device(float alpha, Tensor *A, float beta, Tensor *B, Tensor *C); -template void geadd_full_device(double alpha, Tensor *A, double beta, Tensor *B, Tensor *C); - -template -__global__ void kernel_tensor_scalar_add_full_device(T alpha, T *x, T *out, unsigned int arr_size) { - unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; - unsigned int stride = blockDim.x * gridDim.x; - - for (unsigned int i = idx; i < arr_size; i += stride) { - out[i] = alpha + x[i]; - } -} - -template -void tensor_scalar_add_full_device(T alpha, Tensor *x, Tensor *out) { - unsigned int size = out->get_size(); - kernel_tensor_scalar_add_full_device<<<(size + BLK_SIZE - 1) / BLK_SIZE, BLK_SIZE>>>(alpha, x->get_ptr(), - out->get_ptr(), size); -} -template void tensor_scalar_add_full_device(int alpha, Tensor *x, Tensor *out); -template void tensor_scalar_add_full_device(float alpha, Tensor *x, Tensor *out); -template void tensor_scalar_add_full_device(double alpha, Tensor *x, Tensor *out); - -} // namespace internal -} // namespace magmadnn - -#undef BLK_SIZE diff --git a/src/compute/batchnorm/batchnormop.cpp b/src/compute/batchnorm/batchnormop.cpp index f27ac17..37a38d0 100644 --- a/src/compute/batchnorm/batchnormop.cpp +++ b/src/compute/batchnorm/batchnormop.cpp @@ -4,55 +4,47 @@ namespace magmadnn { namespace op { -template -BatchNormOp::BatchNormOp(Operation *input, bool needs_grad) - : Operation::Operation({input}, needs_grad), input(input), num_calls(0) { +BatchNormOp::BatchNormOp(Operation *input) : Operation::Operation({input}), input(input), num_calls(0) { /* setup code in here */ - this->output_shape = input->get_output_shape(); - this->mem_type = input->get_memory_type(); - this->name = "BatchNorm"; + this->use_operation_settings(input, true); + this->name_ = "BatchNorm"; - this->input_tensor = input->get_output_tensor(); - this->output_tensor = new Tensor(this->output_shape, {NONE, {}}, this->mem_type); + this->output_tensor_ = Tensor(this->output_shape_, this->dtype_, {NONE}, this->mem_type_); #if defined(_HAS_CUDA_) init_settings(); #endif } -template -BatchNormOp::~BatchNormOp() {} - -template -Tensor *BatchNormOp::_eval(bool recompute) { +Tensor &BatchNormOp::_eval(bool recompute) { /* eval code in here ... */ - input_tensor = input->eval(recompute); + Tensor &input_tensor = input->eval(recompute); - if (this->mem_type == HOST) { - math::batchnorm(input_tensor, this->output_tensor); + if (this->mem_type_ == HOST) { + math::batchnorm(input_tensor, this->output_tensor_); } #if defined(_HAS_CUDA_) else { - math::batchnorm_device(input_tensor, this->output_tensor, bn_scale, bn_bias, running_mean, running_variance, + math::batchnorm_device(input_tensor, this->output_tensor_, bn_scale, bn_bias, running_mean, running_variance, saved_mean, saved_variance, num_calls, this->settings); } #endif - return this->output_tensor; + return this->output_tensor_; } -template -Tensor *BatchNormOp::_grad(Operation *consumer, Operation *var, Tensor *grad) { +Tensor &BatchNormOp::_grad(Operation *consumer, Operation *var, const Tensor &grad) { /* return gradient in here ... */ - Tensor *out = this->_grad_cache[(uintptr_t) var]; + auto res = this->_grad_cache.find(var); + Tensor out; - if (out == NULL) { - out = new Tensor(this->output_shape, {NONE, {}}, this->mem_type); - this->_grad_cache[(uintptr_t) var] = out; + if (!res->first) { + out = Tensor(this->output_shape_, this->dtype_, {NONE, {}}, this->mem_type_); + this->_grad_cache.insert(std::make_pair(var, std::ref(out))); } - if (this->mem_type == HOST) { + if (this->mem_type_ == HOST) { math::batchnorm_grad(grad, out); } #if defined(_HAS_CUDA_) @@ -62,12 +54,11 @@ Tensor *BatchNormOp::_grad(Operation *consumer, Operation *var, Tens } #endif - return out; + return this->_grad_cache[var]; } #if defined(_HAS_CUDA_) -template -void BatchNormOp::init_settings() { +void BatchNormOp::init_settings() { settings.handle = ::magmadnn::internal::MAGMADNN_SETTINGS->cudnn_handle; /* Use spatial if 4D (conv layer), and use per activation if 2D (fully connected layer) */ @@ -84,29 +75,17 @@ void BatchNormOp::init_settings() { bn_tensor_shape[3] = 1; } - bn_scale = new Tensor(bn_tensor_shape, {ONE, {}}, this->mem_type); - bn_bias = new Tensor(bn_tensor_shape, {ONE, {}}, this->mem_type); - bn_scale_diff = new Tensor(bn_tensor_shape, {ONE, {}}, this->mem_type); - bn_bias_diff = new Tensor(bn_tensor_shape, {ONE, {}}, this->mem_type); - running_mean = new Tensor(bn_tensor_shape, {ZERO, {}}, this->mem_type); - running_variance = new Tensor(bn_tensor_shape, {ZERO, {}}, this->mem_type); - saved_mean = new Tensor(bn_tensor_shape, {ZERO, {}}, this->mem_type); - saved_variance = new Tensor(bn_tensor_shape, {ZERO, {}}, this->mem_type); + bn_scale = Tensor(bn_tensor_shape, {ONE, {}}, this->mem_type); + bn_bias = Tensor(bn_tensor_shape, {ONE, {}}, this->mem_type); + bn_scale_diff = Tensor(bn_tensor_shape, {ONE, {}}, this->mem_type); + bn_bias_diff = Tensor(bn_tensor_shape, {ONE, {}}, this->mem_type); + running_mean = Tensor(bn_tensor_shape, {ZERO, {}}, this->mem_type); + running_variance = Tensor(bn_tensor_shape, {ZERO, {}}, this->mem_type); + saved_mean = Tensor(bn_tensor_shape, {ZERO, {}}, this->mem_type); + saved_variance = Tensor(bn_tensor_shape, {ZERO, {}}, this->mem_type); } #endif -template class BatchNormOp; -template class BatchNormOp; -template class BatchNormOp; - -template -BatchNormOp *batchnorm(Operation *input, bool needs_grad) { - return new BatchNormOp(input, needs_grad); -} -template BatchNormOp *batchnorm(Operation *input, bool needs_grad); -template BatchNormOp *batchnorm(Operation *input, bool needs_grad); -template BatchNormOp *batchnorm(Operation *input, bool needs_grad); - } // namespace op } // namespace magmadnn \ No newline at end of file diff --git a/src/compute/conv2dforward/conv2dforward_internal.cpp b/src/compute/conv2dforward/conv2dforward_internal.cpp deleted file mode 100644 index 9edaffd..0000000 --- a/src/compute/conv2dforward/conv2dforward_internal.cpp +++ /dev/null @@ -1,18 +0,0 @@ - -#include "compute/conv2dforward/conv2dforward_internal.h" - -namespace magmadnn { -namespace internal { - -template -void conv2dforward_full(Tensor *in, Tensor *out) { - /* - - */ -} -template void conv2dforward_full(Tensor *in, Tensor *out); -template void conv2dforward_full(Tensor *in, Tensor *out); -template void conv2dforward_full(Tensor *in, Tensor *out); - -} // namespace internal -} // namespace magmadnn \ No newline at end of file diff --git a/src/compute/conv2dforward/conv2dforwardop.cpp b/src/compute/conv2dforward/conv2dforwardop.cpp index 73db0a3..265bf8b 100644 --- a/src/compute/conv2dforward/conv2dforwardop.cpp +++ b/src/compute/conv2dforward/conv2dforwardop.cpp @@ -4,11 +4,10 @@ namespace magmadnn { namespace op { -template -Conv2DForwardOp::Conv2DForwardOp(Operation *input, Operation *filter, int pad_h, int pad_w, - int vertical_stride, int horizontal_stride, int dilation_h, int dilation_w, - bool use_cross_correlation, bool needs_grad) - : Operation::Operation({input, filter}, needs_grad), +Conv2DForwardOp::Conv2DForwardOp(Operation *input, Operation *filter, int pad_h, int pad_w, int vertical_stride, + int horizontal_stride, int dilation_h, int dilation_w, bool use_cross_correlation, + bool needs_grad) + : Operation::Operation({input, filter}, needs_grad), input(input), filter(filter), pad_h(pad_h), @@ -19,16 +18,16 @@ Conv2DForwardOp::Conv2DForwardOp(Operation *input, Operation *filter, i dilation_w(dilation_w), use_cross_correlation(use_cross_correlation) { /* setup code in here */ - this->mem_type = input->get_memory_type(); + this->mem_type_ = input->get_memory_type(); /* initialize all the conv settings */ - this->input_tensor = this->input->get_output_tensor(); + this->dtype_ = this->input->dtype(); this->init_settings(); } -template -Conv2DForwardOp::~Conv2DForwardOp() { - if (this->mem_type == HOST) { +Conv2DForwardOp::~Conv2DForwardOp() { + if (this->mem_type_ == HOST) { + /* delete CPU workspace here */ } #if defined(_HAS_CUDA_) else { @@ -43,73 +42,77 @@ Conv2DForwardOp::~Conv2DForwardOp() { #endif } -template -Tensor *Conv2DForwardOp::_eval(bool recompute) { - input_tensor = input->eval(recompute); - filter_tensor = filter->eval(recompute); +Tensor &Conv2DForwardOp::_eval(bool recompute) { + Tensor &input_tensor = input->eval(recompute); + Tensor &filter_tensor = filter->eval(recompute); - if (this->mem_type == HOST) { - std::fprintf(stderr, "Error: Conv2dForward::_eval requires GPU\n"); - } + switch (getDeviceType(this->mem_type_)) { + case CPU: + ::magmadnn::math::conv2d(input_tensor, filter_tensor, this->output_tensor_); #if defined(_HAS_CUDA_) - else { - ::magmadnn::math::conv2d_device(this->input_tensor, this->filter_tensor, this->output_tensor, - this->cudnn_settings); - } + case GPU: /* only wrap this in a macro, because cudnn_settings won't be defined on host */ + ::magmadnn::math::conv2d(input_tensor, filter_tensor, this->output_tensor_, this->cudnn_settings); #endif + default: + ::magmadnn::math::conv2d(input_tensor, filter_tensor, this->output_tensor_); + } - return this->output_tensor; + return this->output_tensor_; } -template -Tensor *Conv2DForwardOp::_grad(Operation *consumer, Operation *var, Tensor *grad) { +Tensor &Conv2DForwardOp::_grad(Operation *consumer, Operation *var, const Tensor &grad) { /* return gradient in here ... */ - Tensor *out = this->_grad_cache[(uintptr_t) var]; + + auto ret = this->_grad_cache.find(var); + Tensor out; if (var == this->input) { - if (out == NULL) { - out = new Tensor(this->input->get_output_shape(), {NONE, {}}, this->mem_type); - this->_grad_cache[(uintptr_t) var] = out; + if (!ret->first) { + out = Tensor(this->input->get_output_shape(), this->dtype_, {NONE, {}}, this->mem_type_); + this->_grad_cache.insert(std::make_pair(var, out)); } - this->filter_tensor = this->filter->eval(false); + Tensor &filter_tensor = this->filter->eval(false); - if (this->mem_type == HOST) { - ::magmadnn::math::conv2d_grad_data(this->filter_tensor, grad, out); - } + switch (getDeviceType(this->mem_type_)) { + case CPU: + ::magmadnn::math::conv2d_grad_data(filter_tensor, grad, out); #if defined(_HAS_CUDA_) - else { - ::magmadnn::math::conv2d_grad_data_device(this->filter_tensor, grad, out, this->cudnn_settings); - } + case GPU: + ::magmadnn::math::conv2d_grad_data(filter_tensor, grad, out, this->cudnn_settings); #endif + default: + LOG(ERROR) << "Unsupported conv2D type.\n"; + } } else if (var == this->filter) { - if (out == NULL) { - out = new Tensor(this->filter->get_output_shape(), {NONE, {}}, this->mem_type); - this->_grad_cache[(uintptr_t) var] = out; + if (!ret->first) { + out = Tensor(this->filter->get_output_shape(), this->dtype_, {NONE, {}}, this->mem_type_); + this->_grad_cache.insert(std::make_pair(var, out)); } - this->input_tensor = this->input->eval(false); + Tensor &input_tensor = this->input->eval(false); - if (this->mem_type == HOST) { - ::magmadnn::math::conv2d_grad_filter(this->input_tensor, grad, out); - } + switch (getDeviceType(this->mem_type_)) { + case CPU: + ::magmadnn::math::conv2d_grad_filter(input_tensor, grad, out); #if defined(_HAS_CUDA_) - else { - ::magmadnn::math::conv2d_grad_filter_device(this->input_tensor, grad, out, this->cudnn_settings); - } + case GPU: + ::magmadnn::math::conv2d_grad_filter(input_tensor, grad, out, this->cudnn_settings); #endif + default: + LOG(ERROR) << "Unsupported conv2D type.\n"; + } } else { std::fprintf(stderr, "Error: bad conv2d grad\n"); } - return out; + return this->_grad_cache[var]; } -template -void Conv2DForwardOp::init_settings() { - if (this->mem_type == HOST) { +void Conv2DForwardOp::init_settings() { + if (this->mem_type_ == HOST) { std::fprintf(stderr, "Error: Conv2DForward::init_settings requires GPU.\n"); } #if defined(_HAS_CUDA_) @@ -187,12 +190,11 @@ void Conv2DForwardOp::init_settings() { #endif } -template -void Conv2DForwardOp::calculate_and_set_output_shape() { +void Conv2DForwardOp::calculate_and_set_output_shape() { /* calculate the correct output shape here */ - if (this->mem_type == HOST) { + if (this->mem_type_ == HOST) { std::fprintf(stderr, "Error: Conv2dForward::output_shape requires GPU.\n"); - this->output_shape = this->input->get_output_shape(); + this->output_shape_ = this->input->get_output_shape(); } #if defined(_HAS_CUDA_) else { @@ -201,34 +203,13 @@ void Conv2DForwardOp::calculate_and_set_output_shape() { cudnnErrchk(cudnnGetConvolution2dForwardOutputDim(this->cudnn_settings.conv_desc, this->input_tensor->get_cudnn_tensor_descriptor(), this->cudnn_settings.filter_desc, &n, &c, &h, &w)); - this->output_shape = {static_cast(n), static_cast(c), static_cast(h), - static_cast(w)}; + this->output_shape_ = {static_cast(n), static_cast(c), static_cast(h), + static_cast(w)}; } #endif - this->output_tensor = new Tensor(this->output_shape, {NONE, {}}, this->mem_type); -} - -template class Conv2DForwardOp; -template class Conv2DForwardOp; -template class Conv2DForwardOp; - -template -Conv2DForwardOp *conv2dforward(Operation *input, Operation *filter, int pad_h, int pad_w, int vertical_stride, - int horizontal_stride, int dilation_h, int dilation_w, bool use_cross_correlation, - bool needs_grad) { - return new Conv2DForwardOp(input, filter, pad_h, pad_w, vertical_stride, horizontal_stride, dilation_h, - dilation_w, use_cross_correlation, needs_grad); + this->output_tensor_ = Tensor(this->output_shape_, this->dtype_, {NONE, {}}, this->mem_type_); } -template Conv2DForwardOp *conv2dforward(Operation *input, Operation *filter, int pad_h, int pad_w, - int vertical_stride, int horizontal_stride, int dilation_h, int dilation_w, - bool use_cross_correlation, bool needs_grad); -template Conv2DForwardOp *conv2dforward(Operation *input, Operation *filter, int pad_h, int pad_w, - int vertical_stride, int horizontal_stride, int dilation_h, - int dilation_w, bool use_cross_correlation, bool needs_grad); -template Conv2DForwardOp *conv2dforward(Operation *input, Operation *filter, int pad_h, - int pad_w, int vertical_stride, int horizontal_stride, int dilation_h, - int dilation_w, bool use_cross_correlation, bool needs_grad); } // namespace op } // namespace magmadnn \ No newline at end of file diff --git a/src/compute/crossentropy/crossentropy_internal.cpp b/src/compute/crossentropy/crossentropy_internal.cpp deleted file mode 100644 index d342e47..0000000 --- a/src/compute/crossentropy/crossentropy_internal.cpp +++ /dev/null @@ -1,53 +0,0 @@ - -#include "compute/crossentropy/crossentropy_internal.h" - -namespace magmadnn { -namespace internal { - -template -void crossentropy_full(Tensor *x, Tensor *y, Tensor *softmax, Tensor *out) { - if (out->get_memory_type() == HOST) { - T *x_ptr = x->get_ptr(); - T *y_ptr = y->get_ptr(); - T *softmax_ptr = softmax->get_ptr(); - T *out_ptr = out->get_ptr(); - unsigned int x_size = x->get_size(); - T x_max = x_ptr[0]; - T exps_sum = (T) 0; - unsigned int n_rows = x->get_shape(0); - - /* compute max of x */ - for (unsigned int i = 1; i < x_size; i++) { - if (x_ptr[i] > x_max) { - x_max = x_ptr[i]; - } - } - - /* softmax = exp(x- max(x)). also sum exp elements */ - for (unsigned int i = 0; i < x_size; i++) { - softmax_ptr[i] = exp(x_ptr[i] - x_max); - exps_sum += softmax_ptr[i]; - } - - /* divide each exp by exps_sum */ - for (unsigned int i = 0; i < x_size; i++) { - softmax_ptr[i] /= exps_sum; - } - - for (unsigned int i = 0; i < x_size; i++) { - out_ptr[0] += y_ptr[i] * log(softmax_ptr[i]); - } - out_ptr[0] /= -((T) n_rows); - } -#if defined(_HAS_CUDA_) - else { - crossentropy_full_device(x, y, softmax, out); - } -#endif -} -template void crossentropy_full(Tensor *x, Tensor *y, Tensor *softmax, Tensor *out); -template void crossentropy_full(Tensor *x, Tensor *y, Tensor *softmax, Tensor *out); -template void crossentropy_full(Tensor *x, Tensor *y, Tensor *softmax, Tensor *out); - -} // namespace internal -} // namespace magmadnn \ No newline at end of file diff --git a/src/compute/crossentropy/crossentropy_internal_device.cu b/src/compute/crossentropy/crossentropy_internal_device.cu deleted file mode 100644 index 3c46dba..0000000 --- a/src/compute/crossentropy/crossentropy_internal_device.cu +++ /dev/null @@ -1,18 +0,0 @@ - -#include "compute/crossentropy/crossentropy_internal.h" - -namespace magmadnn { -namespace internal { - -template -__global__ void kernel_crossentropy_full_device(T *x, T *y, T *softmax, T *out) {} - -template -void crossentropy_full_device(Tensor *x, Tensor *y, Tensor *softmax, Tensor *out) {} -template void crossentropy_full_device(Tensor *x, Tensor *y, Tensor *softmax, Tensor *out); -template void crossentropy_full_device(Tensor *x, Tensor *y, Tensor *softmax, Tensor *out); -template void crossentropy_full_device(Tensor *x, Tensor *y, Tensor *softmax, - Tensor *out); - -} // namespace internal -} // namespace magmadnn \ No newline at end of file diff --git a/src/compute/crossentropy/crossentropyop.cpp b/src/compute/crossentropy/crossentropyop.cpp index 8f16d80..9689111 100644 --- a/src/compute/crossentropy/crossentropyop.cpp +++ b/src/compute/crossentropy/crossentropyop.cpp @@ -4,58 +4,9 @@ namespace magmadnn { namespace op { -template -CrossEntropyOp::CrossEntropyOp(Operation *x, Operation *y, bool copy, bool needs_grad) - : Operation::Operation({x, y}, needs_grad), x(x), y(y), copy(copy) { - /* x should be (n_samples x n_classes) - y should be (n_samples x n_classes) - */ - assert(OP_IS_MATRIX(x)); - assert(OP_IS_MATRIX(y)); - assert(x->get_output_shape(0) == y->get_output_shape(0)); - assert(x->get_output_shape(1) == y->get_output_shape(1)); - - this->output_shape = {1}; - this->mem_type = x->get_memory_type(); - - if (copy) { - this->output_tensor = new Tensor(this->output_shape, {NONE, {}}, this->mem_type); - } else { - std::fprintf(stderr, "no copy cross entropy not supported yet.\n"); - } -} - -template -CrossEntropyOp::~CrossEntropyOp() {} - -template -Tensor *CrossEntropyOp::_eval(bool recompute) { - x_tensor = x->eval(recompute); - y_tensor = y->eval(recompute); - - // internal::crossentropy_full(x_tensor, y_tensor, this->softmax, this->output_tensor); - math::crossentropy(x_tensor, y_tensor, this->output_tensor); - - return this->output_tensor; -} - -template -Tensor *CrossEntropyOp::_grad(Operation *consumer, Operation *var, Tensor *grad) { - this->_grad_cache[(uintptr_t) var] = grad; - return grad; -} - -template class CrossEntropyOp; -template class CrossEntropyOp; -template class CrossEntropyOp; - -template -Operation *crossentropy(Operation *ground_truth, Operation *predicted, bool copy, bool needs_grad) { +Operation *crossentropy(Operation *ground_truth, Operation *predicted, bool needs_grad) { return negative(reducesum(reducesum(product(ground_truth, log(predicted, true)), 1), 0)); } -template Operation *crossentropy(Operation *, Operation *, bool, bool); -template Operation *crossentropy(Operation *, Operation *, bool, bool); -template Operation *crossentropy(Operation *, Operation *, bool, bool); } // namespace op } // namespace magmadnn \ No newline at end of file diff --git a/src/compute/gradients.cpp b/src/compute/gradients.cpp index 849249e..09dd209 100644 --- a/src/compute/gradients.cpp +++ b/src/compute/gradients.cpp @@ -8,13 +8,14 @@ */ #include "compute/gradients.h" +#include "utilities_internal.h" + namespace magmadnn { namespace op { -template -magmadnn_error_t get_grad_table(const std::vector *> &vars, Operation *graph, GradTable &table) { +magmadnn_error_t get_grad_table(const std::vector &vars, Operation *graph, GradTable &table) { magmadnn_error_t err; - Tensor *tmp; + Tensor tmp; /* prune compute graph: construct a new graph G' that only contains nodes that are ancestors of @@ -22,13 +23,14 @@ magmadnn_error_t get_grad_table(const std::vector *> &vars, Operati /* TODO */ /* init Loss in grad table to one */ - Tensor *grad_loss = new Tensor({1}, {ONE, {}}, graph->get_memory_type()); + Tensor grad_loss({1}, graph->dtype(), {ONE}, graph->get_memory_type()); + table.set(graph, grad_loss); /* compute the gradients for each variable */ - for (typename std::vector *>::const_iterator vit = vars.begin(); vit != vars.end(); vit++) { + for (std::vector::const_iterator vit = vars.begin(); vit != vars.end(); vit++) { if (*vit != NULL) { - err = internal::build_grad(*vit, graph, table, &tmp); + err = internal::build_grad(*vit, graph, table, tmp); } else { return (magmadnn_error_t) 1; } @@ -40,52 +42,47 @@ magmadnn_error_t get_grad_table(const std::vector *> &vars, Operati return (magmadnn_error_t) 0; } -template magmadnn_error_t get_grad_table(const std::vector *> &vars, Operation *graph, - GradTable &table); -template magmadnn_error_t get_grad_table(const std::vector *> &vars, Operation *graph, - GradTable &table); -template magmadnn_error_t get_grad_table(const std::vector *> &vars, Operation *graph, - GradTable &table); } // namespace op // build_grad should only be used internally namespace internal { -template -magmadnn_error_t build_grad(op::Operation *var, op::Operation *graph, op::GradTable &table, Tensor **grad) { - Tensor *tmp_grad, *bprop, *result; - op::Operation *consumer; - std::vector *> bprops; +magmadnn_error_t build_grad(op::Operation *var, op::Operation *graph, op::GradTable &table, Tensor &grad) { + // Tensor *tmp_grad, *bprop, *result; + op::Operation *consumer; + std::vector consumers; + std::vector> bprops; magmadnn_error_t err; - std::vector *> consumers; /* error on null values */ - if (var == NULL || graph == NULL || grad == NULL) return (magmadnn_error_t) 1; + if (var == nullptr || graph == nullptr) return (magmadnn_error_t) 1; /* get this entry in the grad table */ - tmp_grad = table.get(var); + auto const &res = table.get(var); - /* if not null then we have already calculated this gradient */ - if (tmp_grad != NULL) { - *grad = tmp_grad; + /* we've already calculated this gradient */ + if (res.first) { + grad = res.second; return (magmadnn_error_t) 0; } + Tensor &tmp_grad = res.second; + /* build gradients for each consumer to this operation in order to properly * calculate ours */ consumers = var->get_consumers(); - for (typename std::vector *>::iterator vit = consumers.begin(); vit != consumers.end(); vit++) { + for (std::vector::iterator vit = consumers.begin(); vit != consumers.end(); vit++) { consumer = (*vit); /* if this is null for some reason stop here */ - if (consumer == NULL) continue; + if (consumer == nullptr) continue; /* build the gradient for consumer and keep track of it in bprops */ - err = build_grad(consumer, graph, table, &tmp_grad); + err = build_grad(consumer, graph, table, tmp_grad); if (err != 0) return err; - bprop = consumer->grad(consumer, var, tmp_grad); + Tensor &bprop = consumer->grad(consumer, var, tmp_grad); bprops.push_back(bprop); } @@ -94,38 +91,27 @@ magmadnn_error_t build_grad(op::Operation *var, op::Operation *graph, op:: if (bprops.size() == 0) { return (magmadnn_error_t) 2; } else if (bprops.size() == 1) { - result = bprops.at(0); + grad = bprops.at(0); } else if (bprops.size() == 2) { - /* - result = op::add(bprops.at(0), bprops.at(1), true, false); - */ /* TODO : Add and sum tensors */ - result = NULL; - fprintf(stderr, "Implement add in gradients\n"); + // result = NULL; + + LOG(ERROR) << "Implement add in gradients\n"; } else { - /* currently sum cannot handle scalar values, so just tetrate adds for - * now */ - // result = op::sum(bprops); /* result = bprops.at(0); for (unsigned int i = 1; i < bprops.size(); i++) { result = op::add(result, bprops.at(i)); }*/ - result = NULL; - fprintf(stderr, "Implement sum in gradients\n"); + // result = NULL; + + LOG(ERROR) << "Implement sum in gradients\n"; } - table.set(var, result); - *grad = result; + table.set(var, grad); return (magmadnn_error_t) 0; } -template magmadnn_error_t build_grad(op::Operation *var, op::Operation *graph, op::GradTable &table, - Tensor **grad); -template magmadnn_error_t build_grad(op::Operation *var, op::Operation *graph, - op::GradTable &table, Tensor **grad); -template magmadnn_error_t build_grad(op::Operation *var, op::Operation *graph, - op::GradTable &table, Tensor **grad); } // namespace internal } // namespace magmadnn \ No newline at end of file diff --git a/src/compute/gradtable.cpp b/src/compute/gradtable.cpp index 8d7b567..2ee352f 100644 --- a/src/compute/gradtable.cpp +++ b/src/compute/gradtable.cpp @@ -11,43 +11,22 @@ namespace magmadnn { namespace op { -template -GradTable::GradTable() { +GradTable::GradTable() { // init } -template -unsigned int GradTable::get_size() { - return _table.size(); -} - -template -Tensor *GradTable::get(Operation *var) { - tmp_map_iterator = _table.find(var); - - // return NULL if not found - if (tmp_map_iterator == _table.end()) { - return (Tensor *) NULL; - } +unsigned int GradTable::get_size() { return _table.size(); } - return tmp_map_iterator->second; -} - -template -void GradTable::set(Operation *var, Tensor *grad) { - if (var == NULL) return; +std::pair> GradTable::get(Operation *var) { return *_table.find(var); } - _table[var] = grad; -} +void GradTable::set(Operation *var, Tensor &grad) { + if (var == nullptr) return; -template -void GradTable::clear() { - this->_table.clear(); + /* add this gradient into the table */ + _table.insert(std::make_pair(var, std::ref(grad))); } -template class GradTable; -template class GradTable; -template class GradTable; +void GradTable::clear() { this->_table.clear(); } } // namespace op } // namespace magmadnn \ No newline at end of file diff --git a/src/compute/op_utilities.cpp b/src/compute/op_utilities.cpp index 7cafd2d..3556a58 100644 --- a/src/compute/op_utilities.cpp +++ b/src/compute/op_utilities.cpp @@ -12,14 +12,13 @@ namespace magmadnn { namespace op { namespace utility { -template -magmadnn_error_t print_compute_graph(::magmadnn::op::Operation *_root, bool debug) { - std::set<::magmadnn::op::Operation *> visited; - std::deque<::magmadnn::op::Operation *> to_visit; - ::magmadnn::op::Operation *cur; +magmadnn_error_t print_compute_graph(::magmadnn::op::Operation *_root, bool debug) { + std::set<::magmadnn::op::Operation *> visited; + std::deque<::magmadnn::op::Operation *> to_visit; + ::magmadnn::op::Operation *cur; int (*print)(const char *, ...); - typename std::vector<::magmadnn::op::Operation *>::const_iterator vit; - std::vector::const_iterator vui_it; + typename std::vector<::magmadnn::op::Operation *>::const_iterator vit; + std::vector::const_iterator vui_it; print = (debug) ? ::magmadnn::internal::debugf : std::printf; @@ -33,13 +32,13 @@ magmadnn_error_t print_compute_graph(::magmadnn::op::Operation *_root, bool d print("Operation [%s]:\n", cur->to_string().c_str()); print("\tShape: {"); - const std::vector &out_shape = cur->get_output_shape(); + const std::vector &out_shape = cur->get_output_shape(); for (vui_it = out_shape.begin(); vui_it != out_shape.end(); vui_it++) { print(" %lu%s", (*vui_it), (vui_it == out_shape.end() - 1) ? " }" : ","); } print("\n\tConsumers:"); - std::vector *> const &consumers = cur->get_consumers(); + std::vector const &consumers = cur->get_consumers(); for (vit = consumers.begin(); vit != consumers.end(); vit++) { print(" [%s]", (*vit)->to_string().c_str()); @@ -50,7 +49,7 @@ magmadnn_error_t print_compute_graph(::magmadnn::op::Operation *_root, bool d } print("\n\tInputs:"); - std::vector *> const &inputs = cur->get_inputs(); + std::vector const &inputs = cur->get_inputs(); for (vit = inputs.begin(); vit != inputs.end(); vit++) { print(" [%s]", (*vit)->to_string().c_str()); @@ -64,9 +63,6 @@ magmadnn_error_t print_compute_graph(::magmadnn::op::Operation *_root, bool d return (magmadnn_error_t) 0; } -template magmadnn_error_t print_compute_graph(::magmadnn::op::Operation *_root, bool debug); -template magmadnn_error_t print_compute_graph(::magmadnn::op::Operation *_root, bool debug); -template magmadnn_error_t print_compute_graph(::magmadnn::op::Operation *_root, bool debug); } // namespace utility } // namespace op diff --git a/src/compute/sum/sumop.cpp b/src/compute/sum/sumop.cpp index b944779..3cdb525 100644 --- a/src/compute/sum/sumop.cpp +++ b/src/compute/sum/sumop.cpp @@ -11,13 +11,12 @@ namespace magmadnn { namespace op { -template -SumOp::SumOp(std::vector *> ops, bool copy) : Operation::Operation(ops), ops(ops), copy(copy) { +SumOp::SumOp(std::vector ops) : Operation::Operation(ops), ops(ops) { if (ops.empty()) { return; } - typename std::vector *>::const_iterator it = ops.begin(); + std::vector::const_iterator it = ops.begin(); unsigned int first_size = (*it)->get_output_size(); for (it++; it != ops.end(); it++) { assert((*it)->get_output_size() == first_size); @@ -26,37 +25,27 @@ SumOp::SumOp(std::vector *> ops, bool copy) : Operation::Oper this->output_shape = ops.at(0)->get_output_shape(); this->mem_type = ops.at(0)->get_memory_type(); - if (copy) { - this->output_tensor = new Tensor(ops.at(0)->get_output_shape(), {ZERO, {}}, ops.at(0)->get_memory_type()); - } else { - std::fprintf(stderr, "no_copy sum not supported yet.\n"); - } + this->output_tensor = + Tensor(ops.at(0)->get_output_shape(), ops.at(0).dtype(), {ZERO, {}}, ops.at(0)->get_memory_type()); } -template -Tensor *SumOp::_eval(bool recompute) { - std::vector *> vals(ops.size()); +Tensor &SumOp::_eval(bool recompute) { + std::vector vals(ops.size()); for (unsigned int i = 0; i < ops.size(); i++) { - vals[i] = ops[i]->eval(); + vals[i] = ops[i]->eval(recompute); } - /* TODO sum into first OR last element for non-copy */ - assert(this->output_tensor != NULL); - internal::sum_full(vals, *this->output_tensor); + internal::sum_full(vals, this->output_tensor); return this->output_tensor; } -template -Tensor *SumOp::_grad(Operation *consumer, Operation *var, Tensor *grad) { - return grad; -} +Tensor &SumOp::_grad(Operation *consumer, Operation *var, Tensor &grad) { return grad; } -template -std::string SumOp::to_string() { +std::string SumOp::to_string() override { std::string ret = "("; - for (typename std::vector *>::iterator vit = this->ops.begin(); vit != this->ops.end(); vit++) { + for (typename std::vector::iterator vit = this->ops.begin(); vit != this->ops.end(); vit++) { if (vit != ops.begin()) { ret += "+"; } @@ -65,17 +54,5 @@ std::string SumOp::to_string() { return ret + ")"; } -template class SumOp; -template class SumOp; -template class SumOp; - -template -Operation *sum(std::vector *> ops, bool copy) { - return new SumOp(ops, copy); -} -template Operation *sum(std::vector *> ops, bool copy); -template Operation *sum(std::vector *> ops, bool copy); -template Operation *sum(std::vector *> ops, bool copy); - } // namespace op } // namespace magmadnn \ No newline at end of file diff --git a/src/math/batchnorm.cpp b/src/math/batchnorm.cpp index 1a1c595..225d236 100644 --- a/src/math/batchnorm.cpp +++ b/src/math/batchnorm.cpp @@ -11,7 +11,6 @@ namespace magmadnn { namespace math { -template void batchnorm(const Tensor &x, Tensor &out) { // assert(T_IS_SAME_MEMORY_TYPE(x, out)); @@ -24,11 +23,7 @@ void batchnorm(const Tensor &x, Tensor &out) { } #endif } -#define comp(type) template void batchnorm(const Tensor &, Tensor &); -CALL_FOR_ALL_TYPES(comp) -#undef comp -template void batchnorm_grad(const Tensor &grad, Tensor &out) { // assert(T_IS_SAME_MEMORY_TYPE(grad, out)); @@ -41,51 +36,40 @@ void batchnorm_grad(const Tensor &grad, Tensor &out) { } #endif } -#define comp(type) template void batchnorm_grad(const Tensor &, Tensor &); -CALL_FOR_ALL_TYPES(comp) -#undef comp #if defined(_HAS_CUDA_) -template void batchnorm_device(const Tensor &x, Tensor &out, Tensor &bn_scale, Tensor &bn_bias, Tensor &running_mean, Tensor &running_variance, Tensor &saved_mean, Tensor &saved_variance, unsigned int &num_calls, cudnn_batchnorm_settings_t settings) { - T alpha = static_cast(1), beta = static_cast(0); - double epsilon = 1E-8; - num_calls++; + FOR_ALL_DTYPES(x.dtype(), T, { + T alpha = static_cast(1), beta = static_cast(0); + double epsilon = 1E-8; + num_calls++; - cudnnErrchk(cudnnBatchNormalizationForwardTraining( - settings.handle, settings.mode, &alpha, &beta, x.get_cudnn_tensor_descriptor(), x.get_ptr(), - out.get_cudnn_tensor_descriptor(), out.get_ptr(), settings.bn_tensor_desc, bn_scale.get_ptr(), - bn_bias.get_ptr(), ((double) (1) / (double) (1 + num_calls)), running_mean.get_ptr(), - running_variance.get_ptr(), epsilon, saved_mean.get_ptr(), saved_variance.get_ptr())); + cudnnErrchk(cudnnBatchNormalizationForwardTraining( + settings.handle, settings.mode, &alpha, &beta, x.get_cudnn_tensor_descriptor(), x.get_ptr(), + out.get_cudnn_tensor_descriptor(), out.get_ptr(), settings.bn_tensor_desc, bn_scale.get_ptr(), + bn_bias.get_ptr(), ((double) (1) / (double) (1 + num_calls)), running_mean.get_ptr(), + running_variance.get_ptr(), epsilon, saved_mean.get_ptr(), saved_variance.get_ptr())); + }); } -#define comp(type) \ - template void batchnorm_device(const Tensor &, Tensor &, Tensor &, Tensor &, Tensor &, Tensor &, Tensor &, \ - Tensor &, unsigned int &num_calls, cudnn_batchnorm_settings_t settings); -CALL_FOR_ALL_TYPES(comp) -#undef comp -template void batchnorm_grad_device(const Tensor &x, const Tensor &grad, Tensor &out, Tensor &bn_scale, Tensor &bn_scale_diff, Tensor &bn_bias_diff, Tensor &saved_mean, Tensor &saved_variance, cudnn_batchnorm_settings_t settings) { - T alpha_data = static_cast(1), alpha_params = static_cast(1); - T beta_data = static_cast(0), beta_params = static_cast(0); - double epsilon = 1E-8; + FOR_ALL_DTYPES(x.dtype(), T, { + T alpha_data = static_cast(1), alpha_params = static_cast(1); + T beta_data = static_cast(0), beta_params = static_cast(0); + double epsilon = 1E-8; - cudnnErrchk(cudnnBatchNormalizationBackward( - settings.handle, settings.mode, &alpha_data, &beta_data, &alpha_params, &beta_params, - x.get_cudnn_tensor_descriptor(), x.get_ptr(), grad.get_cudnn_tensor_descriptor(), grad.get_ptr(), - out.get_cudnn_tensor_descriptor(), out.get_ptr(), settings.bn_tensor_desc, bn_scale.get_ptr(), - bn_scale_diff.get_ptr(), bn_bias_diff.get_ptr(), epsilon, saved_mean.get_ptr(), - saved_variance.get_ptr())); + cudnnErrchk(cudnnBatchNormalizationBackward( + settings.handle, settings.mode, &alpha_data, &beta_data, &alpha_params, &beta_params, + x.get_cudnn_tensor_descriptor(), x.get_ptr(), grad.get_cudnn_tensor_descriptor(), grad.get_ptr(), + out.get_cudnn_tensor_descriptor(), out.get_ptr(), settings.bn_tensor_desc, bn_scale.get_ptr(), + bn_scale_diff.get_ptr(), bn_bias_diff.get_ptr(), epsilon, saved_mean.get_ptr(), + saved_variance.get_ptr())); + }); } -#define comp(type) \ - template void batchnorm_grad_device(const Tensor &, const Tensor &, Tensor &, Tensor &, Tensor &, Tensor &, \ - Tensor &, Tensor &, cudnn_batchnorm_settings_t settings); -CALL_FOR_ALL_TYPES(comp) -#undef comp #endif