Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
olegkkruglov committed Aug 29, 2024
1 parent 1c0764d commit 692775f
Showing 1 changed file with 20 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def test_incremental_basic_statistics_fit_spmd_gold(dataframe, queue, weighted,
)

if weighted:
# Create weights array containing the weight for each sample in the data
weights = np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], dtype=dtype)
dpt_weights = _convert_to_dataframe(
weights, sycl_queue=queue, target_df=dataframe
Expand All @@ -88,7 +89,11 @@ def test_incremental_basic_statistics_fit_spmd_gold(dataframe, queue, weighted,
)

for option, _, _ in options_and_tests:
assert_allclose(getattr(incbs_spmd, option), getattr(incbs, option))
assert_allclose(
getattr(incbs_spmd, option),
getattr(incbs, option),
err_msg=f"Result for {option} is incorrect",
)


@pytest.mark.skipif(
Expand Down Expand Up @@ -131,6 +136,7 @@ def test_incremental_basic_statistics_partial_fit_spmd_gold(
split_local_data = np.array_split(local_data, num_blocks)

if weighted:
# Create weights array containing the weight for each sample in the data
weights = np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], dtype=dtype)
dpt_weights = _convert_to_dataframe(
weights, sycl_queue=queue, target_df=dataframe
Expand All @@ -156,7 +162,11 @@ def test_incremental_basic_statistics_partial_fit_spmd_gold(
incbs.fit(dpt_data, sample_weight=dpt_weights if weighted else None)

for option, _, _ in options_and_tests:
assert_allclose(getattr(incbs_spmd, option), getattr(incbs, option))
assert_allclose(
getattr(incbs_spmd, option),
getattr(incbs, option),
err_msg=f"Result for {option} is incorrect",
)


@pytest.mark.skipif(
Expand Down Expand Up @@ -200,6 +210,7 @@ def test_incremental_basic_statistics_single_option_partial_fit_spmd_gold(
split_local_data = np.array_split(local_data, num_blocks)

if weighted:
# Create weights array containing the weight for each sample in the data
weights = np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], dtype=dtype)
dpt_weights = _convert_to_dataframe(
weights, sycl_queue=queue, target_df=dataframe
Expand Down Expand Up @@ -260,6 +271,7 @@ def test_incremental_basic_statistics_partial_fit_spmd_synthetic(
split_data = np.array_split(data, num_blocks)

if weighted:
# Create weights array containing the weight for each sample in the data
weights = _generate_weights(n_samples, dtype=dtype)
local_weights = _get_local_tensor(weights)
split_local_weights = np.array_split(local_weights, num_blocks)
Expand Down Expand Up @@ -288,4 +300,9 @@ def test_incremental_basic_statistics_partial_fit_spmd_synthetic(
incbs.partial_fit(dpt_data, sample_weight=dpt_weights if weighted else None)

for option, _, _ in options_and_tests:
assert_allclose(getattr(incbs_spmd, option), getattr(incbs, option), atol=tol)
assert_allclose(
getattr(incbs_spmd, option),
getattr(incbs, option),
atol=tol,
err_msg=f"Result for {option} is incorrect",
)

0 comments on commit 692775f

Please sign in to comment.