Skip to content

Commit

Permalink
Fix problem with zeros at assistant options (#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
vhaldemar authored Nov 8, 2024
1 parent 628e166 commit 1e97ef1
Show file tree
Hide file tree
Showing 6 changed files with 693 additions and 34 deletions.
32 changes: 15 additions & 17 deletions src/yandex_cloud_ml_sdk/_assistants/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,11 @@ def _kwargs_from_message(cls, proto: ProtoAssistant, sdk: BaseSDK) -> dict[str,
kwargs = super()._kwargs_from_message(proto, sdk=sdk)

model = sdk.models.completions(proto.model_uri)
if max_tokens := proto.completion_options.max_tokens.value:
model = model.configure(max_tokens=max_tokens)
if temperature := proto.completion_options.temperature.value:
model = model.configure(temperature=temperature)
completion_options = proto.completion_options
if completion_options.HasField('max_tokens'):
model = model.configure(max_tokens=completion_options.max_tokens.value)
if completion_options.HasField('temperature'):
model = model.configure(temperature=completion_options.temperature.value)
kwargs['model'] = model

kwargs['tools'] = tuple(
Expand Down Expand Up @@ -80,23 +81,20 @@ async def _update(
expiration_policy=expiration_policy
)

model_uri: str | None = None
model_temperature: float | None = self.model.config.temperature
model_max_tokens: int | None = self.model.config.max_tokens
model_uri: UndefinedOr[str] | None = UNDEFINED

if is_defined(model):
if isinstance(model, str):
model_uri = self._sdk.models.completions(model).uri
elif isinstance(model, BaseGPTModel):
model_uri = model.uri
model_temperature = model.config.temperature
model_max_tokens = model.config.max_tokens
if not is_defined(temperature) and model.config.temperature is not None:
temperature = model.config.temperature
if not is_defined(max_tokens) and model.config.max_tokens is not None:
max_tokens = model.config.max_tokens
else:
raise TypeError('model argument must be str, GPTModel object either undefined')

model_temperature = get_defined_value(temperature, model_temperature)
model_max_tokens = get_defined_value(max_tokens, model_max_tokens)

request = UpdateAssistantRequest(
assistant_id=self.id,
name=get_defined_value(name, ''),
Expand All @@ -108,11 +106,11 @@ async def _update(
max_prompt_tokens=get_defined_value(max_prompt_tokens, None)
),
completion_options=get_completion_options(
temperature=model_temperature,
max_tokens=model_max_tokens
temperature=temperature,
max_tokens=max_tokens,
)
)
if model_uri:
if model_uri and is_defined(model_uri):
request.model_uri = model_uri

self._fill_update_mask(
Expand All @@ -125,8 +123,8 @@ async def _update(
'expiration_config.expiration_policy': expiration_policy,
'instruction': instruction,
'model_uri': model_uri,
'completion_options.temperature': model_temperature,
'completion_options.max_tokens': model_max_tokens,
'completion_options.temperature': temperature,
'completion_options.max_tokens': max_tokens,
'prompt_truncation_options.max_prompt_tokens': max_prompt_tokens,
}
)
Expand Down
15 changes: 6 additions & 9 deletions src/yandex_cloud_ml_sdk/_assistants/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,20 +48,17 @@ async def _create(
expiration_config = ExpirationConfig.coerce(ttl_days=ttl_days, expiration_policy=expiration_policy)

model_uri: str = ''
model_temperature: float | None = None
model_max_tokens: int | None = None
if isinstance(model, str):
model_uri = self._sdk.models.completions(model).uri
elif isinstance(model, BaseGPTModel):
model_uri = model.uri
model_temperature = model.config.temperature
model_max_tokens = model.config.max_tokens
if not is_defined(temperature) and model.config.temperature is not None:
temperature = model.config.temperature
if not is_defined(max_tokens) and model.config.max_tokens is not None:
max_tokens = model.config.max_tokens
else:
raise TypeError('model argument must be str, GPTModel object either undefined')

model_temperature = get_defined_value(temperature, model_temperature)
model_max_tokens = get_defined_value(max_tokens, model_max_tokens)

tools_: tuple[BaseTool, ...] = ()
if is_defined(tools):
# NB: mypy doesn't love abstract class used as TypeVar substitution here
Expand All @@ -79,8 +76,8 @@ async def _create(
),
model_uri=model_uri,
completion_options=get_completion_options(
temperature=model_temperature,
max_tokens=model_max_tokens
temperature=temperature,
max_tokens=max_tokens
),
tools=[tool._to_proto() for tool in tools_]
)
Expand Down
14 changes: 8 additions & 6 deletions src/yandex_cloud_ml_sdk/_assistants/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,29 @@

from yandex.cloud.ai.assistants.v1.common_pb2 import CompletionOptions, PromptTruncationOptions

from yandex_cloud_ml_sdk._types.misc import UndefinedOr, is_defined


def get_completion_options(
*,
temperature: float | None,
max_tokens: int | None,
temperature: UndefinedOr[float] | None,
max_tokens: UndefinedOr[int] | None,
) -> CompletionOptions:
options = CompletionOptions()
if temperature is not None:
if temperature is not None and is_defined(temperature):
options.temperature.value = temperature
if max_tokens is not None:
if max_tokens is not None and is_defined(max_tokens):
options.max_tokens.value = max_tokens

return options


def get_prompt_trunctation_options(
*,
max_prompt_tokens: int | None
max_prompt_tokens: UndefinedOr[int] | None
) -> PromptTruncationOptions:
options = PromptTruncationOptions()
if max_prompt_tokens is not None:
if max_prompt_tokens is not None and is_defined(max_prompt_tokens):
options.max_prompt_tokens.value = max_prompt_tokens

return options
2 changes: 1 addition & 1 deletion src/yandex_cloud_ml_sdk/_types/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def _update_from_proto(self, proto: Message) -> Self:

def _fill_update_mask(self, mask: FieldMask, fields: dict[str, Any]) -> None:
for key, value in fields.items():
if is_defined(value) and value is not None:
if is_defined(value):
mask.paths.append(key)


Expand Down
Loading

0 comments on commit 1e97ef1

Please sign in to comment.