Skip to content

Commit

Permalink
Merge branch 'pytorch:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 authored Aug 3, 2023
2 parents 0976470 + 0527edd commit 8923872
Show file tree
Hide file tree
Showing 36 changed files with 1,102 additions and 99 deletions.
21 changes: 16 additions & 5 deletions .github/code-owners.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
- "narendasan"

"component: api [Python]":
- "narendasan"
- "gs-olive"
- "peri044"

"component: api":
- "narendasan"
Expand All @@ -11,18 +12,23 @@
- "narendasan"

"component: conversion":
- "narendasan"
- "apbose"
- "peri044"

"component: converters":
- "peri044"
- "bowang007"
- "apbose"
- "zewenli98"

"component: core":
- "narendasan"
- "peri044"
- "bowang007"

"component: dynamo":
- "narendasan"
- "gs-olive"
- "peri044"

"component: evaluators":
- "narendasan"
- "peri044"
Expand All @@ -32,7 +38,7 @@

"component: lowering":
- "peri044"
- "narendasan"
- "gs-olive"

"component: partitioning":
- "bowang007"
Expand All @@ -43,13 +49,18 @@

"component: quantization":
- "peri044"
- "bowang007"

"component: runtime":
- "narendasan"

"component: tests":
- "narendasan"

"component: torch_compile":
- "gs-olive"
- "narendasan"

"component: torchtrtc":
- "narendasan"

Expand Down
14 changes: 13 additions & 1 deletion .github/pr-labels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,43 @@

"component: conversion":
- core/conversion/**/*

- py/torch_tensorrt/dynamo/conversion/**/*

"component: converters":
- core/conversion/converters/**/*
- py/torch_tensorrt/dynamo/conversion/impl/**/*

"component: evaluators":
- core/conversion/evaluators/**/*

"component: fx":
- py/torch_tensorrt/fx/**/*

"component: dynamo":
- py/torch_tensorrt/dynamo/**/*

"component: torch_compile":
- py/torch_tensorrt/dynamo/backend/*

"component: partitioning":
- core/partitioning/**/*

"component: runtime":
- core/runtime/**/*
- py/torch_tensorrt/dynamo/runtime/**/*

"component: lowering":
- core/lowering/**/*
- py/torch_tensorrt/dynamo/lowering/**/*

"component: tests":
- tests/**/*

"component: build system":
- WORKSPACE
- BUILD
- pyproject.toml
- setup.py

"documentation":
- docs/**/*
Expand Down
63 changes: 9 additions & 54 deletions .github/workflows/docgen.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,75 +12,30 @@ jobs:
build-docs:
runs-on: ubuntu-20.04
container:
image: ghcr.io/pytorch/tensorrt/docgen:latest
credentials:
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
image: docker.io/pytorch/manylinux-builder:cuda12.1
steps:
- name: Reclaim space
run: |
rm -rf /usr/share/dotnet
rm -rf /opt/ghc
rm -rf "/usr/local/share/boost"
rm -rf /usr/local/cuda/cuda-*
- name: Install base deps
run: |
apt update
DEBIAN_FRONTEND=noninteractive apt install -y software-properties-common gcc git curl wget make zlib1g-dev bzip2 libbz2-dev lzma lzma-dev libreadline-dev libsqlite3-dev libssl-dev libffi-dev doxygen pandoc
git config --global --add safe.directory '*'
- name: Set up Python 3.10.12
uses: actions/setup-python@v4
with:
python-version: 3.10.12
- uses: actions/checkout@v3
with:
ref: ${{github.head_ref}}
- name: Install base deps
run: |
./packaging/pre_build_script.sh
- name: Get HEAD SHA
id: vars
run: echo "sha=$(git rev-parse --short HEAD)" >> $GITHUB_OUTPUT
- name: Get Bazel version
id: bazel_info
run: echo "version=$(cat .bazelversion)" >> $GITHUB_OUTPUT
- name: Install Bazel
run: |
wget -q https://github.com/bazelbuild/bazel/releases/download/${{ steps.bazel_info.outputs.version }}/bazel-${{ steps.bazel_info.outputs.version }}-linux-x86_64 -O /usr/bin/bazel
chmod a+x /usr/bin/bazel
- name: Install cudnn + tensorrt
run: |
apt-get update
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/cuda-ubuntu2004.pin
mv cuda-ubuntu2004.pin /etc/apt/preferences.d/cuda-repository-pin-600
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/7fa2af80.pub
apt-key adv --keyserver keyserver.ubuntu.com --recv-keys 536F8F1DE80F6A35
apt-key adv --keyserver keyserver.ubuntu.com --recv-keys A4B469963BF863CC
add-apt-repository "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/ /"
apt-get update
apt-get install -y libcudnn8 libcudnn8-dev
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/3bf863cc.pub
add-apt-repository "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/ /"
apt-get update
apt-get install -y libnvinfer8 libnvinfer-plugin8 libnvinfer-dev libnvinfer-plugin-dev
- name: Install Torch
run: |
python3 -m pip install -r py/requirements.txt --user
- name: Build Python Package
run: |
cp toolchains/ci_workspaces/WORKSPACE.x86_64 WORKSPACE
cd py
python3 -m pip install pip==21.3.1
echo $(which python3)
echo $(python3 -c 'import site; print(site.getsitepackages()[0])')
mkdir -p /opt/circleci/.pyenv/versions/3.9.4/lib/python3.9/
ln -s $(python3 -c 'import site; print(site.getsitepackages()[0])') /opt/circleci/.pyenv/versions/3.9.4/lib/python3.9/site-packages
python3 setup.py install
cd ..
cp toolchains/ci_workspaces/WORKSPACE.x86_64.cu121.release.rhel WORKSPACE
python -m pip install pip<=23
python -m pip install --pre -e . --extra-index-url https://download.pytorch.org/whl/nightly/cu121
- name: Generate New Docs
run: |
cd docsrc
python3 -m pip install -r requirements.txt
python3 -c "import torch_tensorrt; print(torch_tensorrt.__version__)"
python -m pip install -r requirements.txt
python -c "import torch_tensorrt; print(torch_tensorrt.__version__)"
make html
cd ..
- uses: stefanzweifel/git-auto-commit-action@v4
Expand Down
62 changes: 62 additions & 0 deletions .github/workflows/docker_builder.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
name: 'Torch-TensorRT Docker Build'

# Apply workflow only to main branch
on:
push:
branches:
- main
- nightly

# If pushes to main are made in rapid succession,
# cancel existing docker builds and use newer commits
concurrency:
group: ${{ github.workflow }}-${{ github.ref_name }}
cancel-in-progress: true

jobs:
build:
runs-on: linux.2xlarge

# Define key environment variables
# Container name is of the form torch_tensorrt:<branch_name>
env:
DOCKER_REGISTRY: ghcr.io/pytorch/tensorrt
CONTAINER_NAME: torch_tensorrt:${{ github.ref_name }}

steps:
- name: Checkout repository
uses: actions/checkout@v3

- name: Log in to the Container registry
uses: docker/login-action@v2
with:
registry: ${{ env.DOCKER_REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}

# Automatically detect TensorRT and cuDNN default versions for Torch-TRT build
- name: Build Docker image
env:
DOCKER_TAG: ${{ env.DOCKER_REGISTRY }}/${{ env.CONTAINER_NAME }}
run: |
python3 -m pip install pyyaml
TRT_VERSION=$(python3 -c "import versions; versions.tensorrt_version()")
echo "TRT VERSION = ${TRT_VERSION}"
CUDNN_VERSION=$(python3 -c "import versions; versions.cudnn_version()")
echo "CUDNN VERSION = ${CUDNN_VERSION}"
DOCKER_BUILDKIT=1 docker build --build-arg TENSORRT_VERSION=$TRT_VERSION --build-arg CUDNN_VERSION=$CUDNN_VERSION -f docker/Dockerfile --tag $DOCKER_TAG .
- name: Push Docker image
env:
DOCKER_URL: ${{ env.DOCKER_REGISTRY }}/${{ env.CONTAINER_NAME }}
run: docker push $DOCKER_URL

# Clean up all untagged containers in registry
- name: Container Registry Cleanup
uses: actions/delete-package-versions@v4
with:
package-name: "tensorrt/torch_tensorrt"
package-type: container
min-versions-to-keep: 0
delete-only-untagged-versions: True
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ docsrc/_build
docsrc/_notebooks
docsrc/_cpp_api
docsrc/_tmp
docsrc/tutorials/_rendered_examples
*.so
__pycache__
*.egg-info
Expand Down Expand Up @@ -67,4 +68,4 @@ bazel-tensorrt
*cifar-10-batches-py*
bazel-project
build/
wheelhouse/
wheelhouse/
6 changes: 3 additions & 3 deletions core/conversion/converters/impl/layer_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ auto layer_norm_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns()

/* Layer_Norm normalizes over last N dimensions.
normalizaed_shape could be (C,H,W), (H,W), or (W). */
auto normalized_shape = args[1].unwrapToIntList();
auto normalized_shape_vec = util::toVec(util::toDims(normalized_shape));
// This could be an IntList or ITensorList. We only need the size of this list.
auto normalized_shape = args[1].IValue()->toList();

// Unwrap eps.
auto eps = args[4].unwrapToDouble();
Expand All @@ -30,7 +30,7 @@ auto layer_norm_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns()

// Set up axis_ask for E[x].
uint32_t axis_mask = 0;
for (size_t i = 0; i < normalized_shape_vec.size(); i++) {
for (size_t i = 0; i < normalized_shape.size(); i++) {
axis_mask |= 1 << (shape.size() - i - 1);
}
LOG_DEBUG("Axis Mask for E[x]" << std::bitset<32>(axis_mask));
Expand Down
36 changes: 25 additions & 11 deletions core/conversion/converters/impl/quantization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,22 @@ namespace {

#if NV_TENSORRT_MAJOR > 7
// clang-format off

bool add_qdq(ConversionCtx *ctx, const torch::jit::Node* n, nvinfer1::ITensor* input, nvinfer1::ITensor* scale, std::string& opName) {
nvinfer1::IQuantizeLayer* quantize_layer = ctx->net->addQuantize(*input, *scale);
TORCHTRT_CHECK(quantize_layer, "Unable to create QuantizeLayer from node: " << *n);
quantize_layer->setAxis(0);

nvinfer1::IDequantizeLayer* dequantize_layer = ctx->net->addDequantize(*quantize_layer->getOutput(0), *scale);
TORCHTRT_CHECK(dequantize_layer, "Unable to create DequantizeLayer from node: " << *n);
dequantize_layer->setAxis(0);

auto qdq_out = ctx->AssociateValueAndTensor(n->outputs()[0], dequantize_layer->getOutput(0));
LOG_DEBUG("[" << opName << "]"<< " Output tensor shape: " << qdq_out->getDimensions());

return true;
}

auto quantization_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns()
.pattern({"aten::fake_quantize_per_tensor_affine(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
Expand All @@ -20,18 +36,16 @@ auto quantization_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns
auto scale = args[1].unwrapToScalar().to<float>();
auto scaleTensor = tensor_to_const(ctx, torch::tensor({scale}));
// Add and configure a QuantizeLayer.
nvinfer1::IQuantizeLayer* quantize_layer = ctx->net->addQuantize(*input, *scaleTensor);
quantize_layer->setAxis(0);

// Add and configure DequantizeLayer following a QuantizeLayer
nvinfer1::IDequantizeLayer* dequantize_layer = ctx->net->addDequantize(*quantize_layer->getOutput(0), *scaleTensor);
dequantize_layer->setAxis(0);

auto qdq_out = ctx->AssociateValueAndTensor(n->outputs()[0], dequantize_layer->getOutput(0));
LOG_DEBUG("[fake_quantize_per_tensor_affine] Output tensor shape: " << qdq_out->getDimensions());

return true;
std::string opName("aten::fake_quantize_per_tensor_affine");
return add_qdq(ctx, n, input, scaleTensor, opName);
}})
.pattern({"aten::fake_quantize_per_tensor_affine.tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto input = args[0].ITensorOrFreeze(ctx);
auto scale = args[1].ITensorOrFreeze(ctx);
std::string opName("aten::fake_quantize_per_tensor_affine.tensor_qparams");
return add_qdq(ctx, n, input, scale, opName);
}})
.pattern({"aten::fake_quantize_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
// This aten operator is generated from torch.fake_quantize_per_channel_affine op in Pytorch python API.
Expand Down
10 changes: 3 additions & 7 deletions core/conversion/converters/impl/select.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,7 @@ auto select_registrations TORCHTRT_UNUSED =
}

shuffle_layer->setReshapeDimensions(util::squeezeDims(
out->getDimensions(),
dim,
ctx->input_is_dynamic,
ctx->input_is_dynamic && (num_zero_dimensions > 0)));
out->getDimensions(), dim, false, ctx->input_is_dynamic && (num_zero_dimensions > 0)));
shuffle_layer->setName(util::node_info(n).c_str());
out = shuffle_layer->getOutput(0);
}
Expand Down Expand Up @@ -710,9 +707,8 @@ auto select_registrations TORCHTRT_UNUSED =
auto val_t_dtype = util::TRTDataTypeToScalarType(self->getType());

// Initialize contant tensor for fill with the inherited data type
auto val_t = tensor_to_const(
ctx, torch::full(util::toVec(self->getDimensions()), val, {torch::dtype(val_t_dtype)}));

std::vector<int64_t> singleton_dims(self->getDimensions().nbDims, 1);
auto val_t = tensor_to_const(ctx, torch::full(singleton_dims, val, {torch::dtype(val_t_dtype)}));
TORCHTRT_CHECK(
util::broadcastable(self->getDimensions(), mask->getDimensions(), /*multidirectional=*/false),
"Self and mask tensors are not broadcastable");
Expand Down
Loading

0 comments on commit 8923872

Please sign in to comment.