Skip to content

Commit

Permalink
Remove generate_weights
Browse files Browse the repository at this point in the history
  • Loading branch information
olegkkruglov committed Aug 29, 2024
1 parent 692775f commit 7e0400e
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
)
from sklearnex.tests._utils_spmd import (
_generate_statistic_data,
_generate_weights,
_get_local_tensor,
_mpi_libs_and_gpu_available,
)
Expand Down Expand Up @@ -272,7 +271,7 @@ def test_incremental_basic_statistics_partial_fit_spmd_synthetic(

if weighted:
# Create weights array containing the weight for each sample in the data
weights = _generate_weights(n_samples, dtype=dtype)
weights = _generate_statistic_data(n_samples, dtype=dtype)
local_weights = _get_local_tensor(weights)
split_local_weights = np.array_split(local_weights, num_blocks)
split_weights = np.array_split(weights, num_blocks)
Expand Down
17 changes: 8 additions & 9 deletions sklearnex/tests/_utils_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,16 @@ def _generate_classification_data(
return X_train, X_test, y_train, y_test


def _generate_statistic_data(n_samples, n_features, dtype=np.float64, random_state=42):
def _generate_statistic_data(
n_samples, n_features=None, dtype=np.float64, random_state=42
):
# Generates statistical data
gen = np.random.default_rng(random_state)
data = gen.uniform(low=-0.3, high=+0.7, size=(n_samples, n_features)).astype(dtype)
data = gen.uniform(
low=-0.3,
high=+0.7,
size=(n_samples, n_features) if n_features is not None else (n_samples,),
).astype(dtype)
return data


Expand All @@ -111,13 +117,6 @@ def _generate_clustering_data(
return X_train, X_test


def _generate_weights(n_samples, dtype=np.float64, random_state=42):
# Generates weights
gen = np.random.default_rng(random_state)
weights = gen.uniform(low=-0.3, high=+0.7, size=(n_samples)).astype(dtype)
return weights


def _spmd_assert_allclose(spmd_result, batch_result, **kwargs):
"""Calls assert_allclose on spmd and batch results.
Expand Down

0 comments on commit 7e0400e

Please sign in to comment.