diff --git a/hannah/nas/search/search_old.py b/hannah/nas/search/search_old.py index 6806954d..c41410cd 100644 --- a/hannah/nas/search/search_old.py +++ b/hannah/nas/search/search_old.py @@ -44,27 +44,17 @@ from ...callbacks.optimization import HydraOptCallback from ...callbacks.summaries import MacSummaryCallback from ...utils import clear_outputs, common_callbacks, fullname -from .sampler.aging_evolution import AgingEvolutionSampler from ..graph_conversion import model_to_graph from ..parametrization import SearchSpace +from .sampler.aging_evolution import AgingEvolutionSampler msglogger = logging.getLogger(__name__) - - -import logging import pickle -import shutil +from typing import List, Optional, Union -from typing import Dict, Any, List, Union, Optional -from dataclasses import dataclass -from pathlib import Path import pandas as pd - -import numpy as np -import yaml - logger = logging.getLogger(__name__) @@ -92,7 +82,6 @@ def __init__(self, bounds, random_state): self.lambdas = random_state.uniform(low=0.0, high=1.0, size=len(self.bounds)) def __call__(self, values): - result = 0.0 for num, key in enumerate(self.bounds.keys()): if key in values: @@ -198,6 +187,14 @@ def tell_result(self, parameters, metrics): return None + def save(self): + history_file = self.output_folder / "history.yml" + history_file_tmp = history_file.with_suffix(".tmp") + + with history_file_tmp.open("w") as history_data: + yaml.dump(self.history, history_data) + shutil.move(history_file_tmp, history_file) + def load(self): # suffixes = [".pkl", ".yml"] suffixes = [".pkl"] @@ -458,8 +455,8 @@ def __init__( self.worklist = [] self.presample = presample - if self.config.get('backend', None): - self.backend = instantiate(self.config.backend) + if self.config.get("backend", None): + self.backend = instantiate(self.config.backend) else: self.backend = None