diff --git a/bugbug/model.py b/bugbug/model.py index 08e474b2f9..ed73b955e4 100644 --- a/bugbug/model.py +++ b/bugbug/model.py @@ -22,6 +22,7 @@ from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics import precision_recall_fscore_support from sklearn.model_selection import cross_validate, train_test_split +from sklearn.pipeline import Pipeline from sklearn.preprocessing import LabelEncoder from tabulate import tabulate from xgboost import XGBModel @@ -35,6 +36,27 @@ logger = logging.getLogger(__name__) +def get_transformer_pipeline(pipeline: Pipeline) -> Pipeline: + """Create a pipeline that contains only the transformers. + + This will exclude any steps that do not have a transform method, such as a + sampler or estimator. + + Args: + pipeline: the pipeline to extract the transformers from. + + Returns: + a pipeline that contains only the transformers. + """ + return Pipeline( + [ + (name, transformer) + for name, transformer in pipeline.steps + if hasattr(transformer, "transform") + ] + ) + + def classification_report_imbalanced_values( y_true, y_pred, labels, target_names=None, sample_weight=None, digits=2, alpha=0.1 ): @@ -399,8 +421,9 @@ def train(self, importance_cutoff=0.15, limit=None): feature_names = self.get_human_readable_feature_names() if self.calculate_importance and len(feature_names): - explainer = shap.TreeExplainer(self.clf) - shap_values = explainer.shap_values(X_train) + explainer = shap.TreeExplainer(self.clf.named_steps["estimator"]) + _X_train = get_transformer_pipeline(self.clf).transform(X_train) + shap_values = explainer.shap_values(_X_train) # In the binary case, sometimes shap returns a single shap values matrix. if is_binary and not isinstance(shap_values, list): @@ -413,7 +436,7 @@ def train(self, importance_cutoff=0.15, limit=None): shap.summary_plot( summary_plot_value, - to_array(X_train), + to_array(_X_train), feature_names=feature_names, class_names=self.class_names, plot_type=summary_plot_type, @@ -628,15 +651,16 @@ def classify( pred_class = self.le.inverse_transform([pred_class_index])[0] if background_dataset is None: - explainer = shap.TreeExplainer(self.clf) + explainer = shap.TreeExplainer(self.clf.named_steps["estimator"]) else: explainer = shap.TreeExplainer( - self.clf, + self.clf.named_steps["estimator"], to_array(background_dataset(pred_class)), feature_perturbation="interventional", ) - shap_values = explainer.shap_values(to_array(X)) + _X = get_transformer_pipeline(self.clf).transform(X) + shap_values = explainer.shap_values(to_array(_X)) # In the binary case, sometimes shap returns a single shap values matrix. if len(classes[0]) == 2 and not isinstance(shap_values, list): @@ -645,7 +669,7 @@ def classify( important_features = self.get_important_features( importance_cutoff, shap_values ) - important_features["values"] = X + important_features["values"] = _X top_indexes = [ int(index) diff --git a/scripts/commit_classifier.py b/scripts/commit_classifier.py index 7cc130c044..fc8289208d 100644 --- a/scripts/commit_classifier.py +++ b/scripts/commit_classifier.py @@ -26,7 +26,7 @@ from scipy.stats import spearmanr from bugbug import db, repository, test_scheduling -from bugbug.model import Model +from bugbug.model import Model, get_transformer_pipeline from bugbug.models.regressor import RegressorModel from bugbug.models.testfailure import TestFailureModel from bugbug.utils import ( @@ -402,7 +402,10 @@ def load_user(phid): ) def generate_feature_importance_data(self, probs, importance): - X_shap_values = shap.TreeExplainer(self.model.clf).shap_values(self.X) + _X = get_transformer_pipeline(self.clf).transform(self.X) + X_shap_values = shap.TreeExplainer( + self.clf.named_steps["estimator"] + ).shap_values(_X) pred_class = self.model.le.inverse_transform([probs[0].argmax()])[0] @@ -414,8 +417,8 @@ def generate_feature_importance_data(self, probs, importance): value = importance["importances"]["values"][0, int(feature_index)] shap.summary_plot( - X_shap_values[:, int(feature_index)].reshape(self.X.shape[0], 1), - self.X[:, int(feature_index)].reshape(self.X.shape[0], 1), + X_shap_values[:, int(feature_index)].reshape(_X.shape[0], 1), + _X[:, int(feature_index)].reshape(_X.shape[0], 1), feature_names=[""], plot_type="layered_violin", show=False, @@ -427,7 +430,7 @@ def generate_feature_importance_data(self, probs, importance): img.seek(0) base64_img = base64.b64encode(img.read()).decode("ascii") - X = self.X[:, int(feature_index)] + X = _X[:, int(feature_index)] y = self.y[X != 0] X = X[X != 0] spearman = spearmanr(X, y)