Skip to content

Commit

Permalink
pnnx ncnn convert select to crop and squeeze (#5826)
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui authored Dec 18, 2024
1 parent 9f67ff1 commit a12baae
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions tools/pnnx/src/pass_ncnn/convert_Tensor_select.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ void convert_Tensor_select(Graph& graph)
if (axis > batch_index)
axis -= 1;

int dim = op->params.at("dim").i;
int index = op->params.at("index").i;

op->params["9"] = std::vector<int> {index};
Expand All @@ -63,24 +64,26 @@ void convert_Tensor_select(Graph& graph)
op->params.erase("dim");
op->params.erase("index");

// reshape for output, squeezing the select dim
// squeezing the select dim
{
Operand* out = op->outputs[0];

Operator* reshape = graph.new_operator_after("Tensor.reshape", op->name + "_ncnnreshape", op);
Operator* squeeze = graph.new_operator_after("torch.squeeze", op->name + "_ncnnsqueeze", op);

Operand* reshape_in = graph.new_operand(op->name + "_ncnnreshape_in");
Operand* squeeze_in = graph.new_operand(op->name + "_ncnnsqueeze_in");

reshape->inputs.push_back(reshape_in);
reshape->outputs.push_back(out);
squeeze->inputs.push_back(squeeze_in);
squeeze->outputs.push_back(out);

op->outputs[0] = reshape_in;
op->outputs[0] = squeeze_in;

out->producer = reshape;
reshape_in->producer = op;
reshape_in->consumers.push_back(reshape);
out->producer = squeeze;
squeeze_in->producer = op;
squeeze_in->consumers.push_back(squeeze);

reshape->params["shape"] = out->shape;
squeeze->params["dim"] = dim;

squeeze_in->params["__batch_index"] = batch_index;
}

break;
Expand Down

0 comments on commit a12baae

Please sign in to comment.