Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrap unsqueeze #1294

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions k2/python/csrc/torch/v2/any.cu
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,8 @@ void PybindRaggedAny(py::module &m) {
any.def("unique", &RaggedAny::Unique, py::arg("need_num_repeats") = false,
py::arg("need_new2old_indexes") = false, kRaggedAnyUniqueDoc);

any.def("unsqueeze", &RaggedAny::Unsqueeze, py::arg("axis"));

any.def("normalize", &RaggedAny::Normalize, py::arg("use_log"),
kRaggedAnyNormalizeDoc);

Expand Down
8 changes: 8 additions & 0 deletions k2/python/csrc/torch/v2/ragged_any.cu
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,14 @@ RaggedAny RaggedAny::Cat(const std::vector<RaggedAny> &srcs, int32_t axis) {
return {};
}

RaggedAny RaggedAny::Unsqueeze(int32_t axis) {
DeviceGuard guard(any.Context());
Dtype t = any.GetDtype();
FOR_REAL_AND_INT32_TYPES(t, T, {
return RaggedAny(k2::Unsqueeze(any.Specialize<T>(), axis).Generic());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

k2:: can be removed since it is inside the k2 namespace.

});
}

std::tuple<RaggedAny, torch::optional<RaggedAny>,
torch::optional<torch::Tensor>>
RaggedAny::Unique(bool need_num_repeats /*= false*/,
Expand Down
2 changes: 2 additions & 0 deletions k2/python/csrc/torch/v2/ragged_any.h
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,8 @@ struct RaggedAny {
torch::optional<torch::Tensor>>
Unique(bool need_num_repeats = false, bool need_new2old_indexes = false);

RaggedAny Unsqueeze(int32_t axis);

/// Wrapper for k2::NormalizePerSublist
RaggedAny Normalize(bool use_log) /*const*/;

Expand Down
1 change: 1 addition & 0 deletions k2/python/k2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

from .ragged import RaggedShape
from .ragged import RaggedTensor
from .ragged import create_ragged_shape2

from . import autograd
from . import autograd_utils
Expand Down
Loading