Skip to content

Commit

Permalink
Add dynamic_temperature_low parameter (#5198)
Browse files Browse the repository at this point in the history
  • Loading branch information
oobabooga authored Jan 7, 2024
1 parent b8a0b3f commit 0d07b3a
Show file tree
Hide file tree
Showing 11 changed files with 30 additions and 88 deletions.
3 changes: 2 additions & 1 deletion docs/03 - Parameters Tab.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ For more information about the parameters, the [transformers documentation](http
* **mirostat_mode**: Activates the Mirostat sampling technique. It aims to control perplexity during sampling. See the [paper](https://arxiv.org/abs/2007.14966).
* **mirostat_tau**: No idea, see the paper for details. According to the Preset Arena, 8 is a good value.
* **mirostat_eta**: No idea, see the paper for details. According to the Preset Arena, 0.1 is a good value.
* **dynatemp**: Dynamic Temperature is activated when this parameter is greater than 0. The temperature range is determined by adding and subtracting dynatemp from the current temperature.
* **dynamic_temperature_low**: The lower bound for temperature in Dynamic Temperature. Only used when "dynamic_temperature" is checked.
* **dynamic_temperature**: Activates Dynamic Temperature. This modifies temperature to range between "dynamic_temperature_low" (minimum) and "temperature" (maximum), with an entropy-based scaling.
* **temperature_last**: Makes temperature the last sampler instead of the first. With this, you can remove low probability tokens with a sampler like min_p and then use a high temperature to make the model creative without losing coherency.
* **do_sample**: When unchecked, sampling is entirely disabled, and greedy decoding is used instead (the most likely token is always picked).
* **Seed**: Set the Pytorch seed to this number. Note that some loaders do not use Pytorch (notably llama.cpp), and others are not deterministic (notably ExLlama v1 and v2). For these loaders, the seed has no effect.
Expand Down
17 changes: 0 additions & 17 deletions extensions/dynatemp_with_range/README.md

This file was deleted.

51 changes: 0 additions & 51 deletions extensions/dynatemp_with_range/script.py

This file was deleted.

3 changes: 2 additions & 1 deletion extensions/openai/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
class GenerationOptions(BaseModel):
preset: str | None = Field(default=None, description="The name of a file under text-generation-webui/presets (without the .yaml extension). The sampling parameters that get overwritten by this option are the keys in the default_preset() function in modules/presets.py.")
min_p: float = 0
dynatemp: float = 0
dynamic_temperature: bool = False
dynamic_temperature_low: float = 0.1
top_k: int = 0
repetition_penalty: float = 1
repetition_penalty_range: int = 1024
Expand Down
9 changes: 6 additions & 3 deletions modules/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ def transformers_samplers():
return {
'temperature',
'temperature_last',
'dynatemp',
'dynamic_temperature',
'dynamic_temperature_low',
'top_p',
'min_p',
'top_k',
Expand Down Expand Up @@ -221,7 +222,8 @@ def transformers_samplers():
'ExLlamav2_HF': {
'temperature',
'temperature_last',
'dynatemp',
'dynamic_temperature',
'dynamic_temperature_low',
'top_p',
'min_p',
'top_k',
Expand Down Expand Up @@ -274,7 +276,8 @@ def transformers_samplers():
'llamacpp_HF': {
'temperature',
'temperature_last',
'dynatemp',
'dynamic_temperature',
'dynamic_temperature_low',
'top_p',
'min_p',
'top_k',
Expand Down
4 changes: 2 additions & 2 deletions modules/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ def default_preset():
return {
'temperature': 1,
'temperature_last': False,
'dynatemp': 0,
'dynamic_temperature': False,
'dynamic_temperature_low': 0.1,
'top_p': 1,
'min_p': 0,
'top_k': 0,
Expand Down Expand Up @@ -53,7 +54,6 @@ def load_preset(name):
for k in preset:
generate_params[k] = preset[k]

generate_params['temperature'] = min(1.99, generate_params['temperature'])
return generate_params


Expand Down
18 changes: 10 additions & 8 deletions modules/sampler_hijack.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


class TemperatureLogitsWarperWithDynatemp(LogitsWarper):
def __init__(self, temperature: float, dynatemp: float):
def __init__(self, temperature: float, dynamic_temperature: bool, dynamic_temperature_low: float):
if not isinstance(temperature, float) or not (temperature > 0):
except_msg = (
f"`temperature` (={temperature}) has to be a strictly positive float, otherwise your next token "
Expand All @@ -28,19 +28,20 @@ def __init__(self, temperature: float, dynatemp: float):
raise ValueError(except_msg)

self.temperature = temperature
self.dynatemp = dynatemp
self.dynamic_temperature = dynamic_temperature
self.dynamic_temperature_low = dynamic_temperature_low

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:

# Regular temperature
if self.dynatemp == 0:
if not self.dynamic_temperature:
scores = scores / self.temperature
return scores

# Dynamic temperature
else:
min_temp = max(0.0, self.temperature - self.dynatemp)
max_temp = self.temperature + self.dynatemp
min_temp = self.dynamic_temperature_low
max_temp = self.temperature
exponent_val = 1.0

# Convert logits to probabilities
Expand Down Expand Up @@ -283,15 +284,15 @@ def get_logits_warper_patch(self, generation_config):
generation_config.temperature = float(generation_config.temperature)

temperature = generation_config.temperature
if generation_config.dynatemp > 0:
if generation_config.dynamic_temperature:
# Make sure TemperatureLogitsWarper will be created by temporarily
# setting temperature to a value != 1.
generation_config.temperature = 1.1

warpers = self._get_logits_warper_old(generation_config)
for i in range(len(warpers)):
if warpers[i].__class__.__name__ == 'TemperatureLogitsWarper':
warpers[i] = TemperatureLogitsWarperWithDynatemp(temperature, generation_config.dynatemp)
warpers[i] = TemperatureLogitsWarperWithDynatemp(temperature, generation_config.dynamic_temperature, generation_config.dynamic_temperature_low)

warpers_to_add = LogitsProcessorList()
min_tokens_to_keep = 2 if generation_config.num_beams > 1 else 1
Expand Down Expand Up @@ -359,7 +360,8 @@ def get_logits_processor_patch(self, **kwargs):
def generation_config_init_patch(self, **kwargs):
self.__init___old(**kwargs)
self.min_p = kwargs.pop("min_p", 0.0)
self.dynatemp = kwargs.pop("dynatemp", 0.0)
self.dynamic_temperature = kwargs.pop("dynamic_temperature", False)
self.dynamic_temperature_low = kwargs.pop("dynamic_temperature_low", 0.1)
self.tfs = kwargs.pop("tfs", 1.0)
self.top_a = kwargs.pop("top_a", 0.0)
self.mirostat_mode = kwargs.pop("mirostat_mode", 0)
Expand Down
2 changes: 1 addition & 1 deletion modules/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def get_reply_from_output_ids(output_ids, state, starting_from=0):

def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
generate_params = {}
for k in ['max_new_tokens', 'temperature', 'temperature_last', 'dynatemp', 'top_p', 'min_p', 'top_k', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'typical_p', 'tfs', 'top_a', 'guidance_scale', 'penalty_alpha', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'do_sample', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'num_beams', 'length_penalty', 'early_stopping']:
for k in ['max_new_tokens', 'temperature', 'temperature_last', 'dynamic_temperature', 'dynamic_temperature_low', 'top_p', 'min_p', 'top_k', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'typical_p', 'tfs', 'top_a', 'guidance_scale', 'penalty_alpha', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'do_sample', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'num_beams', 'length_penalty', 'early_stopping']:
generate_params[k] = state[k]

if state['negative_prompt'] != '':
Expand Down
3 changes: 2 additions & 1 deletion modules/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ def list_interface_input_elements():
'seed',
'temperature',
'temperature_last',
'dynatemp',
'dynamic_temperature',
'dynamic_temperature_low',
'top_p',
'min_p',
'top_k',
Expand Down
3 changes: 2 additions & 1 deletion modules/ui_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def create_ui(default_preset):
shared.gradio['mirostat_mode'] = gr.Slider(0, 2, step=1, value=generate_params['mirostat_mode'], label='mirostat_mode', info='mode=1 is for llama.cpp only.')
shared.gradio['mirostat_tau'] = gr.Slider(0, 10, step=0.01, value=generate_params['mirostat_tau'], label='mirostat_tau')
shared.gradio['mirostat_eta'] = gr.Slider(0, 1, step=0.01, value=generate_params['mirostat_eta'], label='mirostat_eta')
shared.gradio['dynatemp'] = gr.Slider(0, 5, value=generate_params['dynatemp'], step=0.01, label='dynatemp')
shared.gradio['dynamic_temperature_low'] = gr.Slider(0.01, 5, value=generate_params['dynamic_temperature_low'], step=0.01, label='dynamic_temperature_low', info='Only used when dynamic_temperature is checked.')
shared.gradio['dynamic_temperature'] = gr.Checkbox(value=generate_params['dynamic_temperature'], label='dynamic_temperature')
shared.gradio['temperature_last'] = gr.Checkbox(value=generate_params['temperature_last'], label='temperature_last', info='Makes temperature the last sampler instead of the first.')
shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
shared.gradio['seed'] = gr.Number(value=shared.settings['seed'], label='Seed (-1 for random)')
Expand Down
5 changes: 3 additions & 2 deletions presets/Dynamic Temperature.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
temperature: 1.55
dynamic_temperature: true
dynamic_temperature_low: 0.1
temperature: 3
temperature_last: true
dynatemp: 1.45
min_p: 0.05

0 comments on commit 0d07b3a

Please sign in to comment.