Skip to content

Commit

Permalink
Minor update to docs
Browse files Browse the repository at this point in the history
  • Loading branch information
anwai98 committed Jul 24, 2024
1 parent a77a5e8 commit 903e2d8
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
2 changes: 1 addition & 1 deletion finetuning/specialists/resource-efficient/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ Fixed parameters:
- training and validation batch size - `1`
- minimum number of training "samples" for training on the provided images - min. **`50`** (oversample while min. 50 training samples not found) (this is done to avoid the exhaustive time constraints while training with only 1 training sample)
- learning rate: `1e-5`
- optimizer: `Adam`
- optimizer: `AdamW`
- lr scheduler: `ReduceLRonPlateau`
- early stopping: `10`
- patch shape: `(512, 512)`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def base_slurm_script(env_name, partition, cpu_mem, cpu_cores, gpu_name=None):
#SBATCH --mem {cpu_mem}
#SBATCH -p {partition}
#SBATCH -t 2-00:00:00
#SBATCH --job-name micro-sam-resource-efficient-finetuning
"""
if gpu_name is not None:
base_script += f"#SBATCH -G {gpu_name}:1 \n"
Expand All @@ -29,7 +30,7 @@ def base_slurm_script(env_name, partition, cpu_mem, cpu_cores, gpu_name=None):

def write_batch_sript(
env_name, partition, cpu_mem, cpu_cores, gpu_name, input_path, save_root,
model_type, n_objects, n_images, script_name, freeze, lora,
model_type, n_objects, n_images, script_name, freeze, lora, dry,
):
assert model_type in ["vit_t", "vit_b", "vit_t_lm", "vit_b_lm"]

Expand All @@ -52,7 +53,7 @@ def write_batch_sript(

# Whether to use LoRA-based finetuning
# NOTE: We use rank as 4 for LoRA.
if lora is not None:
if lora:
python_script += "--lora_rank 4 "

if gpu_name is not None:
Expand Down Expand Up @@ -82,8 +83,9 @@ def write_batch_sript(
with open(script_name, "w") as f:
f.write(batch_script)

cmd = ["sbatch", script_name]
subprocess.run(cmd)
if not dry:
cmd = ["sbatch", script_name]
subprocess.run(cmd)


def get_batch_script_names(tmp_folder):
Expand All @@ -102,6 +104,7 @@ def main(args):

all_n_images = [1, 2, 5, 10]
use_lora = [False, True]

for (n_images, lora) in itertools.product(all_n_images, use_lora):
# We cannot use LoRA and freeze the image encoder at the same time.
if lora and args.freeze == "image_encoder":
Expand All @@ -121,6 +124,7 @@ def main(args):
script_name=get_batch_script_names(tmp_folder),
freeze=args.freeze,
lora=lora,
dry=args.dry,
)


Expand All @@ -141,7 +145,9 @@ def main(args):
parser.add_argument("--partition", type=str, required=True, help="Name of the partition for running the job.")
parser.add_argument("--mem", type=str, required=True, help="Amount of cpu memory.")
parser.add_argument("-c", "--cpu_cores", type=int, required=True, help="Number of cpu cores.")
parser.add_argument("-G", "--gpu_name", type=str, default=None, help="The GPI resources used for finetuning.")
parser.add_argument("-G", "--gpu_name", type=str, default=None, help="The GPU resources used for finetuning.")

parser.add_argument("--dry", action="store_true", help="Whether to avoid submitting the configured scripts.")

args = parser.parse_args()
main(args)

0 comments on commit 903e2d8

Please sign in to comment.