Skip to content

Commit

Permalink
Update text generation section of README. Fix line continuation in ex…
Browse files Browse the repository at this point in the history
…ample scripts.
  • Loading branch information
jaredcasper committed Oct 8, 2021
1 parent b31e129 commit a2cde4b
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 79 deletions.
30 changes: 10 additions & 20 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -426,33 +426,23 @@ WORLD_SIZE=$TENSOR_MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \
Several downstream tasks are described for both GPT and BERT models below. They can be run in distributed and model parallel modes with the same changes used in the training scripts.

## GPT Text Generation
`bash examples/generate_text.sh`

We generate text samples using largely the GPT pretraining script. Few changes need to make, such as we need to provide the path to the pretrained checkpoint, the length of the output samples, whether to generate texts unconditionally (`--num-samples` to denote how many samples to generate) or conditional (need to pass `--sample-input-file <filename>` where each line of the file will be used as the conditional texts). There are few optional parameters to play, e.g. `top-k`, `top-p`, or `greedy` (set top-k and top-p to 0) sampling..
We have included a simple REST server to use for text generation in `tools/run_text_generation_server.py`. You run it much like you would start a pretraining job, specifying an appropriate pretrained checkpoint. There are also few optional parameters: `temperature`, `top-k`, `top-p`, and `greedy`. See `--help` or the source file for more information. See [examples/run_text_generation_server_345M.sh](examples/run_text_generation_server_345M.sh) for an example of how to run the server.

Once the server is running you can use `tools/text_generation_cli.py` to query it, it takes one argument which is the host the server is running on.

<pre>
CHECKPOINT_PATH=checkpoints/gpt2_345m
VOCAB_FILE=gpt2-vocab.json
MERGE_FILE=gpt2-merges.txt
GPT_ARGS=&#60;same as those in <a href="#gpt-pretraining">GPT pretraining</a> above&#62;
tools/text_generation_cli.py localhost
</pre>

MAX_OUTPUT_SEQUENCE_LENGTH=1024
TEMPERATURE=1.0
TOP_P=0.9
NUMBER_OF_SAMPLES=2
OUTPUT_FILE=samples.json
You can also use CURL or any other tools to query the server directly:

python tools/generate_samples_gpt.py \
$GPT_ARGS \
--load $CHECKPOINT_PATH \
--out-seq-length $MAX_OUTPUT_SEQUENCE_LENGTH \
--temperature $TEMPERATURE \
--genfile $OUTPUT_FILE \
--num-samples $NUMBER_OF_SAMPLES \
--top_p $TOP_P \
--recompute
<pre>
curl 'http://localhost:5000/api' -X 'PUT' -H 'Content-Type: application/json; charset=UTF-8' -d '{"prompts":["Hello world"], "tokens_to_generate":1}'
</pre>

See [megatron/text_generation_server.py](megatron/text_generation_server.py) for more API options.

## GPT Evaluation
We include example scripts for GPT evaluation on WikiText perplexity evaluation and LAMBADA Cloze accuracy.

Expand Down
25 changes: 0 additions & 25 deletions examples/generate_text.sh

This file was deleted.

34 changes: 17 additions & 17 deletions examples/run_text_generation_server_345M.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,21 @@ MERGE_FILE=<Path to merges.txt (e.g. /gpt2-merges.txt)>

pip install flask-restful

python -m torch.distributed.launch $DISTRIBUTED_ARGS tools/run_text_generation_server.py /
--tensor-model-parallel-size 1 /
--pipeline-model-parallel-size 1 /
--num-layers 24 /
--hidden-size 1024 /
--load ${CHECKPOINT} /
--num-attention-heads 16 /
--max-position-embeddings 1024 /
--tokenizer-type GPT2BPETokenizer /
--fp16 /
--micro-batch-size 1 /
--seq-length 1024 /
--out-seq-length 1024 /
--temperature 1.0 /
--vocab-file $VOCAB_FILE /
--merge-file $MERGE_FILE /
--top_p 0.9 /
python -m torch.distributed.run $DISTRIBUTED_ARGS tools/run_text_generation_server.py \
--tensor-model-parallel-size 1 \
--pipeline-model-parallel-size 1 \
--num-layers 24 \
--hidden-size 1024 \
--load ${CHECKPOINT} \
--num-attention-heads 16 \
--max-position-embeddings 1024 \
--tokenizer-type GPT2BPETokenizer \
--fp16 \
--micro-batch-size 1 \
--seq-length 1024 \
--out-seq-length 1024 \
--temperature 1.0 \
--vocab-file $VOCAB_FILE \
--merge-file $MERGE_FILE \
--top_p 0.9 \
--seed 42
34 changes: 17 additions & 17 deletions examples/run_text_generation_server_345M_8_tensor_parallel.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,21 @@ MERGE_FILE=<Path to merges.txt (e.g. /gpt2-merges.txt)>

pip install flask-restful

python -m torch.distributed.launch $DISTRIBUTED_ARGS tools/run_text_generation_server.py /
--tensor-model-parallel-size 8 /
--pipeline-model-parallel-size 1 /
--num-layers 24 /
--hidden-size 1024 /
--load ${CHECKPOINT} /
--num-attention-heads 16 /
--max-position-embeddings 1024 /
--tokenizer-type GPT2BPETokenizer /
--fp16 /
--micro-batch-size 1 /
--seq-length 1024 /
--out-seq-length 1024 /
--temperature 1.0 /
--vocab-file $VOCAB_FILE /
--merge-file $MERGE_FILE /
--top_p 0.9 /
python -m torch.distributed.launch $DISTRIBUTED_ARGS tools/run_text_generation_server.py \
--tensor-model-parallel-size 8 \
--pipeline-model-parallel-size 1 \
--num-layers 24 \
--hidden-size 1024 \
--load ${CHECKPOINT} \
--num-attention-heads 16 \
--max-position-embeddings 1024 \
--tokenizer-type GPT2BPETokenizer \
--fp16 \
--micro-batch-size 1 \
--seq-length 1024 \
--out-seq-length 1024 \
--temperature 1.0 \
--vocab-file $VOCAB_FILE \
--merge-file $MERGE_FILE \
--top_p 0.9 \
--seed 42

0 comments on commit a2cde4b

Please sign in to comment.