Skip to content

Commit

Permalink
Support calculate feature importance on clf pipelines (#3883)
Browse files Browse the repository at this point in the history
  • Loading branch information
suhaibmujahid authored Dec 5, 2023
1 parent 0b022b3 commit 51d15d9
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 12 deletions.
38 changes: 31 additions & 7 deletions bugbug/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics import precision_recall_fscore_support
from sklearn.model_selection import cross_validate, train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import LabelEncoder
from tabulate import tabulate
from xgboost import XGBModel
Expand All @@ -35,6 +36,27 @@
logger = logging.getLogger(__name__)


def get_transformer_pipeline(pipeline: Pipeline) -> Pipeline:
"""Create a pipeline that contains only the transformers.
This will exclude any steps that do not have a transform method, such as a
sampler or estimator.
Args:
pipeline: the pipeline to extract the transformers from.
Returns:
a pipeline that contains only the transformers.
"""
return Pipeline(
[
(name, transformer)
for name, transformer in pipeline.steps
if hasattr(transformer, "transform")
]
)


def classification_report_imbalanced_values(
y_true, y_pred, labels, target_names=None, sample_weight=None, digits=2, alpha=0.1
):
Expand Down Expand Up @@ -399,8 +421,9 @@ def train(self, importance_cutoff=0.15, limit=None):

feature_names = self.get_human_readable_feature_names()
if self.calculate_importance and len(feature_names):
explainer = shap.TreeExplainer(self.clf)
shap_values = explainer.shap_values(X_train)
explainer = shap.TreeExplainer(self.clf.named_steps["estimator"])
_X_train = get_transformer_pipeline(self.clf).transform(X_train)
shap_values = explainer.shap_values(_X_train)

# In the binary case, sometimes shap returns a single shap values matrix.
if is_binary and not isinstance(shap_values, list):
Expand All @@ -413,7 +436,7 @@ def train(self, importance_cutoff=0.15, limit=None):

shap.summary_plot(
summary_plot_value,
to_array(X_train),
to_array(_X_train),
feature_names=feature_names,
class_names=self.class_names,
plot_type=summary_plot_type,
Expand Down Expand Up @@ -628,15 +651,16 @@ def classify(
pred_class = self.le.inverse_transform([pred_class_index])[0]

if background_dataset is None:
explainer = shap.TreeExplainer(self.clf)
explainer = shap.TreeExplainer(self.clf.named_steps["estimator"])
else:
explainer = shap.TreeExplainer(
self.clf,
self.clf.named_steps["estimator"],
to_array(background_dataset(pred_class)),
feature_perturbation="interventional",
)

shap_values = explainer.shap_values(to_array(X))
_X = get_transformer_pipeline(self.clf).transform(X)
shap_values = explainer.shap_values(to_array(_X))

# In the binary case, sometimes shap returns a single shap values matrix.
if len(classes[0]) == 2 and not isinstance(shap_values, list):
Expand All @@ -645,7 +669,7 @@ def classify(
important_features = self.get_important_features(
importance_cutoff, shap_values
)
important_features["values"] = X
important_features["values"] = _X

top_indexes = [
int(index)
Expand Down
13 changes: 8 additions & 5 deletions scripts/commit_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from scipy.stats import spearmanr

from bugbug import db, repository, test_scheduling
from bugbug.model import Model
from bugbug.model import Model, get_transformer_pipeline
from bugbug.models.regressor import RegressorModel
from bugbug.models.testfailure import TestFailureModel
from bugbug.utils import (
Expand Down Expand Up @@ -402,7 +402,10 @@ def load_user(phid):
)

def generate_feature_importance_data(self, probs, importance):
X_shap_values = shap.TreeExplainer(self.model.clf).shap_values(self.X)
_X = get_transformer_pipeline(self.clf).transform(self.X)
X_shap_values = shap.TreeExplainer(
self.clf.named_steps["estimator"]
).shap_values(_X)

pred_class = self.model.le.inverse_transform([probs[0].argmax()])[0]

Expand All @@ -414,8 +417,8 @@ def generate_feature_importance_data(self, probs, importance):
value = importance["importances"]["values"][0, int(feature_index)]

shap.summary_plot(
X_shap_values[:, int(feature_index)].reshape(self.X.shape[0], 1),
self.X[:, int(feature_index)].reshape(self.X.shape[0], 1),
X_shap_values[:, int(feature_index)].reshape(_X.shape[0], 1),
_X[:, int(feature_index)].reshape(_X.shape[0], 1),
feature_names=[""],
plot_type="layered_violin",
show=False,
Expand All @@ -427,7 +430,7 @@ def generate_feature_importance_data(self, probs, importance):
img.seek(0)
base64_img = base64.b64encode(img.read()).decode("ascii")

X = self.X[:, int(feature_index)]
X = _X[:, int(feature_index)]
y = self.y[X != 0]
X = X[X != 0]
spearman = spearmanr(X, y)
Expand Down

0 comments on commit 51d15d9

Please sign in to comment.