Skip to content

Commit

Permalink
another minor change
Browse files Browse the repository at this point in the history
  • Loading branch information
voorhs committed Oct 1, 2024
1 parent c2f9218 commit aa07a99
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
2 changes: 1 addition & 1 deletion autointent/context/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,5 @@ def get_inference_config(self) -> dict[str, Any]:
"seed": self.seed,
"db_dir": self.vector_index.db_dir,
},
"modules": self.optimization_info.get_best_modules()
"nodes_configs": self.optimization_info.get_best_trials()
}
9 changes: 4 additions & 5 deletions autointent/context/optimization_info/optimization_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,16 @@ def dump_evaluation_results(self):
"configs": self.trials.model_dump(),
}

def get_best_modules(self) -> list[dict[str, Any]]:
def get_best_trials(self) -> list[dict[str, Any]]:
node_types = ["regexp", "retrieval", "scoring", "prediction"]
trial_ids = [self._get_best_trial_idx(node_type) for node_type in node_types]
res = []
res = {nt: {} for nt in node_types}
for idx, node_type in zip(trial_ids, node_types, strict=True):
if idx is None:
continue
trial = self.trials[node_type][idx]
res.append({
"node_type": node_type,
res[node_type] = {
"module_type": trial.module_type,
"module_params": trial.module_params
})
}
return res

0 comments on commit aa07a99

Please sign in to comment.