Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
# Conflicts:
#	shap_select/select.py
  • Loading branch information
EgorKraevTransferwise committed Oct 2, 2024
2 parents 3c0c9d7 + 7325a88 commit 925a196
Show file tree
Hide file tree
Showing 6 changed files with 161 additions and 12 deletions.
34 changes: 34 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
name: Run tests on merge

on:
pull_request:
branches:
- main

jobs:
test:
runs-on: ubuntu-latest

steps:
# Checkout the repository code
- name: Checkout code
uses: actions/checkout@v3

# Set up Python environment
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.10' # Use the Python version your project needs

# Install dependencies
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install lightgbm xgboost catboost # Install the libraries required for tests
pip install pytest
# Run tests using pytest
- name: Run tests
run: |
pytest --maxfail=1 --disable-warnings
8 changes: 7 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,10 @@ hs_err_pid*
build/
out/
.gradle/
bin/
bin/

# Python cache files
__pycache__/
*.py[cod]
*.pyo
*.pyd
25 changes: 15 additions & 10 deletions shap_select/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def create_shap_features(


def binary_classifier_significance(
shap_features: pd.DataFrame, target: pd.Series
shap_features: pd.DataFrame, target: pd.Series, alpha: float
) -> pd.DataFrame:
"""
Fits a logistic regression model using the features from `shap_features` to predict the binary `target`.
Expand All @@ -70,7 +70,7 @@ def binary_classifier_significance(

# Fit the logistic regression model that will generate confidence intervals
logit_model = sm.Logit(target, shap_features_with_constant)
result = logit_model.fit_regularized(disp=False, alpha=1e-6)
result = logit_model.fit_regularized(disp=False, alpha=alpha)

# Extract the results
summary_frame = result.summary2().tables[1]
Expand All @@ -92,6 +92,7 @@ def binary_classifier_significance(
def multi_classifier_significance(
shap_features: Dict[Any, pd.DataFrame],
target: pd.Series,
alpha: float,
return_individual_significances: bool = False,
) -> (pd.DataFrame, list):
"""
Expand All @@ -111,7 +112,7 @@ def multi_classifier_significance(
# Iterate through each class and perform binary classification (one-vs-all)
for cls, feature_df in shap_features.items():
binary_target = (target == cls).astype(int)
significance_df = binary_classifier_significance(feature_df, binary_target)
significance_df = binary_classifier_significance(feature_df, binary_target, alpha)
significance_dfs.append(significance_df)

# Combine results into a single DataFrame with the max significance value for each feature
Expand Down Expand Up @@ -139,7 +140,7 @@ def multi_classifier_significance(


def regression_significance(
shap_features: pd.DataFrame, target: pd.Series
shap_features: pd.DataFrame, target: pd.Series, alpha: float
) -> pd.DataFrame:
"""
Fits a linear regression model using the features from `shap_features` to predict the continuous `target`.
Expand All @@ -158,7 +159,7 @@ def regression_significance(
"""
# Fit the linear regression model that will generate confidence intervals
ols_model = sm.OLS(target, shap_features)
result = ols_model.fit_regularized(alpha=1e-6, refit=True)
result = ols_model.fit_regularized(alpha=alpha, refit=True)

# Extract the results
summary_frame = result.summary2().tables[1]
Expand Down Expand Up @@ -186,6 +187,7 @@ def shap_features_to_significance(
shap_features: pd.DataFrame | List[pd.DataFrame],
target: pd.Series,
task: str,
alpha: float,
) -> pd.DataFrame:
"""
Determines the task (regression, binary, or multi-class classification) based on the target and calls the appropriate
Expand All @@ -205,11 +207,11 @@ def shap_features_to_significance(

# Call the appropriate function based on the task
if task == "regression":
result_df = regression_significance(shap_features, target)
result_df = regression_significance(shap_features, target, alpha)
elif task == "binary":
result_df = binary_classifier_significance(shap_features, target)
result_df = binary_classifier_significance(shap_features, target, alpha)
elif task == "multiclass":
result_df = multi_classifier_significance(shap_features, target)
result_df = multi_classifier_significance(shap_features, target, alpha)
else:
raise ValueError("`task` must be 'regression', 'binary', 'multiclass' or None.")

Expand All @@ -225,13 +227,14 @@ def iterative_shap_feature_reduction(
shap_features: pd.DataFrame | List[pd.DataFrame],
target: pd.Series,
task: str,
alpha: float=1e-6,
) -> pd.DataFrame:
collected_rows = [] # List to store the rows we collect during each iteration

features_left = True
while features_left:
# Call the original shap_features_to_significance function
significance_df = shap_features_to_significance(shap_features, target, task)
significance_df = shap_features_to_significance(shap_features, target, task, alpha)

# Find the feature with the lowest t-value
min_t_value_row = significance_df.loc[significance_df["t-value"].idxmin()]
Expand Down Expand Up @@ -268,6 +271,7 @@ def shap_select(
task: str | None = None,
threshold: float = 0.05,
return_extended_data: bool = False,
alpha: float = 1e-6,
) -> pd.DataFrame | Tuple[pd.DataFrame, pd.DataFrame]:
"""
Select features based on their SHAP values and statistical significance.
Expand All @@ -280,6 +284,7 @@ def shap_select(
- task (str | None): The task type ('regression', 'binary', or 'multiclass'). If None, it is inferred automatically.
- threshold (float): Significance threshold to select features. Default is 0.05.
- return_extended_data (bool): Whether to also return the shapley values dataframe(s) and some extra columns
- alpha (float): Controls the regularization strength for the regression
Returns:
- pd.DataFrame: A DataFrame containing the feature names, statistical significance, and a 'Selected' column
Expand Down Expand Up @@ -310,7 +315,7 @@ def shap_select(
shap_features = create_shap_features(tree_model, validation_df[feature_names])

# Compute statistical significance of each feature, recursively ablating
significance_df = iterative_shap_feature_reduction(shap_features, target, task)
significance_df = iterative_shap_feature_reduction(shap_features, target, task, alpha)

# Add 'Selected' column based on the threshold
significance_df["selected"] = (
Expand Down
2 changes: 1 addition & 1 deletion tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,4 +256,4 @@ def test_selected_column_values(model_type, data_fixture, task_type, request):
]
assert (
other_features_rows["selected"] > 0
).all(), "The Selected column must have positive values for features other than x7, x8, x9"
).all(), "The Selected column must have positive values for features other than x7, x8, x9"
49 changes: 49 additions & 0 deletions tests/test_shap_feature_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import pytest
import pandas as pd
import numpy as np
from shap_select.select import create_shap_features
import lightgbm as lgb


@pytest.fixture
def sample_data_binary():
"""Generate sample data for binary classification."""
np.random.seed(42)
X = pd.DataFrame(np.random.normal(size=(100, 5)), columns=[f"x{i}" for i in range(5)])
y = (X["x0"] > 0).astype(int)
return X, y


@pytest.fixture
def sample_data_multiclass():
"""Generate sample data for multiclass classification."""
np.random.seed(42)
X = pd.DataFrame(np.random.normal(size=(100, 5)), columns=[f"x{i}" for i in range(5)])
y = np.random.choice([0, 1, 2], size=100)
return X, y


def test_shap_feature_generation_binary(sample_data_binary):
"""Test SHAP feature generation for binary classification."""
X, y = sample_data_binary

model = lgb.LGBMClassifier()
model.fit(X, y)

shap_df = create_shap_features(model, X)
assert isinstance(shap_df, pd.DataFrame), "SHAP output should be a DataFrame"
assert shap_df.shape == X.shape, "SHAP output shape should match input data"
assert shap_df.isnull().sum().sum() == 0, "No missing values expected in SHAP output"


def test_shap_feature_generation_multiclass(sample_data_multiclass):
"""Test SHAP feature generation for multiclass classification."""
X, y = sample_data_multiclass

model = lgb.LGBMClassifier(objective="multiclass", num_class=3)
model.fit(X, y)

shap_df = create_shap_features(model, X, classes=[0, 1, 2])
assert isinstance(shap_df, dict), "SHAP output should be a dictionary for multiclass"
assert all(isinstance(v, pd.DataFrame) for v in shap_df.values()), "Each class should have a DataFrame"
assert shap_df[0].shape == X.shape, "SHAP output shape should match input data for each class"
55 changes: 55 additions & 0 deletions tests/test_significance_calculation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import pytest
import pandas as pd
import numpy as np
from shap_select.select import binary_classifier_significance, regression_significance
import statsmodels.api as sm


@pytest.fixture
def shap_features_binary():
"""Generate sample SHAP values for binary classification."""
np.random.seed(42)
return pd.DataFrame(np.random.normal(size=(100, 5)), columns=[f"x{i}" for i in range(5)])


@pytest.fixture
def binary_target():
"""Generate binary target."""
np.random.seed(42)
return pd.Series(np.random.choice([0, 1], size=100))


def test_binary_classifier_significance(shap_features_binary, binary_target):
"""Test significance calculation for binary classification."""
result_df = binary_classifier_significance(shap_features_binary, binary_target, alpha=1e-4)

assert "feature name" in result_df.columns, "Result should contain feature names"
assert "coefficient" in result_df.columns, "Result should contain coefficients"
assert "stat.significance" in result_df.columns, "Result should contain statistical significance"
assert result_df.shape[0] == shap_features_binary.shape[1], "Each feature should have a row in the output"
assert (result_df["stat.significance"] > 0).all(), "All p-values should be non-negative"


@pytest.fixture
def shap_features_regression():
"""Generate sample SHAP values for regression."""
np.random.seed(42)
return pd.DataFrame(np.random.normal(size=(100, 5)), columns=[f"x{i}" for i in range(5)])


@pytest.fixture
def regression_target():
"""Generate regression target."""
np.random.seed(42)
return pd.Series(np.random.normal(size=100))


def test_regression_significance(shap_features_regression, regression_target):
"""Test significance calculation for regression."""
result_df = regression_significance(shap_features_regression, regression_target, alpha=1e-6)

assert "feature name" in result_df.columns, "Result should contain feature names"
assert "coefficient" in result_df.columns, "Result should contain coefficients"
assert "stat.significance" in result_df.columns, "Result should contain statistical significance"
assert result_df.shape[0] == shap_features_regression.shape[1], "Each feature should have a row in the output"
assert (result_df["stat.significance"] > 0).all(), "All p-values should be non-negative"

0 comments on commit 925a196

Please sign in to comment.