diff --git a/lime/lime_base.py b/lime/lime_base.py index 63e62a0c..6b373c45 100644 --- a/lime/lime_base.py +++ b/lime/lime_base.py @@ -68,13 +68,30 @@ def forward_selection(self, data, labels, weights, num_features): used_features.append(best) return np.array(used_features) - def feature_selection(self, data, labels, weights, num_features, method): + def feature_selection(self, + datas, + labels, + weights, + num_features, + method, + feature_names=None, + use_feature_names=None): """Selects features for the model. see explain_instance_with_data to understand the parameters.""" + feature_index = np.array(range(datas.shape[1])) + if use_feature_names is not None: + use_feature_index = [] + for f in use_feature_names: + use_feature_index.append(feature_names.index(f)) + data = datas[:, use_feature_index] + feature_index = feature_index[use_feature_index] + else: + data = datas + if method == 'none': - return np.array(range(data.shape[1])) + return feature_index[list(range(data.shape[1]))] elif method == 'forward_selection': - return self.forward_selection(data, labels, weights, num_features) + return feature_index[self.forward_selection(data, labels, weights, num_features)] elif method == 'highest_weights': clf = Ridge(alpha=0, fit_intercept=True, random_state=self.random_state) @@ -105,14 +122,14 @@ def feature_selection(self, data, labels, weights, num_features, method): else: nnz_indexes = argsort_data[sdata - num_features:sdata][::-1] indices = weighted_data.indices[nnz_indexes] - return indices + return feature_index[list(indices)] else: weighted_data = coef * data[0] feature_weights = sorted( zip(range(data.shape[1]), weighted_data), key=lambda x: np.abs(x[1]), reverse=True) - return np.array([x[0] for x in feature_weights[:num_features]]) + return feature_index[list([x[0] for x in feature_weights[:num_features]])] elif method == 'lasso_path': weighted_data = ((data - np.average(data, axis=0, weights=weights)) * np.sqrt(weights[:, np.newaxis])) @@ -126,14 +143,15 @@ def feature_selection(self, data, labels, weights, num_features, method): if len(nonzero) <= num_features: break used_features = nonzero - return used_features + return feature_index[list(used_features)] elif method == 'auto': if num_features <= 6: n_method = 'forward_selection' else: n_method = 'highest_weights' - return self.feature_selection(data, labels, weights, - num_features, n_method) + return self.feature_selection(datas, labels, weights, + num_features, n_method, + feature_names, use_feature_names) def explain_instance_with_data(self, neighborhood_data, @@ -142,6 +160,8 @@ def explain_instance_with_data(self, label, num_features, feature_selection='auto', + feature_names=None, + use_feature_names=None, model_regressor=None): """Takes perturbed data, labels and distances, returns explanation. @@ -168,6 +188,7 @@ def explain_instance_with_data(self, Defaults to Ridge regression if None. Must have model_regressor.coef_ and 'sample_weight' as a parameter to model_regressor.fit() + use_feature_names: use features when select features. Returns: (intercept, exp, score, local_pred): @@ -185,7 +206,9 @@ def explain_instance_with_data(self, labels_column, weights, num_features, - feature_selection) + feature_selection, + feature_names, + use_feature_names) if model_regressor is None: model_regressor = Ridge(alpha=1, fit_intercept=True, random_state=self.random_state) diff --git a/lime/lime_tabular.py b/lime/lime_tabular.py index b1f6f94b..75be9b68 100644 --- a/lime/lime_tabular.py +++ b/lime/lime_tabular.py @@ -297,7 +297,8 @@ def explain_instance(self, num_features=10, num_samples=5000, distance_metric='euclidean', - model_regressor=None): + model_regressor=None, + use_feature_names=None): """Generates explanations for a prediction. First, we generate neighborhood data by randomly perturbing features @@ -451,7 +452,10 @@ def explain_instance(self, label, num_features, model_regressor=model_regressor, - feature_selection=self.feature_selection) + feature_selection=self.feature_selection, + feature_names=self.feature_names, + use_feature_names=use_feature_names + ) if self.mode == "regression": ret_exp.intercept[1] = ret_exp.intercept[0]