diff --git a/src/estimagic/inference/bootstrap_samples.py b/src/estimagic/inference/bootstrap_samples.py index c1fdfa15e..bb631908a 100644 --- a/src/estimagic/inference/bootstrap_samples.py +++ b/src/estimagic/inference/bootstrap_samples.py @@ -26,52 +26,46 @@ def get_bootstrap_indices( """ n_obs = len(data) + probs = _get_probs_for_bootstrap_indices(data, weight_by, cluster_by) - if weight_by is None: + if cluster_by is None: + bootstrap_indices = list( + rng.choice(n_obs, size=(n_draws, n_obs), replace=True, p=probs) + ) + else: + clusters = data[cluster_by].unique() + drawn_clusters = rng.choice( + clusters, size=(n_draws, len(clusters)), replace=True, p=probs + ) - if cluster_by is None: - bootstrap_indices = list(rng.integers(0, n_obs, size=(n_draws, n_obs))) - else: - clusters = data[cluster_by].unique() - drawn_clusters = rng.choice( - clusters, size=(n_draws, len(clusters)), replace=True - ) + bootstrap_indices = _convert_cluster_ids_to_indices( + data[cluster_by], drawn_clusters + ) + + return bootstrap_indices - bootstrap_indices = _convert_cluster_ids_to_indices( - data[cluster_by], drawn_clusters - ) - else: +def _get_probs_for_bootstrap_indices(data, weight_by, cluster_by): + """Calculate probabilities for drawing bootstrap indices. + Args: + data (pandas.DataFrame): original dataset. + weight_by (str): column name of the variable with weights. + cluster_by (str): column name of the variable to cluster by. + + Returns: + list: numpy array with probabilities. + + """ + if weight_by is None: + probs = None + else: if cluster_by is None: probs = data[weight_by] / data[weight_by].sum() - bootstrap_indices = list( - rng.choice( - n_obs, - size=(n_draws, n_obs), - replace=True, - p=probs, - ) - ) else: - clusters_and_weights = ( - data.groupby(cluster_by)[weight_by].sum().reset_index() - ) - clusters = clusters_and_weights[cluster_by] - weights = clusters_and_weights[weight_by] - probs = weights / weights.sum() - drawn_clusters = rng.choice( - clusters, - size=(n_draws, len(clusters)), - replace=True, - p=probs, - ) - - bootstrap_indices = _convert_cluster_ids_to_indices( - data[cluster_by], drawn_clusters - ) - - return bootstrap_indices + cluster_weights = data.groupby(cluster_by, sort=False)[weight_by].sum() + probs = cluster_weights / cluster_weights.sum() + return probs def _convert_cluster_ids_to_indices(cluster_col, drawn_clusters): diff --git a/tests/inference/test_bootstrap_samples.py b/tests/inference/test_bootstrap_samples.py index 69c09221d..7b9283f6a 100644 --- a/tests/inference/test_bootstrap_samples.py +++ b/tests/inference/test_bootstrap_samples.py @@ -47,6 +47,23 @@ def test_get_bootstrap_indices_randomization_works_with_weights_and_clustering(d assert set(res[0]) != set(res[1]) +def test_get_bootstrap_indices_randomization_works_with_and_without_weights(data): + rng1 = get_rng(seed=12345) + rng2 = get_rng(seed=12345) + res1 = get_bootstrap_indices(data, n_draws=1, rng=rng1) + res2 = get_bootstrap_indices(data, weight_by="weights", n_draws=1, rng=rng2) + assert not np.array_equal(res1, res2) + + +def test_get_boostrap_indices_randomization_works_with_extreme_case(data): + rng = get_rng(seed=12345) + weights = np.zeros(900) + weights[0] = 1.0 + data["weights"] = weights + res = get_bootstrap_indices(data, weight_by="weights", n_draws=1, rng=rng) + assert len(np.unique(res)) == 1 + + def test_clustering_leaves_households_intact(data): rng = get_rng(seed=12345) indices = get_bootstrap_indices(data, cluster_by="hh", n_draws=1, rng=rng)[0]