-
Notifications
You must be signed in to change notification settings - Fork 52
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added C++ and CUDA bindings to
tile_matmul
for 1.1.2 Release (#66)
* Added C++ and CUDA bindings for `memtorch.bh.crossbar.Tile.tile_matmul`. * Added `Eigen` integration with C++ and CUDA bindings. * Modularized C++ and CUDA `quantize` bindings.
- Loading branch information
1 parent
aa54b63
commit 055e036
Showing
42 changed files
with
1,155 additions
and
392 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
memtorch/cu/* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,5 +11,6 @@ MemTorch_cpu.egg-info/ | |
memtorch/examples/reproduce/*.csv | ||
tmp/ | ||
**/.pytest_cache/ | ||
.vscode/ | ||
.eggs/ | ||
.vscode/ | ||
tmp.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,3 @@ | ||
[submodule "memtorch/submodules/pytorch-playground"] | ||
path = memtorch/submodules/pytorch-playground | ||
url = https://github.com/coreylammie/pytorch-playground | ||
[submodule "memtorch/submodules/memtorch/submodules/pytorch-playground"] | ||
path = memtorch/submodules/memtorch/submodules/pytorch-playground | ||
url = https://github.com/coreylammie/pytorch-playground | ||
[submodule "memtorch/submodules/eigen"] | ||
path = memtorch/submodules/eigen | ||
url = https://gitlab.com/libeigen/eigen.git |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,14 @@ | ||
repos: | ||
- repo: https://github.com/psf/black | ||
rev: 20.8b1 | ||
rev: 21.6b0 | ||
hooks: | ||
- id: black | ||
language_version: python3 | ||
- repo: https://github.com/timothycrosley/isort | ||
rev: 5.8.0 | ||
rev: 5.9.1 | ||
hooks: | ||
- id: isort | ||
- repo: https://github.com/pocc/pre-commit-hooks | ||
rev: python | ||
rev: v1.1.1 | ||
hooks: | ||
- id: clang-format |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,23 @@ | ||
- Transitioned from TravisCI to GitHub Actions. | ||
## Added | ||
|
||
1. C++ and CUDA bindings for `memtorch.bh.crossbar.Tile.tile_matmul`. | ||
|
||
Using an NVIDIA GeForce GTX 1080, a tile shape of (25, 25), and two tensors of size (500, 500), the runtime of `tile_matmul` without quantization support is reduced by 2.45x and 5.48x, for CPU-bound and GPU-bound operation, respectively. With an ADC resolution of 4 bits and an overflow rate of 0.0, the runtime of `tile_matmul` with quantization support is reduced by 2.30x and 105.27x, for CPU-bound and GPU-bound operation, respectively. | ||
|
||
| Implementation | Runtime Without Quantization Support (s) | Runtime With Quantization Support (s) | | ||
| ---------------------- | ---------------------------------------- | ------------------------------------- | | ||
| Pure Python (Previous) | 6.917784 | 27.099764 | | ||
| C++ (CPU-bound) | 2.822265 | 11.736974 | | ||
| CUDA (GPU-bound) | 1.262861 | 0.2574267 | | ||
|
||
3. `Eigen` integration with C++ and CUDA bindings. | ||
4. Additional unit tests. | ||
|
||
## Enhanced | ||
|
||
1. Modularized C++ and CUDA `quantize` bindings. | ||
2. Enhanced functionality of `naive_progam` and added additional input arguments to dictate logic for stuck devices. | ||
|
||
## Fixed | ||
|
||
1. Removed debugging code from `naive_progam`. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,4 @@ | ||
include memtorch/cu/quantize/gpu.cuh | ||
graft memtorch/submodules/eigen | ||
include memtorch/cpp/*.h | ||
include memtorch/cu/*.h | ||
include memtorch/cu/*.cuh |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,57 +1,74 @@ | ||
# Wrapper for the pytorch-playground quant.py script | ||
import importlib | ||
import copy | ||
|
||
utee = importlib.import_module(".utee", "memtorch.submodules.pytorch-playground") | ||
import numpy as np | ||
import torch | ||
|
||
quant_methods = ["linear", "log", "tanh"] | ||
import memtorch | ||
import memtorch_bindings | ||
|
||
quant_methods = ["linear", "log"] | ||
|
||
def quantize(input, bits, overflow_rate, quant_method="linear", min=None, max=None): | ||
|
||
def quantize( | ||
tensor, | ||
quant, | ||
overflow_rate=0.0, | ||
quant_method=None, | ||
min=float("nan"), | ||
max=float("nan"), | ||
override_original=False, | ||
): | ||
"""Method to quantize a tensor. | ||
Parameters | ||
---------- | ||
input : tensor | ||
tensor : tensor | ||
Input tensor. | ||
bits : int | ||
Bit width. | ||
overflow_rate : float | ||
Overflow rate threshold for linear quanitzation. | ||
quant_method : str | ||
Quantization method. Must be in ['linear', 'log', 'tanh']. | ||
min : float | ||
Minimum value to clip values to. | ||
max : float | ||
Maximum value to clip values to. | ||
quant : int | ||
Bit width (if quant_method is not None) or the number of discrete quantization levels (if quant_method is None). | ||
overflow_rate : float, optional | ||
Overflow rate threshold for linear quantization. | ||
quant_method : str, optional | ||
Quantization method. Must be in quant_methods. | ||
min : float or tensor, optional | ||
Minimum value(s) to clip numbers to. | ||
max : float or tensor, optional | ||
Maximum value(s) to clip numbers to. | ||
override_original : bool, optional | ||
Whether to override the original tensor (True) or not (False). | ||
Returns | ||
------- | ||
tensor | ||
Quantized tensor. | ||
""" | ||
assert type(bits) == int and bits > 0, "bits must be an integer > 0." | ||
assert overflow_rate >= 0 and overflow_rate <= 1, "overflow_rate value invalid." | ||
assert quant_method in quant_methods, "quant_method is not valid." | ||
pass | ||
if min is not None: | ||
input = input.clip(min=min) | ||
|
||
if max is not None: | ||
input = input.clip(max=max) | ||
|
||
if torch.unique(input).numel() == 1: | ||
return input | ||
|
||
if quant_method == "linear": | ||
sf = bits - 1 - utee.compute_integral_part(input, overflow_rate) | ||
return utee.linear_quantize(input, sf, bits) | ||
elif quant_method == "log": | ||
log_abs_input = torch.log(torch.abs(input)) | ||
log_abs_input[log_abs_input == float("-inf")] = 1e-12 | ||
sf = bits - 1 - utee.compute_integral_part(log_abs_input, overflow_rate) | ||
return utee.log_linear_quantize(input, sf, bits) | ||
elif quant_method == "tanh": | ||
return utee.tanh_quantize(input, bits) | ||
device = torch.device("cpu" if "cpu" in memtorch.__version__ else "cuda") | ||
assert ( | ||
overflow_rate >= 0 and overflow_rate <= 1 | ||
), "overflow_rate must be >= 0 and <= 1." | ||
assert ( | ||
type(quant) == int and quant > 0 | ||
), "The bit width or number of discrete quantization levels must be a positive integer." | ||
if type(min) == int: | ||
min = float(min) | ||
if type(max) == int: | ||
max = float(max) | ||
if not override_original: | ||
tensor = copy.deepcopy(tensor) | ||
if quant_method is not None: | ||
assert quant_method in quant_methods, "quant_method is invalid." | ||
tensor = tensor.cpu() | ||
memtorch_bindings.quantize( | ||
tensor, | ||
bits=quant, | ||
overflow_rate=overflow_rate, | ||
quant_method=quant_methods.index(quant_method), | ||
min=min, | ||
max=max, | ||
) | ||
else: | ||
tensor = tensor.cpu() | ||
memtorch_bindings.quantize(tensor, n_quant_levels=quant, min=min, max=max) | ||
|
||
return tensor.to(device) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.