diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index ac939d44..52764e22 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -138,13 +138,25 @@ def progress(loaded_modules: int, total_modules: int, kwargs.get("rope_alpha"), self.calculate_rope_alpha(base_seq_len) ) + # Enable CFG if present + use_cfg = unwrap(kwargs.get("use_cfg"), False) if hasattr(ExLlamaV2Sampler.Settings, "cfg_scale"): - self.use_cfg = unwrap(kwargs.get("use_cfg"), False) - else: + self.use_cfg = use_cfg + elif use_cfg: logger.warning( "CFG is not supported by the currently installed ExLlamaV2 version." ) + # Enable fasttensors loading if present + use_fasttensors = unwrap(kwargs.get("fasttensors"), False) + if hasattr(ExLlamaV2Config, "fasttensors"): + self.config.fasttensors = use_fasttensors + elif use_fasttensors: + logger.warning( + "fasttensors is not supported by " + "the currently installed ExllamaV2 version." + ) + # Turn off flash attention if CFG is on # Workaround until batched FA2 is fixed in exllamav2 upstream self.config.no_flash_attn = ( @@ -668,6 +680,7 @@ def generate_gen(self, prompt: str, **kwargs): **vars(gen_settings), token_healing=token_healing, auto_scale_penalty_range=auto_scale_penalty_range, + generate_window=generate_window, add_bos_token=add_bos_token, ban_eos_token=ban_eos_token, stop_conditions=stop_conditions, diff --git a/common/sampling.py b/common/sampling.py index 53defcc1..8c28002d 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -17,7 +17,13 @@ class SamplerParams(BaseModel): """Common class for sampler params that are used in APIs""" max_tokens: Optional[int] = Field( - default_factory=lambda: get_default_sampler_value("max_tokens", 150) + default_factory=lambda: get_default_sampler_value("max_tokens", 150), + examples=[150], + ) + + generate_window: Optional[int] = Field( + default_factory=lambda: get_default_sampler_value("generate_window"), + examples=[512], ) stop: Optional[Union[str, List[str]]] = Field( @@ -29,7 +35,8 @@ class SamplerParams(BaseModel): ) temperature: Optional[float] = Field( - default_factory=lambda: get_default_sampler_value("temperature", 1.0) + default_factory=lambda: get_default_sampler_value("temperature", 1.0), + examples=[1.0], ) temperature_last: Optional[bool] = Field( @@ -41,7 +48,7 @@ class SamplerParams(BaseModel): ) top_p: Optional[float] = Field( - default_factory=lambda: get_default_sampler_value("top_p", 1.0) + default_factory=lambda: get_default_sampler_value("top_p", 1.0), examples=[1.0] ) top_a: Optional[float] = Field( @@ -65,7 +72,8 @@ class SamplerParams(BaseModel): ) repetition_penalty: Optional[float] = Field( - default_factory=lambda: get_default_sampler_value("repetition_penalty", 1.0) + default_factory=lambda: get_default_sampler_value("repetition_penalty", 1.0), + examples=[1.0], ) repetition_decay: Optional[int] = Field( @@ -77,11 +85,13 @@ class SamplerParams(BaseModel): ) mirostat_tau: Optional[float] = Field( - default_factory=lambda: get_default_sampler_value("mirostat_tau", 1.5) + default_factory=lambda: get_default_sampler_value("mirostat_tau", 1.5), + examples=[1.5], ) mirostat_eta: Optional[float] = Field( - default_factory=lambda: get_default_sampler_value("mirostat_eta", 0.3) + default_factory=lambda: get_default_sampler_value("mirostat_eta", 0.3), + examples=[0.3], ) add_bos_token: Optional[bool] = Field( @@ -89,7 +99,8 @@ class SamplerParams(BaseModel): ) ban_eos_token: Optional[bool] = Field( - default_factory=lambda: get_default_sampler_value("ban_eos_token", False) + default_factory=lambda: get_default_sampler_value("ban_eos_token", False), + examples=[False], ) logit_bias: Optional[Dict[int, float]] = Field( @@ -106,6 +117,7 @@ class SamplerParams(BaseModel): default_factory=lambda: get_default_sampler_value("typical", 1.0), validation_alias=AliasChoices("typical", "typical_p"), description="Aliases: typical_p", + examples=[1.0], ) penalty_range: Optional[int] = Field( @@ -122,6 +134,7 @@ class SamplerParams(BaseModel): default_factory=lambda: get_default_sampler_value("cfg_scale", 1.0), validation_alias=AliasChoices("cfg_scale", "guidance_scale"), description="Aliases: guidance_scale", + examples=[1.0], ) def to_gen_params(self): @@ -135,8 +148,9 @@ def to_gen_params(self): self.stop = [self.stop] return { - "stop": self.stop, "max_tokens": self.max_tokens, + "generate_window": self.generate_window, + "stop": self.stop, "add_bos_token": self.add_bos_token, "ban_eos_token": self.ban_eos_token, "token_healing": self.token_healing, diff --git a/config_sample.yml b/config_sample.yml index 89368acf..cf1ddb53 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -97,6 +97,9 @@ model: # WARNING: This flag disables Flash Attention! (a stopgap fix until it's fixed in upstream) #use_cfg: False + # Enables fasttensors to possibly increase model loading speeds (default: False) + #fasttensors: true + # Options for draft models (speculative decoding). This will use more VRAM! #draft: # Overrides the directory to look for draft (default: models) diff --git a/sampler_overrides/sample_preset.yml b/sampler_overrides/sample_preset.yml index 9c661a14..eae17ab4 100644 --- a/sampler_overrides/sample_preset.yml +++ b/sampler_overrides/sample_preset.yml @@ -18,6 +18,11 @@ token_healing: override: false force: false +# Commented out because the default is dynamically scaled +#generate_window: + #override: 512 + #force: false + # MARK: Temperature temperature: override: 1.0