Skip to content

Commit

Permalink
Fixed 'model' error
Browse files Browse the repository at this point in the history
  • Loading branch information
PandapowrTR authored Jan 18, 2024
1 parent 056b367 commit 82964ad
Showing 1 changed file with 47 additions and 29 deletions.
76 changes: 47 additions & 29 deletions Training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,12 +403,24 @@ def __saveParams(

values = {
"modelTrainValues": GridSearchTrain.__cleanDictForJson(modelTrainValues),
"modelSaveMethodValues": GridSearchTrain.__cleanDictForJson(modelSaveMethodValues),
"modelLoadMethodValues": GridSearchTrain.__cleanDictForJson(modelLoadMethodValues),
"hardwareSetupmethodValues": GridSearchTrain.__cleanDictForJson(hardwareSetupmethodValues),
"loadDatamethodValues": GridSearchTrain.__cleanDictForJson(loadDatamethodValues),
"splitDatamethodValues": GridSearchTrain.__cleanDictForJson(splitDatamethodValues),
"modelTestmethodValues": GridSearchTrain.__cleanDictForJson(modelTestmethodValues),
"modelSaveMethodValues": GridSearchTrain.__cleanDictForJson(
modelSaveMethodValues
),
"modelLoadMethodValues": GridSearchTrain.__cleanDictForJson(
modelLoadMethodValues
),
"hardwareSetupmethodValues": GridSearchTrain.__cleanDictForJson(
hardwareSetupmethodValues
),
"loadDatamethodValues": GridSearchTrain.__cleanDictForJson(
loadDatamethodValues
),
"splitDatamethodValues": GridSearchTrain.__cleanDictForJson(
splitDatamethodValues
),
"modelTestmethodValues": GridSearchTrain.__cleanDictForJson(
modelTestmethodValues
),
}
for key, value in values.copy().items():
try:
Expand Down Expand Up @@ -641,19 +653,19 @@ def __gridLoop(
# save last params
GridSearchTrain.__saveParams(
modelTrainmethod,
modelTrainValues,
modelTrainValues.copy(),
modelSaveMethod,
modelSaveMethodValues,
modelSaveMethodValues.copy(),
modelLoadMethod,
modelLoadMethodValues,
modelLoadMethodValues.copy(),
hardwareSetupmethod,
hardwareSetupmethodValues,
hardwareSetupmethodValues.copy(),
loadDatamethod,
loadDatamethodValues,
loadDatamethodValues.copy(),
splitDatamethod,
splitDatamethodValues,
splitDatamethodValues.copy(),
modelTestmethod,
modelTestmethodValues,
modelTestmethodValues.copy(),
copy.deepcopy(paramsIter),
copy.deepcopy(usedParams),
copy.deepcopy(data),
Expand Down Expand Up @@ -705,7 +717,10 @@ def __gridLoop(
modelTrainValues,
)
modelSaveMethodValues.update(
{"saveToPath": os.path.join(currentSaveFolder, "lastModel.h5")}
{
"saveToPath": os.path.join(currentSaveFolder, "lastModel.h5"),
"model": trainedModel,
}
)
modelSaveMethod(modelSaveMethodValues)
modelTestmethodValues.update({"model": trainedModel})
Expand All @@ -714,19 +729,19 @@ def __gridLoop(
# save usedParams
GridSearchTrain.__saveParams(
modelTrainmethod,
modelTrainValues,
modelTrainValues.copy(),
modelSaveMethod,
modelSaveMethodValues,
modelSaveMethodValues.copy(),
modelLoadMethod,
modelLoadMethodValues,
modelLoadMethodValues.copy(),
hardwareSetupmethod,
hardwareSetupmethodValues,
hardwareSetupmethodValues.copy(),
loadDatamethod,
loadDatamethodValues,
loadDatamethodValues.copy(),
splitDatamethod,
splitDatamethodValues,
splitDatamethodValues.copy(),
modelTestmethod,
modelTestmethodValues,
modelTestmethodValues.copy(),
copy.deepcopy(paramsIter),
copy.deepcopy(usedParams),
copy.deepcopy(data),
Expand Down Expand Up @@ -781,7 +796,10 @@ def __gridLoop(
"log.log",
)
modelSaveMethodValues.update(
{"saveToPath": os.path.join(saveToPath, "bestModel.h5")}
{
"saveToPath": os.path.join(saveToPath, "bestModel.h5"),
"model": bestModel,
}
)
modelSaveMethod(modelSaveMethodValues)
BurobotOther.zipFolder(
Expand All @@ -797,19 +815,19 @@ def __gridLoop(
# save usedParams
GridSearchTrain.__saveParams(
modelTrainmethod,
modelTrainValues,
modelTrainValues.copy(),
modelSaveMethod,
modelSaveMethodValues,
modelSaveMethodValues.copy(),
modelLoadMethod,
modelLoadMethodValues,
modelLoadMethodValues.copy(),
hardwareSetupmethod,
hardwareSetupmethodValues,
hardwareSetupmethodValues.copy(),
loadDatamethod,
loadDatamethodValues,
loadDatamethodValues.copy(),
splitDatamethod,
splitDatamethodValues,
splitDatamethodValues.copy(),
modelTestmethod,
modelTestmethodValues,
modelTestmethodValues.copy(),
copy.deepcopy(paramsIter),
copy.deepcopy(usedParams),
copy.deepcopy(data),
Expand Down

0 comments on commit 82964ad

Please sign in to comment.