From 4969495f31414ff24e92ba69e2d7343559d2fa60 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Thu, 8 Jun 2023 14:21:02 -0700 Subject: [PATCH 01/17] feat: Implement dynamic shape support for floordiv, NumToTensor, layernorm Signed-off-by: Dheeraj Peri --- .../conversion/converters/impl/layer_norm.cpp | 6 +++--- core/conversion/evaluators/aten.cpp | 20 +++++++++++++++++++ core/conversion/evaluators/prim.cpp | 4 ++++ tests/core/conversion/converters/BUILD | 12 +++++------ 4 files changed, 33 insertions(+), 9 deletions(-) diff --git a/core/conversion/converters/impl/layer_norm.cpp b/core/conversion/converters/impl/layer_norm.cpp index 781e061a7f..0c00ee2c4d 100644 --- a/core/conversion/converters/impl/layer_norm.cpp +++ b/core/conversion/converters/impl/layer_norm.cpp @@ -20,8 +20,8 @@ auto layer_norm_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns() /* Layer_Norm normalizes over last N dimensions. normalizaed_shape could be (C,H,W), (H,W), or (W). */ - auto normalized_shape = args[1].unwrapToIntList(); - auto normalized_shape_vec = util::toVec(util::toDims(normalized_shape)); + // This could be an IntList or ITensorList. We only need the size of this list. + auto normalized_shape = args[1].IValue()->toList(); // Unwrap eps. auto eps = args[4].unwrapToDouble(); @@ -30,7 +30,7 @@ auto layer_norm_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns() // Set up axis_ask for E[x]. uint32_t axis_mask = 0; - for (size_t i = 0; i < normalized_shape_vec.size(); i++) { + for (size_t i = 0; i < normalized_shape.size(); i++) { axis_mask |= 1 << (shape.size() - i - 1); } LOG_DEBUG("Axis Mask for E[x]" << std::bitset<32>(axis_mask)); diff --git a/core/conversion/evaluators/aten.cpp b/core/conversion/evaluators/aten.cpp index 838175461e..c797113905 100644 --- a/core/conversion/evaluators/aten.cpp +++ b/core/conversion/evaluators/aten.cpp @@ -9,6 +9,7 @@ #include "torch/csrc/jit/ir/ir.h" #include "torch/torch.h" +#include "core/conversion/converters/converter_util.h" #include "core/conversion/evaluators/eval_macros.h" #include "core/conversion/evaluators/eval_util.h" #include "core/conversion/evaluators/evaluators.h" @@ -677,6 +678,25 @@ auto aten_registrations TORCHTRT_UNUSED = .evaluator( {c10::Symbol::fromQualString("aten::floordiv"), [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { + // Dynamic version of aten::floordiv + if (args.at(n->input(0)).isITensor()) { + if (args.at(n->input(1)).IValue()->isInt()) { + auto int_tensor = scalar_to_tensor(args.at(n->input(1)).IValue()->toInt()); + auto int_itensor = converters::tensor_to_const(ctx, int_tensor, util::node_info(n) + "_constant"); + auto elementwise_layer = converters::add_elementwise( + ctx, + nvinfer1::ElementWiseOperation::kFLOOR_DIV, + args.at(n->input(0)).ITensor(), + int_itensor, + util::node_info(n)); + auto output_tensor = elementwise_layer->getOutput(0); + auto tensor_holder = TensorContainer(); + tensor_holder.hold_tensor(output_tensor); + auto output_ivalue = c10::IValue(std::move(c10::make_intrusive(tensor_holder))); + return output_ivalue; + } + } + // Static version if (args.at(n->input(0)).IValue()->isInt()) { auto a = args.at(n->input(0)).unwrapToInt(); auto b = args.at(n->input(1)).unwrapToInt(); diff --git a/core/conversion/evaluators/prim.cpp b/core/conversion/evaluators/prim.cpp index cbbc109982..456c47fa77 100644 --- a/core/conversion/evaluators/prim.cpp +++ b/core/conversion/evaluators/prim.cpp @@ -32,6 +32,10 @@ auto prim_registrations = .evaluator( {torch::jit::prim::NumToTensor, [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { + // Dynamic version receives an ITensor here so pass that as output directly. + if (args.at(n->input(0)).isITensor()) { + return args.at(n->input(0)).ITensor(); + } return evaluators::scalar_to_tensor(args.at(n->input(0)).IValue()->toScalar()); }}) .evaluator( diff --git a/tests/core/conversion/converters/BUILD b/tests/core/conversion/converters/BUILD index 1973c112fd..477774248d 100644 --- a/tests/core/conversion/converters/BUILD +++ b/tests/core/conversion/converters/BUILD @@ -224,33 +224,33 @@ test_suite( ":test_div", ":test_einsum", ":test_expand", + ":test_index", ":test_instance_norm", ":test_interpolate", - ":test_index", ":test_layer_norm", ":test_linear", ":test_lstm_cell", - ":test_matrix_multiply", ":test_masked_fill", + ":test_matrix_multiply", ":test_max", ":test_normalize", ":test_pooling", ":test_reduce", - ":test_roll", ":test_replication_pad", + ":test_roll", ":test_scatter", ":test_select", ":test_shuffle", + ":test_slice", ":test_softmax", + ":test_split", ":test_squeeze", ":test_stack", - ":test_split", - ":test_slice", ":test_topk", ":test_unary", - ":test_unsqueeze", ":test_unbind", ":test_unpack", + ":test_unsqueeze", ":test_where", ], ) From c62474a925a3f7dc4dfea9b39506de96be2676fd Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Wed, 21 Jun 2023 20:21:16 -0700 Subject: [PATCH 02/17] fix: Return static size for desired dimension if it's available Signed-off-by: Dheeraj Peri --- core/conversion/evaluators/aten.cpp | 16 +++++++++------- core/conversion/evaluators/eval_util.cpp | 4 ++++ 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/core/conversion/evaluators/aten.cpp b/core/conversion/evaluators/aten.cpp index c797113905..8cb6cf2dd6 100644 --- a/core/conversion/evaluators/aten.cpp +++ b/core/conversion/evaluators/aten.cpp @@ -299,7 +299,14 @@ auto aten_registrations TORCHTRT_UNUSED = } else { auto dim = args.at(n->input(1)).unwrapToInt(); if (tensor_var.isITensor()) { - if (ctx->input_is_dynamic) { + auto tensor = tensor_var.ITensor(); + auto dims = util::toVec(tensor->getDimensions()); + auto nbDims = tensor->getDimensions().nbDims; + if (dim < 0) { + dim += nbDims; + } + // Check if selected dimension size is -1 else return static size + if (ctx->input_is_dynamic && dims[dim] == -1) { if (ctx->settings.allow_shape_tensors) { return dynamic_size_layer(ctx, n, args); } else { @@ -307,12 +314,7 @@ auto aten_registrations TORCHTRT_UNUSED = "There may be undefined behavior using dynamic shape and aten::size without setting allow_shape_tensors"); } } - auto tensor = tensor_var.ITensor(); - auto dims = util::toVec(tensor->getDimensions()); - auto nbDims = tensor->getDimensions().nbDims; - if (dim < 0) { - dim += nbDims; - } + return dims[dim]; } else if (tensor_var.IValue()->isTensor()) { auto tensor = tensor_var.unwrapToTensor(); diff --git a/core/conversion/evaluators/eval_util.cpp b/core/conversion/evaluators/eval_util.cpp index 0a0b97cfe1..a40cd48523 100644 --- a/core/conversion/evaluators/eval_util.cpp +++ b/core/conversion/evaluators/eval_util.cpp @@ -45,6 +45,10 @@ c10::IValue dynamic_size_layer(ConversionCtx* ctx, const torch::jit::Node* n, kw // Handle negative axis by refering to nbDims of input Tensor dim = dim < 0 ? dim + maxDim : dim; LOG_DEBUG("Dimension to select: " << dim); + // Check if selected dimension size is -1 else return static size + if (input_dims.d[dim] != -1) { + return input_dims.d[dim]; + } shape_1d_tensor = index_layer(ctx, n, shape_1d_tensor, dim); LOG_DEBUG("Output tensor shape: " << shape_1d_tensor->getDimensions()); From c37eeecac7e1cc5e46017636cf7f3de57c061eda Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Wed, 21 Jun 2023 20:24:30 -0700 Subject: [PATCH 03/17] chore: deleting the previous fix Signed-off-by: Dheeraj Peri --- core/conversion/evaluators/eval_util.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/core/conversion/evaluators/eval_util.cpp b/core/conversion/evaluators/eval_util.cpp index a40cd48523..0a0b97cfe1 100644 --- a/core/conversion/evaluators/eval_util.cpp +++ b/core/conversion/evaluators/eval_util.cpp @@ -45,10 +45,6 @@ c10::IValue dynamic_size_layer(ConversionCtx* ctx, const torch::jit::Node* n, kw // Handle negative axis by refering to nbDims of input Tensor dim = dim < 0 ? dim + maxDim : dim; LOG_DEBUG("Dimension to select: " << dim); - // Check if selected dimension size is -1 else return static size - if (input_dims.d[dim] != -1) { - return input_dims.d[dim]; - } shape_1d_tensor = index_layer(ctx, n, shape_1d_tensor, dim); LOG_DEBUG("Output tensor shape: " << shape_1d_tensor->getDimensions()); From 4b2e2f987ce429a3e95064202ec268b0afacb38c Mon Sep 17 00:00:00 2001 From: Anurag Dixit Date: Mon, 10 Jul 2023 13:36:53 -0700 Subject: [PATCH 04/17] chore: rebase with main branch Signed-off-by: Anurag Dixit --- core/conversion/converters/impl/shuffle.cpp | 98 +++++++++++++ .../conversion/converters/test_shuffle.cpp | 52 +++++++ tests/cpp/test_dynamic_size.cpp | 135 +++++++++++++++++- 3 files changed, 284 insertions(+), 1 deletion(-) diff --git a/core/conversion/converters/impl/shuffle.cpp b/core/conversion/converters/impl/shuffle.cpp index f758c0cc47..bc92964a69 100644 --- a/core/conversion/converters/impl/shuffle.cpp +++ b/core/conversion/converters/impl/shuffle.cpp @@ -64,6 +64,104 @@ static auto shuffle_registrations TORCHTRT_UNUSED = LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); return true; }}) + .pattern( + {"aten::unflatten.int(Tensor self, int dim, int[] sizes) -> (Tensor)", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto in = args[0].ITensorOrFreeze(ctx); + auto dim = args[1].unwrapToInt(); + auto in_shape = util::toVec(in->getDimensions()); + std::vector new_shape; + nvinfer1::ITensor* shape_tensor; + if (ctx->input_is_dynamic) { + /* + * In case the dim is negative + * If the dim in negative range is larger than in_shape, + * then it should run into index out of bound error as expected + */ + if (dim < 0) { + dim = in_shape.size() + dim; + } + std::cout << "Dynamic shape case" << std::endl; + LOG_DEBUG("Using dynamic version of reshape layer"); + if (args[2].isITensorList()) { + std::cout << "isTensorList case" << std::endl; + LOG_DEBUG("Shape tensor is an ITensorList"); + auto expand_shape = args[2].unwrapToITensorList(); + auto shape_layer = ctx->net->addShape(*in); + TORCHTRT_CHECK(shape_layer, "Unable to create shape layer from node: " << *n); + auto shape_1d_tensor = shape_layer->getOutput(0); + + std::vector before_dim_indices_vector(dim); + std::iota(before_dim_indices_vector.begin(), before_dim_indices_vector.end(), 0); + + nvinfer1::ITensor* before_dim_gather_out = nullptr; + if(before_dim_indices_vector.size()){ + at::Tensor before_dim_indices = torch::tensor(before_dim_indices_vector).to(torch::kI32); + auto before_dim_indices_out = converters::tensor_to_const(ctx, before_dim_indices); + auto before_dim_gather_layer = ctx->net->addGather(*shape_1d_tensor, *before_dim_indices_out, 0); + TORCHTRT_CHECK(before_dim_gather_layer, "Unable to create gather layer from node: " << *n); + before_dim_gather_out = before_dim_gather_layer->getOutput(0); + } + + std::vector after_dim_indices_vector(in_shape.size() - (dim + 1)); + std::iota(after_dim_indices_vector.begin(), after_dim_indices_vector.end(), dim + 1); + + nvinfer1::ITensor* after_dim_gather_out = nullptr; + if(after_dim_indices_vector.size()){ + at::Tensor after_dim_indices = torch::tensor(after_dim_indices_vector).to(torch::kI32); + auto after_dim_indices_out = converters::tensor_to_const(ctx, after_dim_indices); + auto after_dim_gather_layer = ctx->net->addGather(*shape_1d_tensor, *after_dim_indices_out, 0); + TORCHTRT_CHECK(after_dim_gather_layer, "Unable to create gather layer from node: " << *n); + after_dim_gather_out = after_dim_gather_layer->getOutput(0); + } + + std::vector shape_tensors; + if(before_dim_gather_out){ + shape_tensors.push_back(before_dim_gather_out); + } + for(auto new_shape_tensor : expand_shape){ + shape_tensors.push_back(new_shape_tensor); + } + if(after_dim_gather_out){ + shape_tensors.push_back(after_dim_gather_out); + } + + auto shape_cat_layer = ctx->net->addConcatenation(shape_tensors.data(), shape_tensors.size()); + TORCHTRT_CHECK(shape_cat_layer, "Unable to create cat layer from node: " << *n); + shape_tensor = shape_cat_layer->getOutput(0); + LOG_DEBUG("Shape tensor shape: " << shape_tensor->getDimensions()); + } else if (args[2].isIntList()) { + auto shape_vec = args[2].unwrapToIntList().vec(); + // New shape + new_shape.insert(new_shape.end(), in_shape.begin(), in_shape.begin() + dim); + new_shape.insert(new_shape.end(), shape_vec.begin(), shape_vec.end()); + new_shape.insert(new_shape.end(), in_shape.begin() + dim + 1, in_shape.end()); + + shape_tensor = tensor_to_const(ctx, torch::tensor(new_shape).to(torch::kI32)); + } else { + LOG_ERROR( + "Invalid IValue type of " << args[2].ivalue_type() + << " detected for shape tensor from node: " << *n); + } + } + else { + new_shape = torch::unflatten(torch::rand(in_shape), dim, args[2].unwrapToIntList().vec()).sizes().vec(); + } + auto shuffle = ctx->net->addShuffle(*in); + shuffle->setName(util::node_info(n).c_str()); + TORCHTRT_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n); + + if (ctx->input_is_dynamic) { + shuffle->setInput(1, *shape_tensor); + } else { + shuffle->setReshapeDimensions(util::toDims(new_shape)); + } + + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0)); + LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); + + return true; + }}) .pattern( {"aten::reshape(Tensor self, int[] shape) -> (Tensor)", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { diff --git a/tests/core/conversion/converters/test_shuffle.cpp b/tests/core/conversion/converters/test_shuffle.cpp index fad50c9340..9c972ba988 100644 --- a/tests/core/conversion/converters/test_shuffle.cpp +++ b/tests/core/conversion/converters/test_shuffle.cpp @@ -364,3 +364,55 @@ TEST(Converters, ATenPixelShuffle5DConvertsCorrectly) { ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); } + +TEST(Converters, ATenUnflattenConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : int = prim::Constant[value=1]() + %3 : int = prim::Constant[value=512]() + %4 : int = prim::Constant[value=1]() + %5 : int = prim::Constant[value=1]() + %6 : int[] = prim::ListConstruct(%3, %4, %5) + %7 : Tensor = aten::unflatten(%x.1, %2, %6) + return (%7))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(0, 5, {1, 512}, {at::kCUDA}); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in}); + + in = at::clone(in); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + +TEST(Converters, ATenUnflattenNegativeDimConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : int = prim::Constant[value=-1]() + %3 : int = prim::Constant[value=512]() + %4 : int = prim::Constant[value=1]() + %5 : int = prim::Constant[value=1]() + %6 : int[] = prim::ListConstruct(%3, %4, %5) + %7 : Tensor = aten::unflatten(%x.1, %2, %6) + return (%7))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(0, 5, {1, 512}, {at::kCUDA}); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in}); + + in = at::clone(in); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} \ No newline at end of file diff --git a/tests/cpp/test_dynamic_size.cpp b/tests/cpp/test_dynamic_size.cpp index 9e46842d9c..e8765d6570 100644 --- a/tests/cpp/test_dynamic_size.cpp +++ b/tests/cpp/test_dynamic_size.cpp @@ -124,4 +124,137 @@ TEST(Converters, ATenResizeGetItemDynShapeMulCorrectly) { auto trt = trt_results[0].reshape(jit_results[0].sizes()); ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); -} \ No newline at end of file +} + +TEST(Converters, ATenUnflattenDynShapeShapeCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : int = prim::Constant[value=1]() + %3 : int = prim::Constant[value=512]() + %4 : int = prim::Constant[value=1]() + %5 : int = prim::Constant[value=1]() + %6 : int[] = prim::ListConstruct(%3, %4, %5) + %7 : Tensor = aten::unflatten(%x.1, %2, %6) + return (%7))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(0, 10, {1, 512}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + +TEST(Converters, ATenUnflattenDynShapeNegativeDimsShapeCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : int = prim::Constant[value=-2]() + %3 : int = prim::Constant[value=512]() + %4 : int = prim::Constant[value=1]() + %5 : int = prim::Constant[value=1]() + %6 : int[] = prim::ListConstruct(%3, %4, %5) + %7 : Tensor = aten::unflatten(%x.1, %2, %6) + return (%7))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(0, 10, {1, 512, 2}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + +TEST(Converters, ATenUnflattenDynShapeITensorShapeCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : int = prim::Constant[value=1]() + %3 : int = aten::size(%x.1, %2) + %4 : int = prim::Constant[value=256]() + %5 : int = prim::Constant[value=2]() + %6 : int[] = prim::ListConstruct(%4, %5) + %7 : Tensor = aten::unflatten(%x.1, %2, %6) + return (%7))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(0, 10, {1, 512, 1}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + +TEST(Converters, ATenUnflattenDynShapeITensorShapeCorrectlyFirstDim) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %1 : int = prim::Constant[value=0]() + %2 : int = prim::Constant[value=1]() + %3 : int = aten::size(%x.1, %1) + %6 : int[] = prim::ListConstruct(%2, %2, %3, %2, %2) + %7 : Tensor = aten::unflatten(%x.1, %1, %6) + return (%7))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(0, 10, {64, 512, 1}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + +TEST(Converters, ATenUnflattenDynShapeITensorShapeCorrectlyLastDim) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %1 : int = prim::Constant[value=2]() + %2 : int = prim::Constant[value=1]() + %3 : int = aten::size(%x.1, %1) + %5 : int = prim::Constant[value=2]() + %6 : int[] = prim::ListConstruct(%3, %2, %2) + %7 : Tensor = aten::unflatten(%x.1, %5, %6) + return (%7))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(0, 10, {1, 512, 9}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} From dad46205b271e6bd2d244c89e1a81f043ca129e0 Mon Sep 17 00:00:00 2001 From: Anurag Dixit Date: Mon, 10 Jul 2023 13:58:28 -0700 Subject: [PATCH 05/17] chore: trigger lint Signed-off-by: Anurag Dixit --- tests/cpp/test_dynamic_size.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/test_dynamic_size.cpp b/tests/cpp/test_dynamic_size.cpp index e8765d6570..c79a4ca6c9 100644 --- a/tests/cpp/test_dynamic_size.cpp +++ b/tests/cpp/test_dynamic_size.cpp @@ -257,4 +257,4 @@ TEST(Converters, ATenUnflattenDynShapeITensorShapeCorrectlyLastDim) { auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true); ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); -} +} From b7e8d1ce328317a45df4054de352b2d0182a2514 Mon Sep 17 00:00:00 2001 From: Anurag Dixit Date: Mon, 10 Jul 2023 17:03:15 -0700 Subject: [PATCH 06/17] chore: apply lint Signed-off-by: Anurag Dixit --- core/conversion/converters/impl/shuffle.cpp | 192 ++++++++++---------- tests/cpp/test_dynamic_size.cpp | 2 +- 2 files changed, 97 insertions(+), 97 deletions(-) diff --git a/core/conversion/converters/impl/shuffle.cpp b/core/conversion/converters/impl/shuffle.cpp index bc92964a69..352729c67e 100644 --- a/core/conversion/converters/impl/shuffle.cpp +++ b/core/conversion/converters/impl/shuffle.cpp @@ -66,102 +66,102 @@ static auto shuffle_registrations TORCHTRT_UNUSED = }}) .pattern( {"aten::unflatten.int(Tensor self, int dim, int[] sizes) -> (Tensor)", - [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - auto in = args[0].ITensorOrFreeze(ctx); - auto dim = args[1].unwrapToInt(); - auto in_shape = util::toVec(in->getDimensions()); - std::vector new_shape; - nvinfer1::ITensor* shape_tensor; - if (ctx->input_is_dynamic) { - /* - * In case the dim is negative - * If the dim in negative range is larger than in_shape, - * then it should run into index out of bound error as expected - */ - if (dim < 0) { - dim = in_shape.size() + dim; - } - std::cout << "Dynamic shape case" << std::endl; - LOG_DEBUG("Using dynamic version of reshape layer"); - if (args[2].isITensorList()) { - std::cout << "isTensorList case" << std::endl; - LOG_DEBUG("Shape tensor is an ITensorList"); - auto expand_shape = args[2].unwrapToITensorList(); - auto shape_layer = ctx->net->addShape(*in); - TORCHTRT_CHECK(shape_layer, "Unable to create shape layer from node: " << *n); - auto shape_1d_tensor = shape_layer->getOutput(0); - - std::vector before_dim_indices_vector(dim); - std::iota(before_dim_indices_vector.begin(), before_dim_indices_vector.end(), 0); - - nvinfer1::ITensor* before_dim_gather_out = nullptr; - if(before_dim_indices_vector.size()){ - at::Tensor before_dim_indices = torch::tensor(before_dim_indices_vector).to(torch::kI32); - auto before_dim_indices_out = converters::tensor_to_const(ctx, before_dim_indices); - auto before_dim_gather_layer = ctx->net->addGather(*shape_1d_tensor, *before_dim_indices_out, 0); - TORCHTRT_CHECK(before_dim_gather_layer, "Unable to create gather layer from node: " << *n); - before_dim_gather_out = before_dim_gather_layer->getOutput(0); - } - - std::vector after_dim_indices_vector(in_shape.size() - (dim + 1)); - std::iota(after_dim_indices_vector.begin(), after_dim_indices_vector.end(), dim + 1); - - nvinfer1::ITensor* after_dim_gather_out = nullptr; - if(after_dim_indices_vector.size()){ - at::Tensor after_dim_indices = torch::tensor(after_dim_indices_vector).to(torch::kI32); - auto after_dim_indices_out = converters::tensor_to_const(ctx, after_dim_indices); - auto after_dim_gather_layer = ctx->net->addGather(*shape_1d_tensor, *after_dim_indices_out, 0); - TORCHTRT_CHECK(after_dim_gather_layer, "Unable to create gather layer from node: " << *n); - after_dim_gather_out = after_dim_gather_layer->getOutput(0); - } - - std::vector shape_tensors; - if(before_dim_gather_out){ - shape_tensors.push_back(before_dim_gather_out); - } - for(auto new_shape_tensor : expand_shape){ - shape_tensors.push_back(new_shape_tensor); - } - if(after_dim_gather_out){ - shape_tensors.push_back(after_dim_gather_out); - } - - auto shape_cat_layer = ctx->net->addConcatenation(shape_tensors.data(), shape_tensors.size()); - TORCHTRT_CHECK(shape_cat_layer, "Unable to create cat layer from node: " << *n); - shape_tensor = shape_cat_layer->getOutput(0); - LOG_DEBUG("Shape tensor shape: " << shape_tensor->getDimensions()); - } else if (args[2].isIntList()) { - auto shape_vec = args[2].unwrapToIntList().vec(); - // New shape - new_shape.insert(new_shape.end(), in_shape.begin(), in_shape.begin() + dim); - new_shape.insert(new_shape.end(), shape_vec.begin(), shape_vec.end()); - new_shape.insert(new_shape.end(), in_shape.begin() + dim + 1, in_shape.end()); - - shape_tensor = tensor_to_const(ctx, torch::tensor(new_shape).to(torch::kI32)); - } else { - LOG_ERROR( - "Invalid IValue type of " << args[2].ivalue_type() - << " detected for shape tensor from node: " << *n); - } - } - else { - new_shape = torch::unflatten(torch::rand(in_shape), dim, args[2].unwrapToIntList().vec()).sizes().vec(); - } - auto shuffle = ctx->net->addShuffle(*in); - shuffle->setName(util::node_info(n).c_str()); - TORCHTRT_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n); - - if (ctx->input_is_dynamic) { - shuffle->setInput(1, *shape_tensor); - } else { - shuffle->setReshapeDimensions(util::toDims(new_shape)); - } - - auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0)); - LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); - - return true; - }}) + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto in = args[0].ITensorOrFreeze(ctx); + auto dim = args[1].unwrapToInt(); + auto in_shape = util::toVec(in->getDimensions()); + std::vector new_shape; + nvinfer1::ITensor* shape_tensor; + if (ctx->input_is_dynamic) { + /* + * In case the dim is negative + * If the dim in negative range is larger than in_shape, + * then it should run into index out of bound error as expected + */ + if (dim < 0) { + dim = in_shape.size() + dim; + } + std::cout << "Dynamic shape case" << std::endl; + LOG_DEBUG("Using dynamic version of reshape layer"); + if (args[2].isITensorList()) { + std::cout << "isTensorList case" << std::endl; + LOG_DEBUG("Shape tensor is an ITensorList"); + auto expand_shape = args[2].unwrapToITensorList(); + auto shape_layer = ctx->net->addShape(*in); + TORCHTRT_CHECK(shape_layer, "Unable to create shape layer from node: " << *n); + auto shape_1d_tensor = shape_layer->getOutput(0); + + std::vector before_dim_indices_vector(dim); + std::iota(before_dim_indices_vector.begin(), before_dim_indices_vector.end(), 0); + + nvinfer1::ITensor* before_dim_gather_out = nullptr; + if (before_dim_indices_vector.size()) { + at::Tensor before_dim_indices = torch::tensor(before_dim_indices_vector).to(torch::kI32); + auto before_dim_indices_out = converters::tensor_to_const(ctx, before_dim_indices); + auto before_dim_gather_layer = ctx->net->addGather(*shape_1d_tensor, *before_dim_indices_out, 0); + TORCHTRT_CHECK(before_dim_gather_layer, "Unable to create gather layer from node: " << *n); + before_dim_gather_out = before_dim_gather_layer->getOutput(0); + } + + std::vector after_dim_indices_vector(in_shape.size() - (dim + 1)); + std::iota(after_dim_indices_vector.begin(), after_dim_indices_vector.end(), dim + 1); + + nvinfer1::ITensor* after_dim_gather_out = nullptr; + if (after_dim_indices_vector.size()) { + at::Tensor after_dim_indices = torch::tensor(after_dim_indices_vector).to(torch::kI32); + auto after_dim_indices_out = converters::tensor_to_const(ctx, after_dim_indices); + auto after_dim_gather_layer = ctx->net->addGather(*shape_1d_tensor, *after_dim_indices_out, 0); + TORCHTRT_CHECK(after_dim_gather_layer, "Unable to create gather layer from node: " << *n); + after_dim_gather_out = after_dim_gather_layer->getOutput(0); + } + + std::vector shape_tensors; + if (before_dim_gather_out) { + shape_tensors.push_back(before_dim_gather_out); + } + for (auto new_shape_tensor : expand_shape) { + shape_tensors.push_back(new_shape_tensor); + } + if (after_dim_gather_out) { + shape_tensors.push_back(after_dim_gather_out); + } + + auto shape_cat_layer = ctx->net->addConcatenation(shape_tensors.data(), shape_tensors.size()); + TORCHTRT_CHECK(shape_cat_layer, "Unable to create cat layer from node: " << *n); + shape_tensor = shape_cat_layer->getOutput(0); + LOG_DEBUG("Shape tensor shape: " << shape_tensor->getDimensions()); + } else if (args[2].isIntList()) { + auto shape_vec = args[2].unwrapToIntList().vec(); + // New shape + new_shape.insert(new_shape.end(), in_shape.begin(), in_shape.begin() + dim); + new_shape.insert(new_shape.end(), shape_vec.begin(), shape_vec.end()); + new_shape.insert(new_shape.end(), in_shape.begin() + dim + 1, in_shape.end()); + + shape_tensor = tensor_to_const(ctx, torch::tensor(new_shape).to(torch::kI32)); + } else { + LOG_ERROR( + "Invalid IValue type of " << args[2].ivalue_type() + << " detected for shape tensor from node: " << *n); + } + } else { + new_shape = + torch::unflatten(torch::rand(in_shape), dim, args[2].unwrapToIntList().vec()).sizes().vec(); + } + auto shuffle = ctx->net->addShuffle(*in); + shuffle->setName(util::node_info(n).c_str()); + TORCHTRT_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n); + + if (ctx->input_is_dynamic) { + shuffle->setInput(1, *shape_tensor); + } else { + shuffle->setReshapeDimensions(util::toDims(new_shape)); + } + + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0)); + LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); + + return true; + }}) .pattern( {"aten::reshape(Tensor self, int[] shape) -> (Tensor)", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { diff --git a/tests/cpp/test_dynamic_size.cpp b/tests/cpp/test_dynamic_size.cpp index c79a4ca6c9..e8765d6570 100644 --- a/tests/cpp/test_dynamic_size.cpp +++ b/tests/cpp/test_dynamic_size.cpp @@ -257,4 +257,4 @@ TEST(Converters, ATenUnflattenDynShapeITensorShapeCorrectlyLastDim) { auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true); ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); -} +} From 948dc5886b8925cf348c9875d1e73650f3732e81 Mon Sep 17 00:00:00 2001 From: Anurag Dixit Date: Mon, 10 Jul 2023 17:27:11 -0700 Subject: [PATCH 07/17] chore: Adopting API change from main branch Signed-off-by: Anurag Dixit --- core/conversion/converters/impl/shuffle.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/conversion/converters/impl/shuffle.cpp b/core/conversion/converters/impl/shuffle.cpp index 352729c67e..5a8e992d90 100644 --- a/core/conversion/converters/impl/shuffle.cpp +++ b/core/conversion/converters/impl/shuffle.cpp @@ -140,7 +140,7 @@ static auto shuffle_registrations TORCHTRT_UNUSED = shape_tensor = tensor_to_const(ctx, torch::tensor(new_shape).to(torch::kI32)); } else { LOG_ERROR( - "Invalid IValue type of " << args[2].ivalue_type() + "Invalid IValue type of " << args[2].IValue()->type() << " detected for shape tensor from node: " << *n); } } else { From a47b5fe0ffe0c783b235bfc212da8a3f1387b307 Mon Sep 17 00:00:00 2001 From: Anurag Dixit Date: Tue, 11 Jul 2023 15:29:50 -0700 Subject: [PATCH 08/17] chore: Removing redundant test cases Signed-off-by: Anurag Dixit --- tests/cpp/test_dynamic_size.cpp | 56 --------------------------------- 1 file changed, 56 deletions(-) diff --git a/tests/cpp/test_dynamic_size.cpp b/tests/cpp/test_dynamic_size.cpp index e8765d6570..c1edff849d 100644 --- a/tests/cpp/test_dynamic_size.cpp +++ b/tests/cpp/test_dynamic_size.cpp @@ -126,62 +126,6 @@ TEST(Converters, ATenResizeGetItemDynShapeMulCorrectly) { ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); } -TEST(Converters, ATenUnflattenDynShapeShapeCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor): - %2 : int = prim::Constant[value=1]() - %3 : int = prim::Constant[value=512]() - %4 : int = prim::Constant[value=1]() - %5 : int = prim::Constant[value=1]() - %6 : int[] = prim::ListConstruct(%3, %4, %5) - %7 : Tensor = aten::unflatten(%x.1, %2, %6) - return (%7))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - auto in = at::randint(0, 10, {1, 512}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true); - - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); -} - -TEST(Converters, ATenUnflattenDynShapeNegativeDimsShapeCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor): - %2 : int = prim::Constant[value=-2]() - %3 : int = prim::Constant[value=512]() - %4 : int = prim::Constant[value=1]() - %5 : int = prim::Constant[value=1]() - %6 : int[] = prim::ListConstruct(%3, %4, %5) - %7 : Tensor = aten::unflatten(%x.1, %2, %6) - return (%7))IR"; - - auto g = std::make_shared(); - - torch::jit::parseIR(graph, g.get()); - - auto in = at::randint(0, 10, {1, 512, 2}, {at::kCUDA}); - - auto jit_in = at::clone(in); - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); - - auto trt_in = at::clone(in); - params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true); - - ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); -} - TEST(Converters, ATenUnflattenDynShapeITensorShapeCorrectly) { const auto graph = R"IR( graph(%x.1 : Tensor): From 6e51901902d481219ce3732b08a0c8f0353c90a2 Mon Sep 17 00:00:00 2001 From: Anurag Dixit Date: Wed, 12 Jul 2023 17:05:16 -0700 Subject: [PATCH 09/17] feat: Added a variant for aten::fake_quant_per_tensor Signed-off-by: Anurag Dixit --- .../converters/impl/quantization.cpp | 36 +++++++++++++------ .../converters/test_quantization.cpp | 34 ++++++++++++++++++ 2 files changed, 59 insertions(+), 11 deletions(-) diff --git a/core/conversion/converters/impl/quantization.cpp b/core/conversion/converters/impl/quantization.cpp index e8fdc69f84..addf629e6b 100644 --- a/core/conversion/converters/impl/quantization.cpp +++ b/core/conversion/converters/impl/quantization.cpp @@ -11,6 +11,22 @@ namespace { #if NV_TENSORRT_MAJOR > 7 // clang-format off + +bool add_qdq(ConversionCtx *ctx, const torch::jit::Node* n, nvinfer1::ITensor* input, nvinfer1::ITensor* scale, std::string& opName) { + nvinfer1::IQuantizeLayer* quantize_layer = ctx->net->addQuantize(*input, *scale); + TORCHTRT_CHECK(quantize_layer, "Unable to create QuantizeLayer from node: " << *n); + quantize_layer->setAxis(0); + + nvinfer1::IDequantizeLayer* dequantize_layer = ctx->net->addDequantize(*quantize_layer->getOutput(0), *scale); + TORCHTRT_CHECK(dequantize_layer, "Unable to create DequantizeLayer from node: " << *n); + dequantize_layer->setAxis(0); + + auto qdq_out = ctx->AssociateValueAndTensor(n->outputs()[0], dequantize_layer->getOutput(0)); + LOG_DEBUG("[" << opName << "]"<< " Output tensor shape: " << qdq_out->getDimensions()); + + return true; +} + auto quantization_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns() .pattern({"aten::fake_quantize_per_tensor_affine(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> (Tensor)", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { @@ -20,18 +36,16 @@ auto quantization_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns auto scale = args[1].unwrapToScalar().to(); auto scaleTensor = tensor_to_const(ctx, torch::tensor({scale})); // Add and configure a QuantizeLayer. - nvinfer1::IQuantizeLayer* quantize_layer = ctx->net->addQuantize(*input, *scaleTensor); - quantize_layer->setAxis(0); - - // Add and configure DequantizeLayer following a QuantizeLayer - nvinfer1::IDequantizeLayer* dequantize_layer = ctx->net->addDequantize(*quantize_layer->getOutput(0), *scaleTensor); - dequantize_layer->setAxis(0); - - auto qdq_out = ctx->AssociateValueAndTensor(n->outputs()[0], dequantize_layer->getOutput(0)); - LOG_DEBUG("[fake_quantize_per_tensor_affine] Output tensor shape: " << qdq_out->getDimensions()); - - return true; + std::string opName("aten::fake_quantize_per_tensor_affine"); + return add_qdq(ctx, n, input, scaleTensor, opName); }}) + .pattern({"aten::fake_quantize_per_tensor_affine.tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max) -> (Tensor)", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto input = args[0].ITensorOrFreeze(ctx); + auto scale = args[1].ITensorOrFreeze(ctx); + std::string opName("aten::fake_quantize_per_tensor_affine.tensor_qparams"); + return add_qdq(ctx, n, input, scale, opName); + }}) .pattern({"aten::fake_quantize_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> (Tensor)", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { // This aten operator is generated from torch.fake_quantize_per_channel_affine op in Pytorch python API. diff --git a/tests/core/conversion/converters/test_quantization.cpp b/tests/core/conversion/converters/test_quantization.cpp index fcbef02e16..d6881bb37e 100644 --- a/tests/core/conversion/converters/test_quantization.cpp +++ b/tests/core/conversion/converters/test_quantization.cpp @@ -30,6 +30,40 @@ TEST(Converters, ATenFakeQuantizePerTensorConvertsCorrectly) { torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); } +TEST(Converters, ATenFakeQuantizePerTensorWithParamsConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %22 : int = prim::Constant[value=-128]() + %14 : int = prim::Constant[value=4]() + %9 : None = prim::Constant() + %35 : Device = prim::Constant[value="cuda:0"]() + %6 : int = prim::Constant[value=6]() + %7 : int = prim::Constant[value=3]() + %3 : int = prim::Constant[value=1]() + %5 : float = prim::Constant[value=3.5]() + %13 : int = prim::Constant[value=1]() + %23 : int = prim::Constant[value=127]() + %4 : int[] = prim::ListConstruct(%3) + %11 : Tensor = aten::full(%4, %5, %6, %9, %35, %9) + %12 : int[] = prim::ListConstruct(%3) + %19 : Tensor = aten::full(%12, %13, %7, %9, %35, %9) + %quant_input.1 : Tensor = aten::fake_quantize_per_tensor_affine(%x.1, %11, %19, %22, %23) + return (%quant_input.1))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {1, 5, 5, 5}, {at::kCUDA}).to(at::kFloat); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}, nvinfer1::DataType::kINT8); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + TEST(Converters, ATenFakeQuantizePerChannelConvertsCorrectly) { const auto graph = R"IR( graph(%x.1 : Tensor): From 165981c1310a307d64dca9fcc579e2153b9fa408 Mon Sep 17 00:00:00 2001 From: Michael Feliz Date: Fri, 14 Jul 2023 11:52:55 -0700 Subject: [PATCH 10/17] Add support for dynamic select and masked_fill --- core/conversion/converters/impl/select.cpp | 10 ++---- .../converters/test_masked_fill.cpp | 34 +++++++++++++++++++ .../conversion/converters/test_select.cpp | 26 ++++++++++++++ 3 files changed, 63 insertions(+), 7 deletions(-) diff --git a/core/conversion/converters/impl/select.cpp b/core/conversion/converters/impl/select.cpp index 7942688db8..8334205879 100644 --- a/core/conversion/converters/impl/select.cpp +++ b/core/conversion/converters/impl/select.cpp @@ -165,10 +165,7 @@ auto select_registrations TORCHTRT_UNUSED = } shuffle_layer->setReshapeDimensions(util::squeezeDims( - out->getDimensions(), - dim, - ctx->input_is_dynamic, - ctx->input_is_dynamic && (num_zero_dimensions > 0))); + out->getDimensions(), dim, false, ctx->input_is_dynamic && (num_zero_dimensions > 0))); shuffle_layer->setName(util::node_info(n).c_str()); out = shuffle_layer->getOutput(0); } @@ -710,9 +707,8 @@ auto select_registrations TORCHTRT_UNUSED = auto val_t_dtype = util::TRTDataTypeToScalarType(self->getType()); // Initialize contant tensor for fill with the inherited data type - auto val_t = tensor_to_const( - ctx, torch::full(util::toVec(self->getDimensions()), val, {torch::dtype(val_t_dtype)})); - + std::vector singleton_dims(self->getDimensions().nbDims, 1); + auto val_t = tensor_to_const(ctx, torch::full(singleton_dims, val, {torch::dtype(val_t_dtype)})); TORCHTRT_CHECK( util::broadcastable(self->getDimensions(), mask->getDimensions(), /*multidirectional=*/false), "Self and mask tensors are not broadcastable"); diff --git a/tests/core/conversion/converters/test_masked_fill.cpp b/tests/core/conversion/converters/test_masked_fill.cpp index 2c375463e5..0931112a5e 100644 --- a/tests/core/conversion/converters/test_masked_fill.cpp +++ b/tests/core/conversion/converters/test_masked_fill.cpp @@ -43,6 +43,40 @@ TEST(Converters, ATenMaskedFillZerosConvertsCorrectly) { torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); } +TEST(Converters, ATenMaskedFillZerosDynamicConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %44 : Device = prim::Constant[value="cuda"]() + %8 : bool = prim::Constant[value=0]() + %7 : None = prim::Constant() + %f32_dtype: int = prim::Constant[value=11]() + %1 : int = prim::Constant[value=0]() # bert.py:5:26 + %2 : int = prim::Constant[value=1]() # bert.py:5:32 + %33 : int = prim::Constant[value=2]() # bert.py:6:31 + %3 : int[] = prim::ListConstruct(%1, %1, %2) + %9 : Tensor = aten::tensor(%3, %f32_dtype, %7, %8) # bert.py:5:11 + %mask.1 : Tensor = aten::to(%9, %44, %7, %8, %8) # bert.py:5:11 + %mask.2 : Tensor = trt::const(%mask.1) + %34 : Tensor = aten::masked_fill(%x.1, %mask.1, %33) # bert.py:6:11 + return (%34, %mask.2))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, &*g); + + auto in = at::zeros({1, 2, 3}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + torch_tensorrt::core::lowering::passes::RemoveNOPs(g); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0])); +} + TEST(Converters, ATenMaskedFillMixedTypesFloatIntConvertsCorrectly) { const auto graph = R"IR( graph(%x.1 : Tensor, %x.2 : Tensor): diff --git a/tests/core/conversion/converters/test_select.cpp b/tests/core/conversion/converters/test_select.cpp index d2af33f099..4de4dc10a1 100644 --- a/tests/core/conversion/converters/test_select.cpp +++ b/tests/core/conversion/converters/test_select.cpp @@ -32,6 +32,32 @@ TEST(Converters, ATenSelectIntConvertsCorrectly) { ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); } +TEST(Converters, ATenSelectIntDynamicConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor): + %2 : int = prim::Constant[value=0]() + %3 : Tensor = aten::select(%0, %2, %2) + return (%3))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {5, 7, 9}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}); + + auto trt = trt_results[0].reshape(jit_results[0].sizes()); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} + TEST(Converters, ATenSelectIntDimIsOneConvertsCorrectly) { const auto graph = R"IR( graph(%0 : Tensor): From 03a75b148ea5ad7f6ca398916fb1dd85af7a971f Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Mon, 31 Jul 2023 19:34:02 -0700 Subject: [PATCH 11/17] chore: fix the docgen job Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- .github/workflows/docgen.yml | 60 ++-------- packaging/pre_build_script.sh | 13 +++ .../WORKSPACE.x86_64.cu121.release.rhel | 103 ++++++++++++++++++ 3 files changed, 125 insertions(+), 51 deletions(-) create mode 100644 packaging/pre_build_script.sh create mode 100644 toolchains/ci_workspaces/WORKSPACE.x86_64.cu121.release.rhel diff --git a/.github/workflows/docgen.yml b/.github/workflows/docgen.yml index 36b7ec9bab..bdb5f3d775 100644 --- a/.github/workflows/docgen.yml +++ b/.github/workflows/docgen.yml @@ -12,75 +12,33 @@ jobs: build-docs: runs-on: ubuntu-20.04 container: - image: ghcr.io/pytorch/tensorrt/docgen:latest + image: docker.io/pytorch/manylinux-builder:cuda12.1 credentials: username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} steps: - - name: Reclaim space - run: | - rm -rf /usr/share/dotnet - rm -rf /opt/ghc - rm -rf "/usr/local/share/boost" - rm -rf /usr/local/cuda/cuda-* - - name: Install base deps - run: | - apt update - DEBIAN_FRONTEND=noninteractive apt install -y software-properties-common gcc git curl wget make zlib1g-dev bzip2 libbz2-dev lzma lzma-dev libreadline-dev libsqlite3-dev libssl-dev libffi-dev doxygen pandoc - git config --global --add safe.directory '*' - - name: Set up Python 3.10.12 - uses: actions/setup-python@v4 - with: - python-version: 3.10.12 - uses: actions/checkout@v3 with: ref: ${{github.head_ref}} + - name: Install base deps + run: | + ./packaging/pre_build_script.sh - name: Get HEAD SHA id: vars run: echo "sha=$(git rev-parse --short HEAD)" >> $GITHUB_OUTPUT - name: Get Bazel version id: bazel_info run: echo "version=$(cat .bazelversion)" >> $GITHUB_OUTPUT - - name: Install Bazel - run: | - wget -q https://github.com/bazelbuild/bazel/releases/download/${{ steps.bazel_info.outputs.version }}/bazel-${{ steps.bazel_info.outputs.version }}-linux-x86_64 -O /usr/bin/bazel - chmod a+x /usr/bin/bazel - - name: Install cudnn + tensorrt - run: | - apt-get update - wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/cuda-ubuntu2004.pin - mv cuda-ubuntu2004.pin /etc/apt/preferences.d/cuda-repository-pin-600 - apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/7fa2af80.pub - apt-key adv --keyserver keyserver.ubuntu.com --recv-keys 536F8F1DE80F6A35 - apt-key adv --keyserver keyserver.ubuntu.com --recv-keys A4B469963BF863CC - add-apt-repository "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/ /" - apt-get update - apt-get install -y libcudnn8 libcudnn8-dev - - apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/3bf863cc.pub - add-apt-repository "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/ /" - apt-get update - - apt-get install -y libnvinfer8 libnvinfer-plugin8 libnvinfer-dev libnvinfer-plugin-dev - - name: Install Torch - run: | - python3 -m pip install -r py/requirements.txt --user - name: Build Python Package run: | - cp toolchains/ci_workspaces/WORKSPACE.x86_64 WORKSPACE - cd py - python3 -m pip install pip==21.3.1 - echo $(which python3) - echo $(python3 -c 'import site; print(site.getsitepackages()[0])') - mkdir -p /opt/circleci/.pyenv/versions/3.9.4/lib/python3.9/ - ln -s $(python3 -c 'import site; print(site.getsitepackages()[0])') /opt/circleci/.pyenv/versions/3.9.4/lib/python3.9/site-packages - python3 setup.py install - cd .. + cp toolchains/ci_workspaces/WORKSPACE.x86_64.cu121.release.rhel WORKSPACE + python -m pip install pip<=23 + python -m pip install --pre -e . --extra-index-url https://download.pytorch.org/whl/nightly/cu121 - name: Generate New Docs run: | cd docsrc - python3 -m pip install -r requirements.txt - python3 -c "import torch_tensorrt; print(torch_tensorrt.__version__)" + python -m pip install -r requirements.txt + python -c "import torch_tensorrt; print(torch_tensorrt.__version__)" make html cd .. - uses: stefanzweifel/git-auto-commit-action@v4 diff --git a/packaging/pre_build_script.sh b/packaging/pre_build_script.sh new file mode 100644 index 0000000000..49a3a187d5 --- /dev/null +++ b/packaging/pre_build_script.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +# Install dependencies +TRT_VERSION=$(python3 -c "import versions; versions.tensorrt_version()") +yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel7/x86_64/cuda-rhel7.repo +yum check-update +yum install -y ninja-build tensorrt-${TRT_VERSION}.* +wget https://github.com/bazelbuild/bazelisk/releases/download/v1.17.0/bazelisk-linux-amd64 \ + && mv bazelisk-linux-amd64 /usr/bin/bazel \ + && chmod +x /usr/bin/bazel + +cp toolchains/ci_workspaces/WORKSPACE.x86_64.${VERSION_SUFFIX#*+}.release.rhel WORKSPACE +export CI_BUILD=1 diff --git a/toolchains/ci_workspaces/WORKSPACE.x86_64.cu121.release.rhel b/toolchains/ci_workspaces/WORKSPACE.x86_64.cu121.release.rhel new file mode 100644 index 0000000000..2fc09e8219 --- /dev/null +++ b/toolchains/ci_workspaces/WORKSPACE.x86_64.cu121.release.rhel @@ -0,0 +1,103 @@ +workspace(name = "Torch-TensorRT") + +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + +http_archive( + name = "rules_python", + sha256 = "863ba0fa944319f7e3d695711427d9ad80ba92c6edd0b7c7443b84e904689539", + strip_prefix = "rules_python-0.22.0", + url = "https://github.com/bazelbuild/rules_python/releases/download/0.22.0/rules_python-0.22.0.tar.gz", +) + +load("@rules_python//python:repositories.bzl", "py_repositories") + +py_repositories() + +http_archive( + name = "rules_pkg", + sha256 = "8f9ee2dc10c1ae514ee599a8b42ed99fa262b757058f65ad3c384289ff70c4b8", + urls = [ + "https://mirror.bazel.build/github.com/bazelbuild/rules_pkg/releases/download/0.9.1/rules_pkg-0.9.1.tar.gz", + "https://github.com/bazelbuild/rules_pkg/releases/download/0.9.1/rules_pkg-0.9.1.tar.gz", + ], +) + +load("@rules_pkg//:deps.bzl", "rules_pkg_dependencies") + +rules_pkg_dependencies() + +http_archive( + name = "googletest", + sha256 = "755f9a39bc7205f5a0c428e920ddad092c33c8a1b46997def3f1d4a82aded6e1", + strip_prefix = "googletest-5ab508a01f9eb089207ee87fd547d290da39d015", + urls = ["https://github.com/google/googletest/archive/5ab508a01f9eb089207ee87fd547d290da39d015.zip"], +) + +# External dependency for torch_tensorrt if you already have precompiled binaries. +local_repository( + name = "torch_tensorrt", + path = "/opt/circleci/.pyenv/versions/3.10.9/lib/python3.10/site-packages/torch_tensorrt" +) + +# CUDA should be installed on the system locally +new_local_repository( + name = "cuda", + build_file = "@//third_party/cuda:BUILD", + path = "/usr/local/cuda-12.1", +) + +new_local_repository( + name = "cublas", + build_file = "@//third_party/cublas:BUILD", + path = "/usr", +) +############################################################################################################# +# Tarballs and fetched dependencies (default - use in cases when building from precompiled bin and tarballs) +############################################################################################################# + +http_archive( + name = "libtorch", + build_file = "@//third_party/libtorch:BUILD", + sha256 = "174579a7ee2a506d063714160c5fc57da428f7935311ef511c8f19820eb14c86", + strip_prefix = "libtorch", + urls = ["https://download.pytorch.org/libtorch/nightly/cu121/libtorch-cxx11-abi-shared-with-deps-2.1.0.dev20230731%2Bcu121.zip"], +) + +http_archive( + name = "libtorch_pre_cxx11_abi", + build_file = "@//third_party/libtorch:BUILD", + sha256 = "532217063c65354d5534211badadc9c370d889cb1c3fdb295c9b3d0f181bc0ba", + strip_prefix = "libtorch", + urls = ["https://download.pytorch.org/libtorch/nightly/cu121/libtorch-shared-with-deps-2.1.0.dev20230731%2Bcu121.zip"], +) + +#################################################################################### +# Locally installed dependencies (use in cases of custom dependencies or aarch64) +#################################################################################### + +new_local_repository( + name = "cudnn", + path = "/usr/", + build_file = "@//third_party/cudnn/local:BUILD" +) + +new_local_repository( + name = "tensorrt", + path = "/usr/", + build_file = "@//third_party/tensorrt/local:BUILD" +) + +# ######################################################################### +# # Testing Dependencies (optional - comment out on aarch64) +# ######################################################################### + +load("@rules_python//python:pip.bzl", "pip_parse") + +pip_parse( + name = "devtools_deps", + requirements = "//:requirements-dev.txt", +) + +load("@devtools_deps//:requirements.bzl", "install_deps") + +install_deps() From 15a6e576c82d5cbb489bf6ca525e111673f7a3e8 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Fri, 5 May 2023 10:04:05 -0700 Subject: [PATCH 12/17] feat: Add example usage scripts for dynamo path - Add sample scripts covering resnet18, transformers, and custom examples showcasing the `torch_tensorrt.dynamo.compile` path, which can compile models with data-dependent control flow and other such restrictions which can make other compilation methods more difficult - Cover different customizeable features allowed in the new backend - Make scripts Sphinx-Gallery compatible Python files fix: Update `index.rst` - Show individual links in sidebar chore: Add note about Cuda Driver Error - Update arguments to Dynamo compile call in line with new schema updates fix: Update function calls to address API changes fix: Update file and reference naming for new API --- .gitignore | 3 +- docsrc/conf.py | 7 ++ docsrc/index.rst | 31 +++-- docsrc/requirements.txt | 1 + examples/dynamo/README.rst | 10 ++ .../dynamo/torch_compile_advanced_usage.py | 103 +++++++++++++++++ .../dynamo/torch_compile_resnet_example.py | 93 +++++++++++++++ .../torch_compile_transformers_example.py | 108 ++++++++++++++++++ 8 files changed, 346 insertions(+), 10 deletions(-) create mode 100644 examples/dynamo/README.rst create mode 100644 examples/dynamo/torch_compile_advanced_usage.py create mode 100644 examples/dynamo/torch_compile_resnet_example.py create mode 100644 examples/dynamo/torch_compile_transformers_example.py diff --git a/.gitignore b/.gitignore index 3a7a3b462d..918b69b27c 100644 --- a/.gitignore +++ b/.gitignore @@ -32,6 +32,7 @@ docsrc/_build docsrc/_notebooks docsrc/_cpp_api docsrc/_tmp +docsrc/tutorials/_rendered_examples *.so __pycache__ *.egg-info @@ -67,4 +68,4 @@ bazel-tensorrt *cifar-10-batches-py* bazel-project build/ -wheelhouse/ \ No newline at end of file +wheelhouse/ diff --git a/docsrc/conf.py b/docsrc/conf.py index 4794450399..cbe068d7ea 100644 --- a/docsrc/conf.py +++ b/docsrc/conf.py @@ -47,6 +47,7 @@ "sphinx.ext.coverage", "sphinx.ext.mathjax", "sphinx.ext.viewcode", + "sphinx_gallery.gen_gallery", ] napoleon_use_ivar = True @@ -79,6 +80,12 @@ # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ["_static"] +# sphinx-gallery configuration +sphinx_gallery_conf = { + "examples_dirs": "../examples/dynamo", + "gallery_dirs": "tutorials/_rendered_examples/", +} + # Setup the breathe extension breathe_projects = {"Torch-TensorRT": "./_tmp/xml"} breathe_default_project = "Torch-TensorRT" diff --git a/docsrc/index.rst b/docsrc/index.rst index e5da81d2a5..b851ca3481 100644 --- a/docsrc/index.rst +++ b/docsrc/index.rst @@ -36,30 +36,43 @@ Getting Started getting_started/getting_started_with_windows -Tutorials +User Guide ------------ * :ref:`creating_a_ts_mod` * :ref:`getting_started_with_fx` * :ref:`ptq` * :ref:`runtime` -* :ref:`serving_torch_tensorrt_with_triton` * :ref:`use_from_pytorch` * :ref:`using_dla` + +.. toctree:: + :caption: User Guide + :maxdepth: 1 + :hidden: + + user_guide/creating_torchscript_module_in_python + user_guide/getting_started_with_fx_path + user_guide/ptq + user_guide/runtime + user_guide/use_from_pytorch + user_guide/using_dla + +Tutorials +------------ +* :ref:`serving_torch_tensorrt_with_triton` * :ref:`notebooks` +* :ref:`dynamo_compile` .. toctree:: :caption: Tutorials - :maxdepth: 1 + :maxdepth: 3 :hidden: - tutorials/creating_torchscript_module_in_python - tutorials/getting_started_with_fx_path - tutorials/ptq - tutorials/runtime tutorials/serving_torch_tensorrt_with_triton - tutorials/use_from_pytorch - tutorials/using_dla tutorials/notebooks + tutorials/_rendered_examples/dynamo/torch_compile_resnet_example + tutorials/_rendered_examples/dynamo/torch_compile_transformers_example + tutorials/_rendered_examples/dynamo/torch_compile_advanced_usage Python API Documenation ------------------------ diff --git a/docsrc/requirements.txt b/docsrc/requirements.txt index ccbe311f0f..ac75bf5632 100644 --- a/docsrc/requirements.txt +++ b/docsrc/requirements.txt @@ -1,4 +1,5 @@ sphinx==4.5.0 +sphinx-gallery==0.13.0 breathe==4.33.1 exhale==0.3.1 -e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme diff --git a/examples/dynamo/README.rst b/examples/dynamo/README.rst new file mode 100644 index 0000000000..69e7f404e2 --- /dev/null +++ b/examples/dynamo/README.rst @@ -0,0 +1,10 @@ +.. _torch_compile: + +Dynamo Compile Examples +================ + +This document contains examples of usage of the `torch_tensorrt.dynamo.compile` API which integrates with `torch.compile` functionality + +* :ref:`torch_compile_resnet`: Compiling a ResNet model using the Torch Compile Frontend for ``torch_tensorrt.compile`` +* :ref:`torch_compile_transformer`: Compiling a Transformer model using ``torch.compile`` +* :ref:`torch_compile_advanced_usage`: Advanced usage including making a custom backend to use directly with the ``torch.compile`` API diff --git a/examples/dynamo/torch_compile_advanced_usage.py b/examples/dynamo/torch_compile_advanced_usage.py new file mode 100644 index 0000000000..1d301a16a8 --- /dev/null +++ b/examples/dynamo/torch_compile_advanced_usage.py @@ -0,0 +1,103 @@ +""" +.. _torch_compile_advanced_usage: + +Torch Compile Advanced Usage +====================================================== + +This interactive script is intended as an overview of the process by which `torch_tensorrt.compile(..., ir="torch_compile", ...)` works, and how it integrates with the `torch.compile` API.""" + +# %% +# Imports and Model Definition +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +import torch +import torch_tensorrt + +# %% + +# We begin by defining a model +class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.relu = torch.nn.ReLU() + + def forward(self, x: torch.Tensor, y: torch.Tensor): + x_out = self.relu(x) + y_out = self.relu(y) + x_y_out = x_out + y_out + return torch.mean(x_y_out) + + +# %% +# Compilation with `torch.compile` Using Default Settings +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +# Define sample float inputs and initialize model +sample_inputs = [torch.rand((5, 7)).cuda(), torch.rand((5, 7)).cuda()] +model = Model().eval().cuda() + +# %% + +# Next, we compile the model using torch.compile +# For the default settings, we can simply call torch.compile +# with the backend "torch_tensorrt", and run the model on an +# input to cause compilation, as so: +optimized_model = torch.compile(model, backend="torch_tensorrt") +optimized_model(*sample_inputs) + +# %% +# Compilation with `torch.compile` Using Custom Settings +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +# First, we use Torch utilities to clean up the workspace +# after the previous compile invocation +torch._dynamo.reset() + +# Define sample half inputs and initialize model +sample_inputs_half = [ + torch.rand((5, 7)).half().cuda(), + torch.rand((5, 7)).half().cuda(), +] +model_half = Model().eval().cuda() + +# %% + +# If we want to customize certain options in the backend, +# but still use the torch.compile call directly, we can provide +# custom options to the backend via the "options" keyword +# which takes in a dictionary mapping options to values. +# +# For accepted backend options, see the CompilationSettings dataclass: +# py/torch_tensorrt/dynamo/_settings.py +backend_kwargs = { + "enabled_precisions": {torch.half}, + "debug": True, + "min_block_size": 2, + "torch_executed_ops": {"torch.ops.aten.sub.Tensor"}, + "optimization_level": 4, + "use_python_runtime": False, +} + +# Run the model on an input to cause compilation, as so: +optimized_model_custom = torch.compile( + model_half, backend="torch_tensorrt", options=backend_kwargs +) +optimized_model_custom(*sample_inputs_half) + +# %% +# Cleanup +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +# Finally, we use Torch utilities to clean up the workspace +torch._dynamo.reset() + +# %% +# Cuda Driver Error Note +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# Occasionally, upon exiting the Python runtime after Dynamo compilation with `torch_tensorrt`, +# one may encounter a Cuda Driver Error. This issue is related to https://github.com/NVIDIA/TensorRT/issues/2052 +# and can be resolved by wrapping the compilation/inference in a function and using a scoped call, as in:: +# +# if __name__ == '__main__': +# compile_engine_and_infer() diff --git a/examples/dynamo/torch_compile_resnet_example.py b/examples/dynamo/torch_compile_resnet_example.py new file mode 100644 index 0000000000..9015538fec --- /dev/null +++ b/examples/dynamo/torch_compile_resnet_example.py @@ -0,0 +1,93 @@ +""" +.. _torch_compile_resnet: + +Compiling ResNet using the Torch-TensorRT `torch.compile` Backend +========================================================== + +This interactive script is intended as a sample of the Torch-TensorRT workflow with `torch.compile` on a ResNet model.""" + +# %% +# Imports and Model Definition +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +import torch +import torch_tensorrt +import torchvision.models as models + +# %% + +# Initialize model with half precision and sample inputs +model = models.resnet18(pretrained=True).half().eval().to("cuda") +inputs = [torch.randn((1, 3, 224, 224)).to("cuda").half()] + +# %% +# Optional Input Arguments to `torch_tensorrt.compile` +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +# Enabled precision for TensorRT optimization +enabled_precisions = {torch.half} + +# Whether to print verbose logs +debug = True + +# Workspace size for TensorRT +workspace_size = 20 << 30 + +# Maximum number of TRT Engines +# (Lower value allows more graph segmentation) +min_block_size = 7 + +# Operations to Run in Torch, regardless of converter support +torch_executed_ops = {} + +# %% +# Compilation with `torch_tensorrt.compile` +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +# Build and compile the model with torch.compile, using Torch-TensorRT backend +optimized_model = torch_tensorrt.compile( + model, + ir="torch_compile", + inputs=inputs, + enabled_precisions=enabled_precisions, + debug=debug, + workspace_size=workspace_size, + min_block_size=min_block_size, + torch_executed_ops=torch_executed_ops, +) + +# %% +# Equivalently, we could have run the above via the torch.compile frontend, as so: +# `optimized_model = torch.compile(model, backend="torch_tensorrt", options={"enabled_precisions": enabled_precisions, ...}); optimized_model(*inputs)` + +# %% +# Inference +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +# Does not cause recompilation (same batch size as input) +new_inputs = [torch.randn((1, 3, 224, 224)).half().to("cuda")] +new_outputs = optimized_model(*new_inputs) + +# %% + +# Does cause recompilation (new batch size) +new_batch_size_inputs = [torch.randn((8, 3, 224, 224)).half().to("cuda")] +new_batch_size_outputs = optimized_model(*new_batch_size_inputs) + +# %% +# Cleanup +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +# Finally, we use Torch utilities to clean up the workspace +torch._dynamo.reset() + +# %% +# Cuda Driver Error Note +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# Occasionally, upon exiting the Python runtime after Dynamo compilation with `torch_tensorrt`, +# one may encounter a Cuda Driver Error. This issue is related to https://github.com/NVIDIA/TensorRT/issues/2052 +# and can be resolved by wrapping the compilation/inference in a function and using a scoped call, as in:: +# +# if __name__ == '__main__': +# compile_engine_and_infer() diff --git a/examples/dynamo/torch_compile_transformers_example.py b/examples/dynamo/torch_compile_transformers_example.py new file mode 100644 index 0000000000..5422f9cc1d --- /dev/null +++ b/examples/dynamo/torch_compile_transformers_example.py @@ -0,0 +1,108 @@ +""" +.. _torch_compile_transformer: + +Compiling a Transformer using torch.compile and TensorRT +============================================================== + +This interactive script is intended as a sample of the Torch-TensorRT workflow with `torch.compile` on a transformer-based model.""" + +# %% +# Imports and Model Definition +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +import torch +import torch_tensorrt +from transformers import BertModel + +# %% + +# Initialize model with float precision and sample inputs +model = BertModel.from_pretrained("bert-base-uncased").eval().to("cuda") +inputs = [ + torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"), + torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"), +] + + +# %% +# Optional Input Arguments to `torch_tensorrt.compile` +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +# Enabled precision for TensorRT optimization +enabled_precisions = {torch.float} + +# Whether to print verbose logs +debug = True + +# Workspace size for TensorRT +workspace_size = 20 << 30 + +# Maximum number of TRT Engines +# (Lower value allows more graph segmentation) +min_block_size = 7 + +# Operations to Run in Torch, regardless of converter support +torch_executed_ops = {} + +# %% +# Compilation with `torch.compile` +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +# Define backend compilation keyword arguments +compilation_kwargs = { + "enabled_precisions": enabled_precisions, + "debug": debug, + "workspace_size": workspace_size, + "min_block_size": min_block_size, + "torch_executed_ops": torch_executed_ops, +} + +# Build and compile the model with torch.compile, using Torch-TensorRT backend +optimized_model = torch.compile( + model, + backend="torch_tensorrt", + options=compilation_kwargs, +) +optimized_model(*inputs) + +# %% +# Equivalently, we could have run the above via the convenience frontend, as so: +# `torch_tensorrt.compile(model, ir="torch_compile", inputs=inputs, **compilation_kwargs)` + +# %% +# Inference +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +# Does not cause recompilation (same batch size as input) +new_inputs = [ + torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"), + torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"), +] +new_outputs = optimized_model(*new_inputs) + +# %% + +# Does cause recompilation (new batch size) +new_inputs = [ + torch.randint(0, 2, (4, 14), dtype=torch.int32).to("cuda"), + torch.randint(0, 2, (4, 14), dtype=torch.int32).to("cuda"), +] +new_outputs = optimized_model(*new_inputs) + +# %% +# Cleanup +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +# Finally, we use Torch utilities to clean up the workspace +torch._dynamo.reset() + +# %% +# Cuda Driver Error Note +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# Occasionally, upon exiting the Python runtime after Dynamo compilation with `torch_tensorrt`, +# one may encounter a Cuda Driver Error. This issue is related to https://github.com/NVIDIA/TensorRT/issues/2052 +# and can be resolved by wrapping the compilation/inference in a function and using a scoped call, as in:: +# +# if __name__ == '__main__': +# compile_engine_and_infer() From 968aca415d3a3276ace7332d66a4206b24f510bd Mon Sep 17 00:00:00 2001 From: Naren Dasan <1790613+narendasan@users.noreply.github.com> Date: Tue, 25 Jul 2023 15:20:39 -0600 Subject: [PATCH 13/17] Reorg docs and fix css for sphix-gallery examples (#1967) Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- docsrc/Makefile | 1 + docsrc/_static/css/custom.css | 8 ++ docsrc/_static/css/pytorch_theme.css | 127 ++++++++++++++++++ docsrc/conf.py | 11 +- docsrc/index.rst | 2 +- .../creating_torchscript_module_in_python.rst | 0 .../getting_started_with_fx_path.rst | 0 docsrc/{tutorials => user_guide}/ptq.rst | 0 docsrc/{tutorials => user_guide}/runtime.rst | 0 .../use_from_pytorch.rst | 0 .../{tutorials => user_guide}/using_dla.rst | 0 examples/README.rst | 7 + examples/dynamo/README.rst | 7 +- 13 files changed, 158 insertions(+), 5 deletions(-) create mode 100644 docsrc/_static/css/custom.css create mode 100644 docsrc/_static/css/pytorch_theme.css rename docsrc/{tutorials => user_guide}/creating_torchscript_module_in_python.rst (100%) rename docsrc/{tutorials => user_guide}/getting_started_with_fx_path.rst (100%) rename docsrc/{tutorials => user_guide}/ptq.rst (100%) rename docsrc/{tutorials => user_guide}/runtime.rst (100%) rename docsrc/{tutorials => user_guide}/use_from_pytorch.rst (100%) rename docsrc/{tutorials => user_guide}/using_dla.rst (100%) create mode 100644 examples/README.rst diff --git a/docsrc/Makefile b/docsrc/Makefile index 0ea6796ed8..f30a9ae76e 100644 --- a/docsrc/Makefile +++ b/docsrc/Makefile @@ -35,6 +35,7 @@ endif rm -rf $(SOURCEDIR)/_py_api rm -rf $(SOURCEDIR)/_build rm -rf $(SOURCEDIR)/_tmp + rm -rf $(SOURCEDIR)/tutorials/_rendered_examples html: # mkdir -p $(SOURCEDIR)/_notebooks diff --git a/docsrc/_static/css/custom.css b/docsrc/_static/css/custom.css new file mode 100644 index 0000000000..2523d4e541 --- /dev/null +++ b/docsrc/_static/css/custom.css @@ -0,0 +1,8 @@ +/* sphinx-design styles for cards/tabs +*/ + +.sphx-glr-thumbcontainer { + padding: 50%; + display: flex; + align-content: center; +} \ No newline at end of file diff --git a/docsrc/_static/css/pytorch_theme.css b/docsrc/_static/css/pytorch_theme.css new file mode 100644 index 0000000000..153f4889c0 --- /dev/null +++ b/docsrc/_static/css/pytorch_theme.css @@ -0,0 +1,127 @@ +body { + font-family: "Lato","proxima-nova","Helvetica Neue",Arial,sans-serif; +} + +/* Default header fonts are ugly */ +h1, h2, .rst-content .toctree-wrapper p.caption, h3, h4, h5, h6, legend, p.caption { + font-family: "Lato","proxima-nova","Helvetica Neue",Arial,sans-serif; +} + +/* Use white for docs background */ +.wy-side-nav-search { + background-color: #fff; +} + +.wy-nav-content-wrap, .wy-menu li.current > a { + background-color: #fff; +} + +@media screen and (min-width: 1400px) { + .wy-nav-content-wrap { + background-color: rgba(0, 0, 0, 0.0470588); + } + + .wy-nav-content { + background-color: #fff; + } +} + +/* Fixes for mobile */ +.wy-nav-top { + background-color: #fff; + background-image: url('../img/pytorch-logo-dark.svg'); + background-repeat: no-repeat; + background-position: center; + padding: 0; + margin: 0.4045em 0.809em; + color: #333; +} + +.wy-nav-top > a { + display: none; +} + +@media screen and (max-width: 768px) { + .wy-side-nav-search>a img.logo { + height: 60px; + } +} + +/* This is needed to ensure that logo above search scales properly */ +.wy-side-nav-search a { + display: block; +} + +/* This ensures that multiple constructors will remain in separate lines. */ +.rst-content dl:not(.docutils) dt { + display: table; +} + +/* Use our red for literals (it's very similar to the original color) */ +.rst-content tt.literal, .rst-content tt.literal, .rst-content code.literal { + color: #F05732; +} + +.rst-content tt.xref, a .rst-content tt, .rst-content tt.xref, +.rst-content code.xref, a .rst-content tt, a .rst-content code { + color: #404040; +} + +/* Change link colors (except for the menu) */ + +a { + color: #F05732; +} + +a:hover { + color: #F05732; +} + + +a:visited { + color: #D44D2C; +} + +.wy-menu a { + color: #b3b3b3; +} + +.wy-menu a:hover { + color: #b3b3b3; +} + +a.icon.icon-home { + color: #D44D2C; +} + +.version{ + color: #D44D2C !important; +} + +/* Default footer text is quite big */ +footer { + font-size: 80%; +} + +footer .rst-footer-buttons { + font-size: 125%; /* revert footer settings - 1/80% = 125% */ +} + +footer p { + font-size: 100%; +} + +/* For hidden headers that appear in TOC tree */ +/* see https://stackoverflow.com/a/32363545/3343043 */ +.rst-content .hidden-section { + display: none; +} + +nav .hidden-section { + display: inherit; +} + +/* Make code blocks have a background */ +.codeblock,pre.literal-block,.rst-content .literal-block,.rst-content pre.literal-block,div[class^='highlight'] { + background: rgba(0, 0, 0, 0.0470588); +} diff --git a/docsrc/conf.py b/docsrc/conf.py index cbe068d7ea..0fd4acc2e0 100644 --- a/docsrc/conf.py +++ b/docsrc/conf.py @@ -18,6 +18,9 @@ import torch import pytorch_sphinx_theme import torch_tensorrt +from docutils.parsers.rst import Directive, directives +from docutils.statemachine import StringList +from docutils import nodes # -- Project information ----------------------------------------------------- @@ -79,10 +82,16 @@ # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ["_static"] +# Custom CSS paths should either relative to html_static_path +# or fully qualified paths (eg. https://...) +html_css_files = [ + "https://cdn.jsdelivr.net/npm/katex@0.10.0-beta/dist/katex.min.css", + "css/custom.css", +] # sphinx-gallery configuration sphinx_gallery_conf = { - "examples_dirs": "../examples/dynamo", + "examples_dirs": "../examples", "gallery_dirs": "tutorials/_rendered_examples/", } diff --git a/docsrc/index.rst b/docsrc/index.rst index b851ca3481..eee62bc2f7 100644 --- a/docsrc/index.rst +++ b/docsrc/index.rst @@ -59,9 +59,9 @@ User Guide Tutorials ------------ +* :ref:`torch_tensorrt_tutorials` * :ref:`serving_torch_tensorrt_with_triton` * :ref:`notebooks` -* :ref:`dynamo_compile` .. toctree:: :caption: Tutorials diff --git a/docsrc/tutorials/creating_torchscript_module_in_python.rst b/docsrc/user_guide/creating_torchscript_module_in_python.rst similarity index 100% rename from docsrc/tutorials/creating_torchscript_module_in_python.rst rename to docsrc/user_guide/creating_torchscript_module_in_python.rst diff --git a/docsrc/tutorials/getting_started_with_fx_path.rst b/docsrc/user_guide/getting_started_with_fx_path.rst similarity index 100% rename from docsrc/tutorials/getting_started_with_fx_path.rst rename to docsrc/user_guide/getting_started_with_fx_path.rst diff --git a/docsrc/tutorials/ptq.rst b/docsrc/user_guide/ptq.rst similarity index 100% rename from docsrc/tutorials/ptq.rst rename to docsrc/user_guide/ptq.rst diff --git a/docsrc/tutorials/runtime.rst b/docsrc/user_guide/runtime.rst similarity index 100% rename from docsrc/tutorials/runtime.rst rename to docsrc/user_guide/runtime.rst diff --git a/docsrc/tutorials/use_from_pytorch.rst b/docsrc/user_guide/use_from_pytorch.rst similarity index 100% rename from docsrc/tutorials/use_from_pytorch.rst rename to docsrc/user_guide/use_from_pytorch.rst diff --git a/docsrc/tutorials/using_dla.rst b/docsrc/user_guide/using_dla.rst similarity index 100% rename from docsrc/tutorials/using_dla.rst rename to docsrc/user_guide/using_dla.rst diff --git a/examples/README.rst b/examples/README.rst new file mode 100644 index 0000000000..7c21aad732 --- /dev/null +++ b/examples/README.rst @@ -0,0 +1,7 @@ +.. _torch_tensorrt_tutorials: + +Torch-TensorRT Tutorials +=========================== + +The user guide covers the basic concepts and usage of Torch-TensorRT. +We also provide a number of tutorials to explore specific usecases and advanced concepts diff --git a/examples/dynamo/README.rst b/examples/dynamo/README.rst index 69e7f404e2..fa863952e7 100644 --- a/examples/dynamo/README.rst +++ b/examples/dynamo/README.rst @@ -1,9 +1,10 @@ .. _torch_compile: -Dynamo Compile Examples -================ +Dynamo / ``torch.compile`` +---------------------------- -This document contains examples of usage of the `torch_tensorrt.dynamo.compile` API which integrates with `torch.compile` functionality +Torch-TensorRT provides a backend for the new ``torch.compile`` API released in PyTorch 2.0. In the following examples we describe +a number of ways you can leverage this backend to accelerate inference. * :ref:`torch_compile_resnet`: Compiling a ResNet model using the Torch Compile Frontend for ``torch_tensorrt.compile`` * :ref:`torch_compile_transformer`: Compiling a Transformer model using ``torch.compile`` From 8149db44d80a06f362630c706dee4d0a89a027cc Mon Sep 17 00:00:00 2001 From: Naren Dasan <1790613+narendasan@users.noreply.github.com> Date: Wed, 2 Aug 2023 10:35:25 -0600 Subject: [PATCH 14/17] Update pr-labels.yml --- .github/pr-labels.yml | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/.github/pr-labels.yml b/.github/pr-labels.yml index aa0f66931b..e07cb3f500 100644 --- a/.github/pr-labels.yml +++ b/.github/pr-labels.yml @@ -12,9 +12,11 @@ "component: conversion": - core/conversion/**/* - + - py/torch_tensorrt/dynamo/conversion/**/* + "component: converters": - core/conversion/converters/**/* + - py/torch_tensorrt/dynamo/conversion/impl/**/* "component: evaluators": - core/conversion/evaluators/**/* @@ -22,14 +24,22 @@ "component: fx": - py/torch_tensorrt/fx/**/* +"component: dynamo": + - py/torch_tensorrt/dynamo/**/* + +"component: torch_compile": + - py/torch_tensorrt/dynamo/backend/* + "component: partitioning": - core/partitioning/**/* "component: runtime": - core/runtime/**/* + - py/torch_tensorrt/dynamo/runtime/**/* "component: lowering": - core/lowering/**/* + - py/torch_tensorrt/dynamo/lowering/**/* "component: tests": - tests/**/* @@ -37,6 +47,8 @@ "component: build system": - WORKSPACE - BUILD + - pyproject.toml + - setup.py "documentation": - docs/**/* From 61e338e199c685bc3edb3a4400cf6f042bb6eabb Mon Sep 17 00:00:00 2001 From: Naren Dasan <1790613+narendasan@users.noreply.github.com> Date: Wed, 2 Aug 2023 10:40:24 -0600 Subject: [PATCH 15/17] Update code-owners.yml --- .github/code-owners.yml | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/.github/code-owners.yml b/.github/code-owners.yml index d83c4c31a0..a3799b6733 100644 --- a/.github/code-owners.yml +++ b/.github/code-owners.yml @@ -2,7 +2,8 @@ - "narendasan" "component: api [Python]": - - "narendasan" + - "gs-olive" + - "peri044" "component: api": - "narendasan" @@ -11,18 +12,23 @@ - "narendasan" "component: conversion": - - "narendasan" + - "apbose" - "peri044" "component: converters": - - "peri044" - - "bowang007" + - "apbose" + - "zewenli98" "component: core": - "narendasan" - "peri044" - "bowang007" +"component: dynamo": + - "narendasan" + - "gs-olive" + - "peri044" + "component: evaluators": - "narendasan" - "peri044" @@ -32,7 +38,7 @@ "component: lowering": - "peri044" - - "narendasan" + - "gs-olive" "component: partitioning": - "bowang007" @@ -43,6 +49,7 @@ "component: quantization": - "peri044" + - "bowang007" "component: runtime": - "narendasan" @@ -50,6 +57,10 @@ "component: tests": - "narendasan" +"component: torch_compile": + - "gs-olive" + - "narendasan" + "component: torchtrtc": - "narendasan" From 32701b587dfa50dee23e6a44afa540b4bac7114a Mon Sep 17 00:00:00 2001 From: George S <113141689+gs-olive@users.noreply.github.com> Date: Thu, 3 Aug 2023 10:39:35 -0700 Subject: [PATCH 16/17] ci: Add automatic GHA job to build + push Docker Container on `main` (#2129) --- .github/workflows/docker_builder.yml | 61 ++++++++++++++++++++++++++++ docker/Dockerfile | 1 + 2 files changed, 62 insertions(+) create mode 100644 .github/workflows/docker_builder.yml diff --git a/.github/workflows/docker_builder.yml b/.github/workflows/docker_builder.yml new file mode 100644 index 0000000000..71e7cb74d4 --- /dev/null +++ b/.github/workflows/docker_builder.yml @@ -0,0 +1,61 @@ +name: 'Torch-TensorRT Docker Build' + +# Apply workflow only to main branch +on: + push: + branches: + - main + - nightly + +# If pushes to main are made in rapid succession, +# cancel existing docker builds and use newer commits +concurrency: + group: ${{ github.workflow }}-${{ github.ref_name }} + cancel-in-progress: true + +jobs: + build: + runs-on: linux.2xlarge + + # Define key environment variables + # Container name is of the form torch_tensorrt: + env: + DOCKER_REGISTRY: ghcr.io/pytorch/tensorrt + CONTAINER_NAME: torch_tensorrt:${{ github.ref_name }} + + steps: + - name: Checkout repository + uses: actions/checkout@v3 + + - name: Log in to the Container registry + uses: docker/login-action@v2 + with: + registry: ${{ env.DOCKER_REGISTRY }} + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + # Automatically detect TensorRT and cuDNN default versions for Torch-TRT build + - name: Build Docker image + env: + DOCKER_TAG: ${{ env.DOCKER_REGISTRY }}/${{ env.CONTAINER_NAME }} + run: | + TRT_VERSION=$(python3 -c "import versions; versions.tensorrt_version()") + echo "TRT VERSION = ${TRT_VERSION}" + CUDNN_VERSION=$(python3 -c "import versions; versions.cudnn_version()") + echo "CUDNN VERSION = ${CUDNN_VERSION}" + + DOCKER_BUILDKIT=1 docker build --build-arg TENSORRT_VERSION=$TRT_VERSION --build-arg CUDNN_VERSION=$CUDNN_VERSION -f docker/Dockerfile --tag $DOCKER_TAG . + + - name: Push Docker image + env: + DOCKER_URL: ${{ env.DOCKER_REGISTRY }}/${{ env.CONTAINER_NAME }} + run: docker push $DOCKER_URL + + # Clean up all untagged containers in registry + - name: Container Registry Cleanup + uses: actions/delete-package-versions@v4 + with: + package-name: "tensorrt/torch_tensorrt" + package-type: container + min-versions-to-keep: 0 + delete-only-untagged-versions: True diff --git a/docker/Dockerfile b/docker/Dockerfile index aa3623f32b..ca01bf8e2d 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -61,6 +61,7 @@ FROM base as torch-tensorrt-builder-base ARG ARCH="x86_64" ARG TARGETARCH="amd64" +RUN apt-get update RUN apt-get install -y python3-setuptools RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/3bf863cc.pub From 0527edd9cbb138a367a40fd046b04919a96af222 Mon Sep 17 00:00:00 2001 From: George S <113141689+gs-olive@users.noreply.github.com> Date: Thu, 3 Aug 2023 10:47:15 -0700 Subject: [PATCH 17/17] chore: Add `pyyaml` import to GHA Docker job (#2170) --- .github/workflows/docker_builder.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/docker_builder.yml b/.github/workflows/docker_builder.yml index 71e7cb74d4..817bc87c82 100644 --- a/.github/workflows/docker_builder.yml +++ b/.github/workflows/docker_builder.yml @@ -39,6 +39,7 @@ jobs: env: DOCKER_TAG: ${{ env.DOCKER_REGISTRY }}/${{ env.CONTAINER_NAME }} run: | + python3 -m pip install pyyaml TRT_VERSION=$(python3 -c "import versions; versions.tensorrt_version()") echo "TRT VERSION = ${TRT_VERSION}" CUDNN_VERSION=$(python3 -c "import versions; versions.cudnn_version()")