Skip to content

Commit

Permalink
Fixes and Graphs
Browse files Browse the repository at this point in the history
  • Loading branch information
CVelizR committed Jan 8, 2025
1 parent b647bb6 commit 06a99aa
Show file tree
Hide file tree
Showing 32 changed files with 628 additions and 178 deletions.
Binary file added DashAI/back.rar
Binary file not shown.
17 changes: 15 additions & 2 deletions DashAI/back/api/api_v1/endpoints/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,11 @@ async def get_run_by_id(
return run


@router.get("/plot/{run_id}")
@router.get("/plot/{run_id}/{plot_type}")
@inject
async def get_hyperparameter_optimization_plot(
run_id: int,
plot_type: int,
session_factory: sessionmaker = Depends(lambda: di["session_factory"]),
):
with session_factory() as db:
Expand All @@ -134,7 +135,14 @@ async def get_hyperparameter_optimization_plot(
detail="Run hyperaparameter plot not found",
)

plot_path = run_model[0].plot_path
if plot_type == 1:
plot_path = run_model[0].plot_history_path
elif plot_type == 2:
plot_path = run_model[0].plot_slice_path
elif plot_type == 3:
plot_path = run_model[0].plot_contour_path
else:
plot_path = run_model[0].plot_importance_path

with open(plot_path, "rb") as file:
plot = pickle.load(file)
Expand Down Expand Up @@ -189,6 +197,11 @@ async def upload_run(
parameters=params.parameters,
optimizer_name=params.optimizer_name,
optimizer_parameters=params.optimizer_parameters,
plot_history_path=params.plot_history_path,
plot_slice_path=params.plot_slice_path,
plot_contour_path=params.plot_contour_path,
plot_importance_path=params.plot_importance_path,
goal_metric=params.goal_metric,
name=params.name,
description=params.description,
)
Expand Down
5 changes: 5 additions & 0 deletions DashAI/back/api/api_v1/schemas/runs_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,9 @@ class RunParams(BaseModel):
parameters: dict
optimizer_name: str
optimizer_parameters: dict
plot_history_path: str
plot_slice_path: str
plot_contour_path: str
plot_importance_path: str
goal_metric: str
description: Union[str, None] = None
7 changes: 6 additions & 1 deletion DashAI/back/dependencies/database/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,12 @@ class Run(Base):
# optimizer
optimizer_name: Mapped[str] = mapped_column(String)
optimizer_parameters: Mapped[JSON] = mapped_column(JSON)
plot_path: Mapped[str] = mapped_column(String, nullable=True)
plot_history_path: Mapped[str] = mapped_column(String, nullable=True)
plot_slice_path: Mapped[str] = mapped_column(String, nullable=True)
plot_contour_path: Mapped[str] = mapped_column(String, nullable=True)
plot_importance_path: Mapped[str] = mapped_column(String, nullable=True)
# goal metrics
goal_metric: Mapped[str] = mapped_column(String)
# metrics
train_metrics: Mapped[JSON] = mapped_column(JSON, nullable=True)
test_metrics: Mapped[JSON] = mapped_column(JSON, nullable=True)
Expand Down
51 changes: 34 additions & 17 deletions DashAI/back/job/model_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,7 @@ def run(
) from e

try:
run.optimizer_parameters["metric"] = selected_metrics[
run.optimizer_parameters["metric"]
]
goal_metric = selected_metrics[run.goal_metric]
except Exception as e:
log.exception(e)
raise JobError(
Expand Down Expand Up @@ -269,30 +267,49 @@ def run(
x,
y,
run_optimizable_parameters,
goal_metric,
experiment.task_name,
)
model = optimizer.get_model()
# Generate hyperparameter plot
X, Y = optimizer.get_metrics()
plot = optimizer.create_plot(X, Y)
plot_filename = f"hyperparameter_optimization_plot_{run_id}.pickle"
plot_path = os.path.join(config["RUNS_PATH"], plot_filename)
with open(plot_path, "wb") as file:
pickle.dump(plot, file)
trials = optimizer.get_trials_values()
plot_filenames, plots = optimizer.create_plots(
trials, run_id, n_params=len(run_optimizable_parameters)
)
plot_paths = []
for filename, plot in zip(plot_filenames, plots):
plot_path = os.path.join(config["RUNS_PATH"], filename)
with open(plot_path, "wb") as file:
pickle.dump(plot, file)
plot_paths.append(plot_path)
except Exception as e:
log.exception(e)
raise JobError(
"Model training failed",
) from e
if run_optimizable_parameters != {}:
try:
run.plot_path = plot_path
db.commit()
except Exception as e:
log.exception(e)
raise JobError(
"Hyperparameter plot path saving failed",
) from e
if len(run_optimizable_parameters) >= 2:
try:
run.plot_history_path = plot_paths[0]
run.plot_slice_path = plot_paths[1]
run.plot_contour_path = plot_paths[2]
run.plot_importance_path = plot_paths[3]
db.commit()
except Exception as e:
log.exception(e)
raise JobError(
"Hyperparameter plot path saving failed",
) from e
else:
try:
run.plot_history_path = plot_paths[0]
run.plot_slice_path = plot_paths[1]
db.commit()
except Exception as e:
log.exception(e)
raise JobError(
"Hyperparameter plot path saving failed",
) from e
try:
run.set_status_as_finished()
db.commit()
Expand Down
Loading

0 comments on commit 06a99aa

Please sign in to comment.