Skip to content

Commit

Permalink
[JAX] Update use of JAX internals in preparation for JAX change.
Browse files Browse the repository at this point in the history
jax-ml/jax#17028 would break this code without this change.

PiperOrigin-RevId: 557540605
  • Loading branch information
hawkinsp authored and learned_optimization authors committed Aug 16, 2023
1 parent a49615f commit ae70b23
Showing 1 changed file with 21 additions and 14 deletions.
35 changes: 21 additions & 14 deletions learned_optimization/setup_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,25 @@
import jax
from learned_optimization import filesystem

flags.DEFINE_multi_string("gin_bindings", None,
"Newline separated list of Gin parameter bindings.")
flags.DEFINE_multi_string(
"gin_bindings", None, "Newline separated list of Gin parameter bindings."
)

flags.DEFINE_multi_string("gin_import", None, "List of modules to import")

flags.DEFINE_multi_string("config_file", None,
"List of paths to the config files for Gin.")
flags.DEFINE_multi_string(
"config_file", None, "List of paths to the config files for Gin."
)

flags.DEFINE_integer("task", 0, "Task / index of the replica for this job.")

flags.DEFINE_string("train_log_dir", None,
"Training directory to save summaries/checkpoints.")
flags.DEFINE_string(
"train_log_dir", None, "Training directory to save summaries/checkpoints."
)

flags.DEFINE_string("train_log_dir_suffix", None,
"suffix to add to train_log_dir path.")
flags.DEFINE_string(
"train_log_dir_suffix", None, "suffix to add to train_log_dir path."
)

FLAGS = flags.FLAGS

Expand Down Expand Up @@ -72,7 +76,7 @@ def parse_and_set_gin_config(finalize: bool, skip_unknown: bool):
split = g.split("=")
key, value = split[0], "=".join(split[1:])
new_v = value.strip()
if new_v[0:2] in ["\"@"]:
if new_v[0:2] in ['"@']:
new_v = new_v[1:-1] # strip quotes
FLAGS.gin_bindings[i] = key.strip() + "=" + new_v

Expand All @@ -91,9 +95,11 @@ def parse_and_set_gin_config(finalize: bool, skip_unknown: bool):



def setup_experiment(gin_finalize: bool = True,
gin_skip_unknown: bool = True,
make_dir: bool = False) -> Optional[str]:
def setup_experiment(
gin_finalize: bool = True,
gin_skip_unknown: bool = True,
make_dir: bool = False,
) -> Optional[str]:
"""Setup an experiment.
This function manages flags ensuring gin flags are parsed correctly,
Expand All @@ -118,8 +124,9 @@ def setup_experiment(gin_finalize: bool = True,
filesystem.make_dirs(FLAGS.train_log_dir)

if FLAGS.train_log_dir:
logging.info("Setup experiment! Training directory located: %s",
FLAGS.train_log_dir)
logging.info(
"Setup experiment! Training directory located: %s", FLAGS.train_log_dir
)
return FLAGS.train_log_dir
else:
logging.info("Setup experiment! No training directory specified")
Expand Down

0 comments on commit ae70b23

Please sign in to comment.