Skip to content

Commit

Permalink
[TextGeneration] Fix llama tokenizer (#1635) (#1636)
Browse files Browse the repository at this point in the history
* [TextGeneration] Fix llama tokenizer (#1635)

* add llama tokenizer fix

* fix generated string

* only run for streaming

* add TODO

---------

Co-authored-by: Dipika Sikka <[email protected]>

* Retire `flaky` in favour of `pytest-rerunfailures` (#1628)

* pick up another fix and bump up version to 1.7.1

---------

Co-authored-by: Dipika Sikka <[email protected]>
Co-authored-by: Dipika Sikka <[email protected]>
Co-authored-by: dbogunowicz <[email protected]>
Co-authored-by: dhuang <[email protected]>
  • Loading branch information
5 people authored Mar 18, 2024
1 parent 5fc5f73 commit 639c9f7
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 11 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def _parse_requirements_file(file_path):
"black==22.12.0",
"flake8>=3.8.3",
"isort>=5.7.0",
"flaky~=3.7.0",
"pytest-rerunfailures>=13.0",
"ndjson>=0.3.1",
"wheel>=0.36.2",
"pytest>=6.0.0",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def run(
else [],
"finished_reason": [],
"token_generator": token_generator,
"past_tokens_queue": copy.copy(tokens),
}

if kv_cache is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
from typing import Optional
from typing import List, Optional

import numpy

Expand Down Expand Up @@ -54,6 +54,33 @@ def _create_generated_text_output(
finished=False,
)

def _generate_streamed_text_from_past_tokens(
self, generated_tokens: numpy.ndarray, past_tokens_queue: List[int]
) -> str:
"""
An auxiliary method that helps to properly generate the streamed text.
Some models like llama2 and mistral are using LlamaTokenizer which is
based on SentencePiece tokenizer. This specific tokenizer doesn't seem
to output appropriate prefix spaces when decoding token by token.
One can make it work if the previously generated tokens are included.
This allows the tokenizer to figure out that the appropriate spaces
from last n consecutive tokens.
:param generated_tokens: the generated tokens from the engine
:param past_tokens_queue: the queue of last n tokens (n is the
original prompt length in tokens)
:return: the generated string
"""
string_from_n_tokens = self.tokenizer.decode(
past_tokens_queue, skip_special_tokens=True
)
past_tokens_queue.append(generated_tokens[0])
string_from_n_plus_1_tokens = self.tokenizer.decode(
past_tokens_queue, skip_special_tokens=True
)
past_tokens_queue.pop(0)
return [string_from_n_plus_1_tokens[len(string_from_n_tokens) :]]

def run(
self,
generated_tokens: numpy.ndarray,
Expand All @@ -64,9 +91,24 @@ def run(
):
generation_config = inference_state.current_state.get("generation_config")
generated_logits = generated_logits if generation_config.output_scores else None
sequences = self.tokenizer.batch_decode(
generated_tokens, skip_special_tokens=True
)

import transformers

# Fix for LLAMA-specific models when running streaming
# TODO: make streaming a conditional input to this operator. using inference
# state is a quick fix.
if isinstance(
self.tokenizer,
(transformers.LlamaTokenizer, transformers.LlamaTokenizerFast),
) and inference_state.current_state.get("streaming"):
past_tokens_queue = inference_state.current_state.get("past_tokens_queue")
sequences = self._generate_streamed_text_from_past_tokens(
generated_tokens, past_tokens_queue
)
else:
sequences = self.tokenizer.batch_decode(
generated_tokens, skip_special_tokens=True
)

try:
finished_reason = [f[-1] for f in finished_reason]
Expand Down
2 changes: 1 addition & 1 deletion src/deepsparse/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from deepsparse.generated_version import is_enterprise, is_release, splash, version
except Exception:
# otherwise, fall back to version info in this file
version = "1.7.0"
version = "1.7.1"
is_release = False
is_enterprise = False
splash = (
Expand Down
3 changes: 1 addition & 2 deletions tests/deepsparse/pipelines/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from concurrent.futures import ThreadPoolExecutor
from unittest import mock

import flaky
import pytest
from deepsparse.legacy.base_pipeline import BasePipeline

Expand Down Expand Up @@ -125,7 +124,7 @@ def test_pipeline_executor_num_workers():
assert executor._max_workers >= 1


@flaky.flaky(max_runs=2, min_passes=1)
@pytest.mark.flaky(reruns=2, min_passes=1)
@mock_engine(rng_seed=0)
def test_pipeline_call_is_async(engine_mock):
# attempts to verify that pipeline calls to engine are async
Expand Down
6 changes: 3 additions & 3 deletions tests/server/test_legacy_loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from collections import Counter
from unittest import mock

import pytest
from deepsparse.legacy.loggers import PythonLogger
from deepsparse.legacy.loggers.config import (
PipelineSystemLoggingConfig,
Expand All @@ -30,7 +31,6 @@
from deepsparse.server.deepsparse_server import DeepsparseServer
from deepsparse.server.helpers import server_logger_from_config
from fastapi.testclient import TestClient
from flaky import flaky
from tests.deepsparse.legacy.loggers.helpers import fetch_leaf_logger
from tests.helpers import find_free_port
from tests.test_data.server_test_data import SAMPLE_LOGS_DICT
Expand Down Expand Up @@ -106,7 +106,7 @@ def test_data_logging_from_predefined():
assert log == expected_log


@flaky(max_runs=4, min_passes=3)
@pytest.mark.flaky(reruns=4, min_passes=3)
def test_logging_only_system_info():
server_config = ServerConfig(
endpoints=[EndpointConfig(task=task, name=name, model=stub)],
Expand Down Expand Up @@ -195,7 +195,7 @@ def test_multiple_targets_logging():
)


@flaky(max_runs=3, min_passes=2)
@pytest.mark.flaky(reruns=3, min_passes=2)
def test_function_metric_with_target_loggers():
server_config = ServerConfig(
endpoints=[
Expand Down

0 comments on commit 639c9f7

Please sign in to comment.