Skip to content

Commit

Permalink
Merge pull request #63 from roboflow/fix/type_not_yet_supported_error
Browse files Browse the repository at this point in the history
fix for `RuntimeError: Type not yet supported: typing.Literal['sgd', 'adamw', 'adam']`
  • Loading branch information
onuralpszr authored Sep 24, 2024
2 parents 379b658 + 9d31dbc commit 3e0c681
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions maestro/trainer/models/florence_2/entrypoint.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import dataclasses
from typing import Annotated, Literal, Optional, Union
from typing import Annotated, Optional

import rich
import torch
Expand All @@ -16,7 +16,7 @@
DEFAULT_FLORENCE2_MODEL_REVISION,
DEVICE,
)
from maestro.trainer.models.florence_2.core import Configuration, LoraInitLiteral
from maestro.trainer.models.florence_2.core import Configuration
from maestro.trainer.models.florence_2.core import evaluate as florence2_evaluate
from maestro.trainer.models.florence_2.core import train as florence2_train

Expand Down Expand Up @@ -70,15 +70,15 @@ def train(
typer.Option("--epochs", help="Number of training epochs"),
] = 10,
optimizer: Annotated[
Literal["sgd", "adamw", "adam"],
str,
typer.Option("--optimizer", help="Optimizer to use for training"),
] = "adamw",
lr: Annotated[
float,
typer.Option("--lr", help="Learning rate for the optimizer"),
] = 1e-5,
lr_scheduler: Annotated[
Literal["linear", "cosine", "polynomial"],
str,
typer.Option("--lr_scheduler", help="Learning rate scheduler"),
] = "linear",
batch_size: Annotated[
Expand Down Expand Up @@ -110,15 +110,15 @@ def train(
typer.Option("--lora_dropout", help="Dropout probability for LoRA layers"),
] = 0.05,
bias: Annotated[
Literal["none", "all", "lora_only"],
str,
typer.Option("--bias", help="Which bias to train"),
] = "none",
use_rslora: Annotated[
bool,
typer.Option("--use_rslora/--no_use_rslora", help="Whether to use RSLoRA"),
] = True,
init_lora_weights: Annotated[
Union[bool, LoraInitLiteral],
str,
typer.Option("--init_lora_weights", help="How to initialize LoRA weights"),
] = "gaussian",
output_dir: Annotated[
Expand All @@ -138,19 +138,19 @@ def train(
device=torch.device(device),
cache_dir=cache_dir,
epochs=epochs,
optimizer=optimizer,
optimizer=optimizer, # type: ignore
lr=lr,
lr_scheduler=lr_scheduler,
lr_scheduler=lr_scheduler, # type: ignore
batch_size=batch_size,
val_batch_size=val_batch_size,
num_workers=num_workers,
val_num_workers=val_num_workers,
lora_r=lora_r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
bias=bias,
bias=bias, # type: ignore
use_rslora=use_rslora,
init_lora_weights=init_lora_weights,
init_lora_weights=init_lora_weights, # type: ignore
output_dir=output_dir,
metrics=metric_objects,
)
Expand Down

0 comments on commit 3e0c681

Please sign in to comment.