From 43c777bdc04d7c747fa929e047f7dea9935c1f62 Mon Sep 17 00:00:00 2001 From: gqcpm <63070177+gqcpm@users.noreply.github.com> Date: Fri, 8 Nov 2024 00:19:26 -0800 Subject: [PATCH] add data loaders --- sleap_nn/config/trainer_config.py | 112 ++++++++++++++++++++++++++++++ 1 file changed, 112 insertions(+) diff --git a/sleap_nn/config/trainer_config.py b/sleap_nn/config/trainer_config.py index e7ed15a4..5ffd55ba 100644 --- a/sleap_nn/config/trainer_config.py +++ b/sleap_nn/config/trainer_config.py @@ -81,6 +81,118 @@ def __attrs_post_init__(self): else: self.wandb = None + @classmethod + def from_dict(cls, config_dict: Dict[Text, Any]) -> "TrainerJobConfig": + """Create a TrainingJobConfig from a Python dictionary. + + Arguments: + config_dict: python dictionary that specifies the configurations. + + Returns: + A TrainingJobConfig instance parsed from the python dictionary. + """ + # Convert dictionary to an OmegaConf config, then instantiate from it. + config = OmegaConf.create(config_dict) + return Omega.to_object(config, cls) + + @classmethod + def from_json(cls, json_data: Text) -> "TrainingJobConfig": + """Create TrainingJobConfig from JSON-formatted string + + Arguments: + json_data: JSON-formatted string that specifies the configurations. + + Returns: + A TrainingJobConfig instance parsed from the JSON text. + """ + config_dict = json.loads(json_data) + return cls.from_dict(config_dict) + + @classmethod + def from_yaml(cls, yaml_data: Text) -> "TrainingJobConfig": + """Create TrainingJobConfig from YAML-formatted string. + + Arguments: + yaml_data: YAML-formatted string that specifies the configurations. + + Returns: + A TrainingJobConfig instance parsed from the YAML text. + """ + config = OmegaConf.create(yaml_data) + return OmegaConf.to_object(config, cls) + + @classmethod + def load_json(cls, filename: Text) -> "TrainingJobConfig": + """Load a training job configuration from a json file. + + Arguments: + filename: Path to a training job configuration JSON file or a directory + containing `"training_job.json"`. + + Returns: + A TrainingJobConfig instance parsed from the json file. + """ + with open(filename, "r") as f: + json_data = f.read() + return cls.from_json(json_data) + + @classmethod + def load_yaml(cls, filename:Text) -> "TrainingJobConfig": + """Load a training job configuration from a yaml file. + + Arguments: + filename: Path to a training job configuration YAML file or a directory + containing `"training_job.yaml"`. + + Returns: + A TrainingJobConfig instance parsed from the YAML file. + """ + config = OmegaConf.load(filename) + return OmegaConf.to_object(config, cls) + + def to_dict(self) -> DictConfig: + """Serialize the configuration into a python dictionary format. + + Returns: + The dictionary representation of the configuration. + """ + return OmegaConf.structured(self) + + def to_json(self) -> str: + """Serialize the configuration into JSON-encoded string format. + + Returns: + The JSON encoded string representation of the configuration. + """ + config = self.to_dict() + return json.dumps(OmegaConf.to_container(config), indent=4) + + def to_yaml(self) -> str: + """Serialize the configuration into YAML-encoded string format. + + Returns: + The YAML encoded string representation of the configuration. + """ + config = self.to_dict() + return OmegaConf.to_yaml(config) + + def save_json(self, filename: Text): + """Save the configuration to a JSON file. + + Arguments: + filename: Path to save the training job file to. + """ + with open(filename, "w") as f: + f.write(self.to_json()) + + def save_yaml(self, filename: Text): + """Save the configuration to a YAML file. + + Arguments: + filename: Path to save the training job file to. + """ + with open(filename, "w") as f: + f.write(self.to_yaml()) @attrs.define class DataLoaderConfig: