Skip to content

Commit

Permalink
resolving review comments
Browse files Browse the repository at this point in the history
Signed-off-by: akhilsaivenkata <[email protected]>
  • Loading branch information
akhilsaivenkata committed Jul 16, 2024
1 parent e0ea84e commit 2ef15e1
Showing 1 changed file with 26 additions and 25 deletions.
51 changes: 26 additions & 25 deletions sdk/python/v1beta1/kubeflow/katib/api/katib_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,30 @@ def tune(
if max_failed_trial_count is not None:
experiment.spec.max_failed_trial_count = max_failed_trial_count

# Iterate over input parameters.
input_params = {}
experiment_params = []
trial_params = []
base_image = constants.BASE_IMAGE_TENSORFLOW,

for p_name, p_value in parameters.items():
# If input parameter value is Katib Experiment parameter sample.
if isinstance(p_value, models.V1beta1ParameterSpec):
# Wrap value for the function input.
input_params[p_name] = f"${{trialParameters.{p_name}}}"

# Add value to the Katib Experiment parameters.
p_value.name = p_name
experiment_params.append(p_value)

# Add value to the Katib Experiment's Trial parameters.
trial_params.append(
models.V1beta1TrialParameterSpec(name=p_name, reference=p_name)
)
else:
# Otherwise, add value to the function input.
input_params[p_name] = p_value

# Handle different types of objective input
if callable(objective):
# Validate objective function.
Expand All @@ -295,29 +319,6 @@ def tune(
# (e.g. in another function). We need to dedent the function code.
objective_code = textwrap.dedent(objective_code)

# Iterate over input parameters.
input_params = {}
experiment_params = []
trial_params = []
base_image = constants.BASE_IMAGE_TENSORFLOW,
for p_name, p_value in parameters.items():
# If input parameter value is Katib Experiment parameter sample.
if isinstance(p_value, models.V1beta1ParameterSpec):
# Wrap value for the function input.
input_params[p_name] = f"${{trialParameters.{p_name}}}"

# Add value to the Katib Experiment parameters.
p_value.name = p_name
experiment_params.append(p_value)

# Add value to the Katib Experiment's Trial parameters.
trial_params.append(
models.V1beta1TrialParameterSpec(name=p_name, reference=p_name)
)
else:
# Otherwise, add value to the function input.
input_params[p_name] = p_value

# Wrap objective function to execute it from the file. For example
# def objective(parameters):
# print(f'Parameters are {parameters}')
Expand Down Expand Up @@ -407,12 +408,12 @@ def tune(
trial_template = models.V1beta1TrialTemplate(
primary_container_name=constants.DEFAULT_PRIMARY_CONTAINER_NAME,
retain=retain_trials,
trial_parameters=trial_params if callable(objective) else [],
trial_parameters=trial_params,
trial_spec=trial_spec,
)

# Add parameters to the Katib Experiment.
experiment.spec.parameters = experiment_params if callable(objective) else []
experiment.spec.parameters = experiment_params

# Add Trial template to the Katib Experiment.
experiment.spec.trial_template = trial_template
Expand Down

0 comments on commit 2ef15e1

Please sign in to comment.