Skip to content

Commit

Permalink
support for LLM.int8 (#6)
Browse files Browse the repository at this point in the history
- add support for summarization models with LLM.int8 (_per hf
implementation_)
- add tf32 support 
- update docs

---------

Signed-off-by: Peter <[email protected]>
  • Loading branch information
pszemraj authored Jan 31, 2023
1 parent 419eb3b commit 9108f66
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 33 deletions.
81 changes: 65 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,10 @@

> utility for using transformers summarization models on text docs
This package is to provides easy-to-use interfaces for using summarization models on text documents of arbitrary length. Currently implemented interfaces include a python API, CLI, and a shareable demo app.
This package provides easy-to-use interfaces for using summarization models on text documents of arbitrary length. Currently implemented interfaces include a python API, CLI, and a shareable demo app.

For details, explanations, and docs, see the [wiki](https://github.com/pszemraj/textsum/wiki)

⚠️ _This is a WIP, but general functionality is available_ ⚠️

---

- [textsum](#textsum)
Expand All @@ -37,6 +35,10 @@ For details, explanations, and docs, see the [wiki](https://github.com/pszemraj/
- [Python API](#python-api)
- [CLI](#cli)
- [Demo App](#demo-app)
- [Using Big Models](#using-big-models)
- [Reducing Memory Usage](#reducing-memory-usage)
- [EFficient Inference](#efficient-inference)
- [Parameters](#parameters)
- [Contributing](#contributing)
- [Roadmap](#roadmap)

Expand All @@ -51,7 +53,7 @@ Install using pip:
pip install textsum
```

The `textsum` package is now installed in your virtual environment. CLI commands/python API can be summarize text docs from anywhere. see the [Usage](#usage) section for more details.
The `textsum` package is now installed in your virtual environment. CLI commands/python API can summarize text docs from anywhere. see the [Usage](#usage) section for more details.

### Full Installation

Expand All @@ -66,7 +68,7 @@ pip install -e .[all]

### Additional Details

This package uses the [clean-text](https://github.com/jfilter/clean-text) python package, and like the "base" version of the package **does not** include the GPL-licensed `unidecode` dependency. If you want to use the `unidecode` package, install the package as an extra with `pip`:
This package uses the [clean-text](https://github.com/jfilter/clean-text) python package, and like the "base" version of the package, **does not** include the GPL-licensed `unidecode` dependency. If you want to use the `unidecode` package, install the package as an extra with `pip`:

```bash
pip install textsum[unidecode]
Expand All @@ -86,7 +88,7 @@ There are three ways to use this package:

To use the python API, import the `Summarizer` class and instantiate it. This will load the default model and parameters.

You can then use the `summarize_string` method to summarize a long string of text.
You can then use the `summarize_string` method to summarize a long text string.

```python
from textsum.summarize import Summarizer
Expand Down Expand Up @@ -115,15 +117,17 @@ textsum-dir /path/to/dir

The following options are available:

```
usage: textsum-dir [-h] [-o OUTPUT_DIR] [-m MODEL_NAME] [-batch BATCH_LENGTH] [-stride BATCH_STRIDE] [-nb NUM_BEAMS]
[-l2 LENGTH_PENALTY] [-r2 REPETITION_PENALTY] [--no_cuda] [-length_ratio MAX_LENGTH_RATIO] [-ml MIN_LENGTH]
[-enc_ngram ENCODER_NO_REPEAT_NGRAM_SIZE] [-dec_ngram NO_REPEAT_NGRAM_SIZE] [--no_early_stopping] [--shuffle]
[--lowercase] [-v] [-vv] [-lf LOGFILE]
```bash
usage: textsum-dir [-h] [-o OUTPUT_DIR] [-m MODEL_NAME] [--no_cuda] [--tf32] [-8bit]
[-batch BATCH_LENGTH] [-stride BATCH_STRIDE] [-nb NUM_BEAMS]
[-l2 LENGTH_PENALTY] [-r2 REPETITION_PENALTY]
[-length_ratio MAX_LENGTH_RATIO] [-ml MIN_LENGTH]
[-enc_ngram ENCODER_NO_REPEAT_NGRAM_SIZE] [-dec_ngram NO_REPEAT_NGRAM_SIZE]
[--no_early_stopping] [--shuffle] [--lowercase] [-v] [-vv] [-lf LOGFILE]
input_dir
```

For more information, run:
For more information, run the following:

```bash
textsum-dir --help
Expand All @@ -145,7 +149,51 @@ textsum-ui

This will start a local server that you can access in your browser & a shareable link will be printed to the console.

[^1]: The demo is currently minimal, but will be expanded in the future to accept other arguments and options.
[^1]: The demo is minimal but will be expanded to accept other arguments and options.

## Using Big Models

Summarization is a memory-intensive task, and the [default model is relatively small and efficient](https://huggingface.co/pszemraj/long-t5-tglobal-base-16384-book-summary) for long-form text summarization. If you want to use a bigger model, you can specify the `model_name_or_path` argument when instantiating the `Summarizer` class.

```python
summarizer = Summarizer(model_name_or_path='pszemraj/long-t5-tglobal-xl-16384-book-summary')
```

You can also use the `-m` argument when using the CLI:

```bash
textsum-dir /path/to/dir -m pszemraj/long-t5-tglobal-xl-16384-book-summary
```

### Reducing Memory Usage

#### EFficient Inference

Some methods of reducing memory usage _if you have compatible hardware_ include loading the model in 8-bit precision via [LLM.int8](https://arxiv.org/abs/2208.07339) and using the `--tf32` flag to use TensorFloat32 precision. See the [transformers docs](https://huggingface.co/docs/transformers/perf_infer_gpu_one#efficient-inference-on-a-single-gpu) for more details on how this works. Using LLM.int8 requires the [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) package, which can either be installed directly or via the `textsum[8bit]` extra:

```bash
pip install textsum[8bit]
```

To use these options, use the `-8bit` and `--tf32` flags when using the CLI:

```bash
textsum-dir /path/to/dir -8bit --tf32
```

Or in python, using the `load_in_8bit` argument:

```python
summarizer = Summarizer(load_in_8bit=True)
```

If using the python API, it's better to initiate tf32 yourself; see [here](https://huggingface.co/docs/transformers/perf_train_gpu_one#tf32) for how.

#### Parameters

Memory usage can also be reduced by adjusting the parameters for inference. This is discussed in detail in the [project wiki](https://github.com/pszemraj/textsum/wiki).

tl;dr for this README, you can use the `.set_inference_params()` and `.get_inference_params()` methods to adjust the parameters for inference.

---

Expand All @@ -160,10 +208,11 @@ See the [CONTRIBUTING.md](CONTRIBUTING.md) file for details on how to contribute
- [x] add CLI for summarization of all text files in a directory
- [x] python API for summarization of text docs
- [ ] add argparse CLI for UI demo
- [x] put on pypi
- [ ] optimum inference integration, LLM.int8 inference
- [x] put on PyPI
- [x] LLM.int8 inference
- [ ] optimum inference integration
- [ ] better documentation [in the wiki](https://github.com/pszemraj/textsum/wiki), details on improving performance (speed, quality, memory usage, etc.)
- [ ] improvements to OCR helper module
- [ ] improvements to the PDF OCR helper module

_Other ideas? Open an issue or PR!_

Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ install_requires =
nltk
torch
tqdm
transformers
transformers>=4.26.0
accelerate

[options.packages.find]
Expand Down
42 changes: 30 additions & 12 deletions src/textsum/cli.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""
cli.py - a module containing functions for the command line interface (to run the summarization on a directory of files)
usage: textsum-dir [-h] [-o OUTPUT_DIR] [-m MODEL_NAME] [-batch BATCH_LENGTH] [-stride BATCH_STRIDE] [-nb NUM_BEAMS]
[-l2 LENGTH_PENALTY] [-r2 REPETITION_PENALTY] [--no_cuda] [-length_ratio MAX_LENGTH_RATIO] [-ml MIN_LENGTH]
[-enc_ngram ENCODER_NO_REPEAT_NGRAM_SIZE] [-dec_ngram NO_REPEAT_NGRAM_SIZE] [--no_early_stopping] [--shuffle]
[--lowercase] [-v] [-vv] [-lf LOGFILE]
usage: textsum-dir [-h] [-o OUTPUT_DIR] [-m MODEL_NAME] [--no_cuda] [--tf32] [-8bit]
[-batch BATCH_LENGTH] [-stride BATCH_STRIDE] [-nb NUM_BEAMS]
[-l2 LENGTH_PENALTY] [-r2 REPETITION_PENALTY]
[-length_ratio MAX_LENGTH_RATIO] [-ml MIN_LENGTH]
[-enc_ngram ENCODER_NO_REPEAT_NGRAM_SIZE] [-dec_ngram NO_REPEAT_NGRAM_SIZE]
[--no_early_stopping] [--shuffle] [--lowercase] [-v] [-vv] [-lf LOGFILE]
input_dir
Summarize text files in a directory
Expand All @@ -23,7 +25,7 @@
from tqdm.auto import tqdm

from textsum.summarize import Summarizer
from textsum.utils import setup_logging
from textsum.utils import enable_tf32, setup_logging


def get_parser():
Expand Down Expand Up @@ -52,6 +54,24 @@ def get_parser():
default="pszemraj/long-t5-tglobal-base-16384-book-summary",
help="the name of the model to use for summarization",
)
parser.add_argument(
"--no_cuda",
action="store_true",
help="flag to not use cuda if available",
)
parser.add_argument(
"--tf32",
action="store_true",
dest="tf32",
help="enable tf32 data type for computation (requires ampere series GPU or newer)",
)
parser.add_argument(
"-8bit",
"--load_in_8bit",
action="store_true",
dest="load_in_8bit",
help="flag to load the model in 8 bit precision (requires bitsandbytes)",
)
parser.add_argument(
"-batch",
"--batch_length",
Expand Down Expand Up @@ -88,11 +108,6 @@ def get_parser():
default=2.5,
help="the repetition penalty to use for beam search",
)
parser.add_argument(
"--no_cuda",
action="store_true",
help="flag to not use cuda if available",
)
parser.add_argument(
"-length_ratio",
"--max_length_ratio",
Expand Down Expand Up @@ -170,9 +185,8 @@ def get_parser():
help="the directory containing the input files",
)

# if there are no args, print the help
if len(sys.argv) == 1:
parser.print_help(sys.stderr)
parser.print_help(sys.stderr) # no args, print help
sys.exit(1)

return parser
Expand Down Expand Up @@ -200,12 +214,16 @@ def main(args):
"do_sample": False,
}

if args.tf32:
enable_tf32() # enable tf32 for computation

summarizer = Summarizer(
model_name_or_path=args.model_name,
use_cuda=not args.no_cuda,
token_batch_length=args.batch_length,
batch_stride=args.batch_stride,
max_length_ratio=args.max_length_ratio,
load_in_8bit=args.load_in_8bit,
**params,
)

Expand Down
29 changes: 25 additions & 4 deletions src/textsum/summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
from tqdm.auto import tqdm
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

from textsum.utils import get_timestamp, postprocess_booksummary
from textsum.utils import (
check_bitsandbytes_available,
get_timestamp,
postprocess_booksummary,
)


class Summarizer:
Expand All @@ -27,6 +31,7 @@ def __init__(
token_batch_length: int = 2048,
batch_stride: int = 16,
max_length_ratio: float = 0.25,
load_in_8bit=False,
**kwargs,
):
"""
Expand All @@ -38,16 +43,32 @@ def __init__(
:param int token_batch_length: the amount of tokens to process in a batch, defaults to 2048
:param int batch_stride: the amount of tokens to stride the batch by, defaults to 16
:param float max_length_ratio: the ratio of the token_batch_length to use as the max_length for the model, defaults to 0.25
:param bool load_in_8bit: whether to load the model in 8bit precision (LLM.int8), defaults to False
:param kwargs: additional keyword arguments to pass to the model as inference parameters
"""
self.logger = logging.getLogger(__name__)

self.model_name_or_path = model_name_or_path
self.device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"
self.logger.debug(f"loading model {model_name_or_path} to {self.device}")
self.model = AutoModelForSeq2SeqLM.from_pretrained(
self.model_name_or_path,
).to(self.device)

if load_in_8bit:
logging.info("Loading model in 8-bit precision")

if not check_bitsandbytes_available():
raise ImportError(
"You must install bitsandbytes to load the model in 8-bit precision. Please run `pip install bitsandbytes` or `pip install textsum[8bit]`"
)
self.model = AutoModelForSeq2SeqLM.from_pretrained(
model_name_or_path,
load_in_8bit=load_in_8bit,
device_map="auto",
)
else:
self.model = AutoModelForSeq2SeqLM.from_pretrained(
self.model_name_or_path,
).to(self.device)

self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path)
self.is_general_attention_model = (
is_general_attention_model # TODO: add a check later
Expand Down
23 changes: 23 additions & 0 deletions src/textsum/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from datetime import datetime
from pathlib import Path

import torch

logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(message)s",
Expand Down Expand Up @@ -193,3 +195,24 @@ def postprocess_booksummary(text: str, custom_phrases: list = None) -> str:

text = text.replace(pr, "")
return text


def check_bitsandbytes_available():
"""
check_bitsandbytes_available - check if the bitsandbytes library is available
"""
try:
import bitsandbytes
except ImportError:
return False
return True


def enable_tf32():
"""
enable_tf32 - enables computation in tf32 precision. (requires ampere series GPU or newer)
See https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/ for details
"""
logging.debug("Enabling TF32 computation")
torch.backends.cuda.matmul.allow_tf32 = True

0 comments on commit 9108f66

Please sign in to comment.