Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add TopK op #5839

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2,314 changes: 1,220 additions & 1,094 deletions docs/developer-guide/operators.md

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ ncnn_add_layer(Shrink)
ncnn_add_layer(RMSNorm)
ncnn_add_layer(Spectrogram)
ncnn_add_layer(InverseSpectrogram)
ncnn_add_layer(TopK)

if(NCNN_VULKAN)
ncnn_add_shader(${CMAKE_CURRENT_SOURCE_DIR}/convert_ycbcr.comp)
Expand Down
103 changes: 103 additions & 0 deletions src/layer/topk.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "topk.h"

namespace ncnn {

TopK::TopK()
{
one_blob_only = true; // 只需要一个输入 blob
support_inplace = false; // 是否支持原地运算
k = 1;
axis = 0;
largest = 1;
sorted = 1;
}

int TopK::load_param(const ParamDict& pd)
{
k = pd.get(0, 1);
axis = pd.get(1, 0);
largest = pd.get(2, 1);
sorted = pd.get(3, 1);
return 0;
}
int TopK::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
{
int size = (int)bottom_blob.total();
int k_ = k;
if (k_ > size) k_ = size;

const float* ptr = bottom_blob.row(0);

std::vector<std::pair<float, int> > vec;
vec.reserve(size);
for (int i = 0; i < size; i++)
{
vec.push_back(std::make_pair(ptr[i], i));
}

// [](const std::pair<float, int>& a, const std::pair<float, int>& b) {return a.first > b.first;}); // fix Lambda with lower version of C++
struct CompareGreater
{
bool operator()(const std::pair<float, int>& a, const std::pair<float, int>& b) const
{
return a.first > b.first;
}
};

struct CompareLess
{
bool operator()(const std::pair<float, int>& a, const std::pair<float, int>& b) const
{
return a.first < b.first;
}
};

if (largest == 1)
{
std::partial_sort(vec.begin(), vec.begin() + k_, vec.end(), CompareGreater());
}
else
{
std::partial_sort(vec.begin(), vec.begin() + k_, vec.end(), CompareLess());
}

if (sorted)
{
if (largest == 1)
{
std::sort(vec.begin(), vec.begin() + k_, CompareGreater());
}
else
{
std::sort(vec.begin(), vec.begin() + k_, CompareLess());
}
}

top_blob.create(k_, 1, 4u, 1, opt.blob_allocator);
if (top_blob.empty())
return -100;

float* outptr = top_blob;
for (int i = 0; i < k_; i++)
{
outptr[i] = vec[i].first;
}

return 0;
}

} // namespace ncnn
40 changes: 40 additions & 0 deletions src/layer/topk.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#ifndef LAYER_TOPK_H
#define LAYER_TOPK_H

#include "layer.h"

namespace ncnn {

class TopK : public Layer
{
public:
TopK();

virtual int load_param(const ParamDict& pd);

virtual int forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const;

public:
int k;
int axis;
int largest;
int sorted;
};

} // namespace ncnn

#endif // LAYER_TOPK_H
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ ncnn_add_layer_test(Spectrogram)
ncnn_add_layer_test(Squeeze)
ncnn_add_layer_test(Swish)
ncnn_add_layer_test(TanH)
ncnn_add_layer_test(TopK)
ncnn_add_layer_test(Tile)
ncnn_add_layer_test(UnaryOp)
ncnn_add_layer_test(Unfold)
Expand Down
60 changes: 60 additions & 0 deletions tests/test_topk.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "testutil.h"

static int test_topk(const ncnn::Mat& a, int k, int axis, int largest, int sorted)
{
ncnn::ParamDict pd;
pd.set(0, k); // k
pd.set(1, axis); // axis
pd.set(2, largest); // largest
pd.set(3, sorted); // sorted

std::vector<ncnn::Mat> weights(0);

int ret = test_layer("TopK", pd, weights, a);
if (ret != 0)
{
fprintf(stderr, "test_topk failed a.dims=%d a=(%d %d %d) k=%d axis=%d largest=%d sorted=%d\n", a.dims, a.w, a.h, a.c, k, axis, largest, sorted);
}

return ret;
}

static int test_topk_0()
{
return 0
|| test_topk(RandomMat(8, 8, 3), 5, 0, 1, 1)
|| test_topk(RandomMat(7, 7, 2), 3, 1, 0, 1)
|| test_topk(RandomMat(6, 6, 4), 2, -1, 1, 0)
|| test_topk(RandomMat(5, 5, 5), 4, 2, 0, 0);
}

static int test_topk_1()
{
return 0
|| test_topk(RandomMat(16), 5, 0, 1, 1)
|| test_topk(RandomMat(32), 10, 0, 0, 1)
|| test_topk(RandomMat(64), 20, 0, 1, 0);
}

int main()
{
SRAND(7767517);

return 0
|| test_topk_0()
|| test_topk_1();
}
1 change: 1 addition & 0 deletions tools/pnnx/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,7 @@ set(pnnx_pass_ncnn_SRCS
pass_ncnn/torch_sum.cpp
pass_ncnn/torch_stft.cpp
pass_ncnn/torch_t.cpp
pass_ncnn/torch_topk.cpp
pass_ncnn/torch_transpose.cpp
pass_ncnn/torch_unsqueeze.cpp
pass_ncnn/torchaudio_F_inverse_spectrogram.cpp
Expand Down
66 changes: 66 additions & 0 deletions tools/pnnx/src/pass_ncnn/torch_topk.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.
#include "pass_ncnn.h"

namespace pnnx {

namespace ncnn {

class torch_topk : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
torch.topk op_0 1 2 input out indices dim=%dim k=%k largest=%largest sorted=%sorted
pnnx.Output output 2 0 out indices
)PNNXIR";
}

const char* type_str() const
{
return "TopK";
}

const char* name_str() const
{
return "topk";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
int k = captured_params.at("k").i;
int dim = captured_params.at("dim").i;
int largest = captured_params.at("largest").b ? 1 : 0;
int sorted = captured_params.at("sorted").b ? 1 : 0;

// 设置参数
op->params["0"] = k;
op->params["1"] = dim;
op->params["2"] = largest;
op->params["3"] = sorted;

// 移除不需要的输入
op->inputs.resize(1);
op->outputs.resize(1);
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_topk, 20)

} // namespace ncnn

} // namespace pnnx
1 change: 1 addition & 0 deletions tools/pnnx/tests/ncnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ pnnx_ncnn_add_test(torch_square)
pnnx_ncnn_add_test(torch_tan)
pnnx_ncnn_add_test(torch_tanh)
pnnx_ncnn_add_test(torch_trunc)
pnnx_ncnn_add_test(torch_topk)

pnnx_ncnn_add_test(convnext_tiny)
pnnx_ncnn_add_test(mobilenet_v2)
Expand Down
68 changes: 68 additions & 0 deletions tools/pnnx/tests/ncnn/test_torch_topk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y, z):
x, _ = torch.topk(x, 4)
y, _ = torch.topk(y, k=1, dim=2, largest=False)
z, indices = torch.topk(z, k=3, dim=-1, sorted=False)
return x, y, z, indices


def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(1, 3, 16)
y = torch.rand(1, 5, 9, 11)
z = torch.rand(14, 8, 5, 9, 10)

a = net(x, y, z)

# export torchscript
mod = torch.jit.trace(net, (x, y, z))
mod.save("test_torch_topk.pt")

# torchscript to pnnx
import os

os.system(
"../src/pnnx test_torch_topk.pt inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10]"
)

# pnnx inference
import test_torch_topk_ncnn

b = test_torch_topk_ncnn.test_inference()

for a0, b0 in zip(a, b):
if not torch.equal(a0, b0):
return False
return True


if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)
Loading