Skip to content

Commit

Permalink
Samplers: Add dynamic temperature
Browse files Browse the repository at this point in the history
Does not work if max_temp is less than or equal to min_temp. Sampler
validation will have to be refactored in the future, so the dynamic
temperature check will also be changed.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
bdashore3 committed Jan 31, 2024
1 parent 3605067 commit 4a7b8b1
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 0 deletions.
25 changes: 25 additions & 0 deletions backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,14 @@ def check_unsupported_settings(self, **kwargs):
"installed ExLlamaV2 version."
)

if (unwrap(kwargs.get("max_temp"), 0.0)) > 0.0 and not hasattr(
ExLlamaV2Sampler.Settings, "max_temp"
):
logger.warning(
"DynaTemp parameters are not supported by the currently "
"installed ExLlamaV2 version."
)

def generate(self, prompt: str, **kwargs):
"""Generate a response to a prompt"""
generation = list(self.generate_gen(prompt, **kwargs))
Expand Down Expand Up @@ -579,6 +587,7 @@ def generate_gen(self, prompt: str, **kwargs):
# Sampler settings
gen_settings = ExLlamaV2Sampler.Settings()

# TODO: Migrate settings validation to different function
self.check_unsupported_settings(**kwargs)

# Apply settings
Expand All @@ -592,6 +601,22 @@ def generate_gen(self, prompt: str, **kwargs):
gen_settings.typical = unwrap(kwargs.get("typical"), 1.0)
gen_settings.mirostat = unwrap(kwargs.get("mirostat"), False)

# DynaTemp settings
if hasattr(gen_settings, "max_temp"):
max_temp = unwrap(kwargs.get("max_temp"), 0.0)
min_temp = unwrap(kwargs.get("min_temp"), 0.0)

if max_temp < min_temp or (
0 not in {min_temp, max_temp} and max_temp == min_temp
):
logger.warning(
"Max temp is less than or equal to min temp, skipping DynaTemp."
)

gen_settings.max_temp = max_temp
gen_settings.min_temp = min_temp
gen_settings.temp_exponent = kwargs.get("temp_exponent")

# Default tau and eta fallbacks don't matter if mirostat is off
gen_settings.mirostat_tau = unwrap(kwargs.get("mirostat_tau"), 1.5)
gen_settings.mirostat_eta = unwrap(kwargs.get("mirostat_eta"), 0.1)
Expand Down
16 changes: 16 additions & 0 deletions common/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,19 @@ class SamplerParams(BaseModel):
default_factory=lambda: get_default_sampler_value("temperature_last", False)
)

max_temp: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("max_temp", 0.0),
)

min_temp: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("min_temp", 0.0),
)

temp_exponent: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("temp_exponent", 1.0),
examples=[1.0],
)

top_k: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("top_k", 0)
)
Expand Down Expand Up @@ -157,6 +170,9 @@ def to_gen_params(self):
"logit_bias": self.logit_bias,
"temperature": self.temperature,
"temperature_last": self.temperature_last,
"min_temp": self.min_temp,
"max_temp": self.max_temp,
"temp_exponent": self.temp_exponent,
"top_k": self.top_k,
"top_p": self.top_p,
"top_a": self.top_a,
Expand Down
9 changes: 9 additions & 0 deletions sampler_overrides/sample_preset.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@ temperature:
temperature_last:
override: false
force: false
min_temp:
override: 0.0
force: false
max_temp:
override: 0.0
force: false
temp_exponent:
override: 0.0
force: false

# MARK: Alphabet soup
top_k:
Expand Down

0 comments on commit 4a7b8b1

Please sign in to comment.