Skip to content

Commit

Permalink
add data loaders
Browse files Browse the repository at this point in the history
  • Loading branch information
gqcpm committed Nov 8, 2024
1 parent 16c8f26 commit 43c777b
Showing 1 changed file with 112 additions and 0 deletions.
112 changes: 112 additions & 0 deletions sleap_nn/config/trainer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,118 @@ def __attrs_post_init__(self):
else:
self.wandb = None

@classmethod
def from_dict(cls, config_dict: Dict[Text, Any]) -> "TrainerJobConfig":
"""Create a TrainingJobConfig from a Python dictionary.
Arguments:
config_dict: python dictionary that specifies the configurations.
Returns:
A TrainingJobConfig instance parsed from the python dictionary.
"""
# Convert dictionary to an OmegaConf config, then instantiate from it.
config = OmegaConf.create(config_dict)
return Omega.to_object(config, cls)

@classmethod
def from_json(cls, json_data: Text) -> "TrainingJobConfig":
"""Create TrainingJobConfig from JSON-formatted string
Arguments:
json_data: JSON-formatted string that specifies the configurations.
Returns:
A TrainingJobConfig instance parsed from the JSON text.
"""
config_dict = json.loads(json_data)
return cls.from_dict(config_dict)

@classmethod
def from_yaml(cls, yaml_data: Text) -> "TrainingJobConfig":
"""Create TrainingJobConfig from YAML-formatted string.
Arguments:
yaml_data: YAML-formatted string that specifies the configurations.
Returns:
A TrainingJobConfig instance parsed from the YAML text.
"""
config = OmegaConf.create(yaml_data)
return OmegaConf.to_object(config, cls)

@classmethod
def load_json(cls, filename: Text) -> "TrainingJobConfig":
"""Load a training job configuration from a json file.
Arguments:
filename: Path to a training job configuration JSON file or a directory
containing `"training_job.json"`.
Returns:
A TrainingJobConfig instance parsed from the json file.
"""
with open(filename, "r") as f:
json_data = f.read()
return cls.from_json(json_data)

@classmethod
def load_yaml(cls, filename:Text) -> "TrainingJobConfig":
"""Load a training job configuration from a yaml file.
Arguments:
filename: Path to a training job configuration YAML file or a directory
containing `"training_job.yaml"`.
Returns:
A TrainingJobConfig instance parsed from the YAML file.
"""
config = OmegaConf.load(filename)
return OmegaConf.to_object(config, cls)

def to_dict(self) -> DictConfig:
"""Serialize the configuration into a python dictionary format.
Returns:
The dictionary representation of the configuration.
"""
return OmegaConf.structured(self)

def to_json(self) -> str:
"""Serialize the configuration into JSON-encoded string format.
Returns:
The JSON encoded string representation of the configuration.
"""
config = self.to_dict()
return json.dumps(OmegaConf.to_container(config), indent=4)

def to_yaml(self) -> str:
"""Serialize the configuration into YAML-encoded string format.
Returns:
The YAML encoded string representation of the configuration.
"""
config = self.to_dict()
return OmegaConf.to_yaml(config)

def save_json(self, filename: Text):
"""Save the configuration to a JSON file.
Arguments:
filename: Path to save the training job file to.
"""
with open(filename, "w") as f:
f.write(self.to_json())

def save_yaml(self, filename: Text):
"""Save the configuration to a YAML file.
Arguments:
filename: Path to save the training job file to.
"""
with open(filename, "w") as f:
f.write(self.to_yaml())

@attrs.define
class DataLoaderConfig:
Expand Down

0 comments on commit 43c777b

Please sign in to comment.