Skip to content

Commit

Permalink
Merge branch 'f/legacy_nas' into 'main'
Browse files Browse the repository at this point in the history
Added save function

See merge request es/ai/hannah/hannah!357
  • Loading branch information
FrischAd committed Nov 24, 2023
2 parents 35343ba + 9d3d344 commit 5dce3b0
Showing 1 changed file with 12 additions and 15 deletions.
27 changes: 12 additions & 15 deletions hannah/nas/search/search_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,27 +44,17 @@
from ...callbacks.optimization import HydraOptCallback
from ...callbacks.summaries import MacSummaryCallback
from ...utils import clear_outputs, common_callbacks, fullname
from .sampler.aging_evolution import AgingEvolutionSampler
from ..graph_conversion import model_to_graph
from ..parametrization import SearchSpace
from .sampler.aging_evolution import AgingEvolutionSampler

msglogger = logging.getLogger(__name__)



import logging
import pickle
import shutil
from typing import List, Optional, Union

from typing import Dict, Any, List, Union, Optional
from dataclasses import dataclass
from pathlib import Path
import pandas as pd


import numpy as np
import yaml

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -92,7 +82,6 @@ def __init__(self, bounds, random_state):
self.lambdas = random_state.uniform(low=0.0, high=1.0, size=len(self.bounds))

def __call__(self, values):

result = 0.0
for num, key in enumerate(self.bounds.keys()):
if key in values:
Expand Down Expand Up @@ -198,6 +187,14 @@ def tell_result(self, parameters, metrics):

return None

def save(self):
history_file = self.output_folder / "history.yml"
history_file_tmp = history_file.with_suffix(".tmp")

with history_file_tmp.open("w") as history_data:
yaml.dump(self.history, history_data)
shutil.move(history_file_tmp, history_file)

def load(self):
# suffixes = [".pkl", ".yml"]
suffixes = [".pkl"]
Expand Down Expand Up @@ -458,8 +455,8 @@ def __init__(
self.worklist = []
self.presample = presample

if self.config.get('backend', None):
self.backend = instantiate(self.config.backend)
if self.config.get("backend", None):
self.backend = instantiate(self.config.backend)
else:
self.backend = None

Expand Down

0 comments on commit 5dce3b0

Please sign in to comment.