Skip to content

Commit

Permalink
[Relay] conv3d depthwise bug fix (#16151)
Browse files Browse the repository at this point in the history
* fix conv3d depthwise weight shape

* apply cpplint

* apply clang-format

* add test case for conv3d depthwise

* appy lint
  • Loading branch information
jonghewk authored Nov 23, 2023
1 parent bce8243 commit 1a2cc18
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 2 deletions.
29 changes: 27 additions & 2 deletions src/relay/op/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -460,8 +460,33 @@ bool Conv3DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
if (param->kernel_size.defined() && param->channels.defined()) {
ICHECK_EQ(param->kernel_size.size(), 3);
ICHECK_EQ(param->dilation.size(), 3);
Array<IndexExpr> wshape({param->channels, indexdiv(dshape_ncdhw[1], param->groups),
param->kernel_size[0], param->kernel_size[1], param->kernel_size[2]});

bool is_depthwise = false;
if (param->groups > 1) {
if (!(weight && weight->shape.defined())) {
reporter->GetDiagCtx().Emit(
Diagnostic::Error(reporter->GetSpan())
<< "Weight shape must be specified when groups is greater than 1.");
return false;
}

Array<IndexExpr> wshape_oidhw = trans_kernel_layout.ForwardShape(weight->shape);
if (tvm::tir::ExprDeepEqual()(param->groups, dshape_ncdhw[1]) &&
tvm::tir::ExprDeepEqual()(param->groups, wshape_oidhw[0])) {
is_depthwise = true;
}
}

Array<IndexExpr> wshape;
if (is_depthwise) {
auto channel_multiplier = indexdiv(param->channels, dshape_ncdhw[1]);
wshape = {dshape_ncdhw[1], channel_multiplier, param->kernel_size[0], param->kernel_size[1],
param->kernel_size[2]};
} else {
wshape = {param->channels, indexdiv(dshape_ncdhw[1], param->groups), param->kernel_size[0],
param->kernel_size[1], param->kernel_size[2]};
}

wshape = trans_kernel_layout.BackwardShape(wshape);
channels = param->channels;
dilated_ksize_z = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
Expand Down
24 changes: 24 additions & 0 deletions tests/python/relay/test_op_level2.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,30 @@ def test_conv3d_transpose_ncdhw_run():
tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5)


def test_compile_depthwise_conv3d():
dshape = [1, 16, 10, 10, 10]
wshape = [16, 2, 1, 1, 1]
params = {}
data = relay.var("data", shape=dshape, dtype="float32")
kernel = relay.const(tvm.nd.array(np.ones(shape=wshape).astype(dtype="float32")))
mod = tvm.IRModule()
res = relay.nn.conv3d(
data,
kernel,
kernel_size=[1, 1, 1],
padding=[0] * 3,
channels=32,
groups=16,
data_layout="NCDHW",
kernel_layout="OIDHW",
)
func = relay.Function([data], res)
mod = tvm.IRModule.from_expr(func)

target = "llvm"
_ = relay.build(mod, tvm.target.Target(target, host=target))


@tvm.testing.uses_gpu
def test_conv2d_transpose_infer_type():
# symbolic in batch dimension
Expand Down

0 comments on commit 1a2cc18

Please sign in to comment.