forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ParamUtils.cpp
61 lines (53 loc) · 1.82 KB
/
ParamUtils.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/native/sparse/ParamUtils.h>
#include <ATen/core/Tensor.h>
#include <ATen/TensorUtils.h>
#include <ATen/WrapDimUtils.h>
#include <tuple>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/empty_like_native.h>
#endif
namespace at {
namespace native {
std::pair<Tensor, Tensor> softmax_sparse_input_preprocessing(
const Tensor& input_,
const int64_t dim_,
const bool half_to_float,
CheckedFrom function_name) {
TORCH_INTERNAL_ASSERT(input_.is_sparse());
TORCH_CHECK(
!half_to_float,
std::string(function_name) +
": with half to float conversion is not supported on " +
input_.device().str());
auto input = input_.coalesce();
Tensor output = at::native::empty_like_sparse_coo(input);
TORCH_CHECK(
dim_ >= 0 && dim_ < input.dim(),
": dim must be non-negative and less than input dimensions");
return std::make_pair(input, output);
}
std::tuple<Tensor, Tensor, Tensor> softmax_backward_sparse_input_preprocessing(
const Tensor& grad_,
const Tensor& output_,
int64_t dim_,
const Tensor& input_,
CheckedFrom function_name) {
TensorArg grad_arg{grad_, "grad", 1}, output_arg{output_, "output", 2};
checkSameSize(function_name, grad_arg, output_arg);
int64_t dim = maybe_wrap_dim(dim_, grad_.dim());
auto grad = grad_.coalesce();
auto output = output_.coalesce();
Tensor grad_input = at::native::empty_like_sparse_coo(output);
TORCH_CHECK(
dim >= 0 && dim < grad.dim(),
": dim must be non-negative and less than input dimensions");
TORCH_CHECK(
grad.sparse_dim() == output.sparse_dim(),
": grad and output sparse dimensions must be equal");
return std::make_tuple(grad_input, grad, output);
}
} // namespace native
} // namespace at