-
Notifications
You must be signed in to change notification settings - Fork 307
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fit the extraction pipeline based one the training data only #3819
Conversation
50e5943
to
40f8dfd
Compare
40f8dfd
to
497273b
Compare
X = np.fromiter(X_gen(), dtype=object) | ||
y = np.array(y) | ||
self.le.fit(y) | ||
|
||
if limit: | ||
X = X[:limit] | ||
y = y[:limit] | ||
logger.info("Number of data points: %d", len(X)) | ||
|
||
# Split dataset in training and test. | ||
X_train, X_test, y_train, y_test = self.train_test_split(X, y) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With this, the generator does not have much memory optimization benefits. Before, we had the features in memory instead of the original items in dictionaries.
One benefit of the new behaviour is the ability to customize the data splitting if required (e.g., select recent data points to be in the testing set)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the total impact on memory usage? Could you test with one of the models?
Note: if the increase is not too high and we decide to collect items from the generator in an array, we might get rid of the split_tuple_generator
function and simplify things.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After this PR, the memory usage increased by 67% (based on the spambug model). With a small optimization, which is preparing the data as a separate function to avoid keeping unnecessary data in memory, the increase became 23% instead of 67% (It was part of this PR, but I dropped it to do it in a separate PR for simplicity).
Before the PR (3296.1 MiB)
Line # Mem usage Increment Occurrences Line Contents
=============================================================
345 234.4 MiB 234.4 MiB 1 @profile
346 def train(self, importance_cutoff=0.15, limit=None):
347 276.7 MiB 42.2 MiB 1 classes, self.class_names = self.get_labels()
348 276.7 MiB 0.0 MiB 1 self.class_names = sort_class_names(self.class_names)
349
350 # Get items and labels, filtering out those for which we have no labels.
351 276.7 MiB 0.0 MiB 3 X_gen, y = split_tuple_generator(lambda: self.items_gen(classes))
352
353 # Extract features from the items.
354 794.0 MiB 517.4 MiB 1 X = self.extraction_pipeline.fit_transform(X_gen)
355
356 # Calculate labels.
357 794.0 MiB 0.0 MiB 1 y = np.array(y)
358 794.0 MiB 0.0 MiB 1 self.le.fit(y)
359
360 794.0 MiB 0.0 MiB 1 if limit:
361 X = X[:limit]
362 y = y[:limit]
363
364 794.1 MiB 0.0 MiB 1 logger.info(f"X: {X.shape}, y: {y.shape}")
365
366 794.1 MiB 0.0 MiB 1 is_multilabel = isinstance(y[0], np.ndarray)
367 794.1 MiB 0.0 MiB 1 is_binary = len(self.class_names) == 2
368
369 # Split dataset in training and test.
370 794.3 MiB 0.3 MiB 1 X_train, X_test, y_train, y_test = self.train_test_split(X, y)
371 794.3 MiB 0.0 MiB 1 if self.sampler is not None:
372 794.3 MiB 0.0 MiB 1 pipeline = make_pipeline(self.sampler, self.clf)
373 else:
374 pipeline = self.clf
375
376 794.3 MiB 0.0 MiB 1 tracking_metrics = {}
377
378 # Use k-fold cross validation to evaluate results.
379 794.3 MiB 0.0 MiB 1 if self.cross_validation_enabled:
380 794.3 MiB 0.0 MiB 1 scorings = ["accuracy"]
381 794.3 MiB 0.0 MiB 1 if len(self.class_names) == 2:
382 794.3 MiB 0.0 MiB 1 scorings += ["precision", "recall"]
383
384 2878.5 MiB 2084.1 MiB 2 scores = cross_validate(
385 794.4 MiB 0.0 MiB 1 pipeline, X_train, self.le.transform(y_train), scoring=scorings, cv=5
386 )
387
388 2878.5 MiB 0.0 MiB 1 logger.info("Cross Validation scores:")
389 2878.5 MiB 0.0 MiB 4 for scoring in scorings:
390 2878.5 MiB 0.0 MiB 3 score = scores[f"test_{scoring}"]
391 2878.5 MiB 0.0 MiB 3 tracking_metrics[f"test_{scoring}"] = {
392 2878.5 MiB 0.0 MiB 3 "mean": score.mean(),
393 2878.5 MiB 0.0 MiB 3 "std": score.std() * 2,
394 }
395 2878.5 MiB 0.0 MiB 6 logger.info(
396 2878.5 MiB 0.0 MiB 3 f"{scoring.capitalize()}: f{score.mean()} (+/- {score.std() * 2})"
397 )
398
399 2878.5 MiB 0.0 MiB 1 logger.info(f"X_train: {X_train.shape}, y_train: {y_train.shape}")
400
401 # Training on the resampled dataset if sampler is provided.
402 2878.5 MiB 0.0 MiB 1 if self.sampler is not None:
403 3692.4 MiB 813.9 MiB 1 X_train, y_train = self.sampler.fit_resample(X_train, y_train)
404
405 3692.4 MiB 0.0 MiB 1 logger.info(f"resampled X_train: {X_train.shape}, y_train: {y_train.shape}")
406
407 3692.4 MiB 0.0 MiB 1 logger.info(f"X_test: {X_test.shape}, y_test: {y_test.shape}")
408
409 3291.3 MiB -401.1 MiB 1 self.clf.fit(X_train, self.le.transform(y_train))
410
411 3291.3 MiB 0.0 MiB 1 logger.info("Model trained")
412
413 3291.3 MiB 0.0 MiB 1 feature_names = self.get_human_readable_feature_names()
414 3291.3 MiB 0.0 MiB 1 if self.calculate_importance and len(feature_names):
415 explainer = shap.TreeExplainer(self.clf)
416 shap_values = explainer.shap_values(X_train)
417
418 # In the binary case, sometimes shap returns a single shap values matrix.
419 if is_binary and not isinstance(shap_values, list):
420 shap_values = [-shap_values, shap_values]
421 summary_plot_value = shap_values[1]
422 summary_plot_type = "layered_violin"
423 else:
424 summary_plot_value = shap_values
425 summary_plot_type = None
426
427 shap.summary_plot(
428 summary_plot_value,
429 to_array(X_train),
430 feature_names=feature_names,
431 class_names=self.class_names,
432 plot_type=summary_plot_type,
433 show=False,
434 )
435
436 matplotlib.pyplot.savefig("feature_importance.png", bbox_inches="tight")
437 matplotlib.pyplot.xlabel("Impact on model output")
438 matplotlib.pyplot.clf()
439
440 important_features = self.get_important_features(
441 importance_cutoff, shap_values
442 )
443
444 self.print_feature_importances(important_features)
445
446 # Save the important features in the metric report too
447 feature_report = self.save_feature_importances(
448 important_features, feature_names
449 )
450
451 tracking_metrics["feature_report"] = feature_report
452
453 3291.3 MiB 0.0 MiB 1 logger.info("Training Set scores:")
454 3291.3 MiB 0.0 MiB 1 y_pred = self.clf.predict(X_train)
455 3291.5 MiB 0.3 MiB 1 y_pred = self.le.inverse_transform(y_pred)
456 3291.5 MiB 0.0 MiB 1 if not is_multilabel:
457 3291.6 MiB 0.0 MiB 2 print(
458 3291.6 MiB 0.1 MiB 2 classification_report_imbalanced(
459 3291.5 MiB 0.0 MiB 1 y_train, y_pred, labels=self.class_names
460 )
461 )
462
463 3291.6 MiB 0.0 MiB 1 logger.info("Test Set scores:")
464 # Evaluate results on the test set.
465 3291.6 MiB 0.0 MiB 1 y_pred = self.clf.predict(X_test)
466 3291.6 MiB 0.0 MiB 1 y_pred = self.le.inverse_transform(y_pred)
467
468 3291.6 MiB 0.0 MiB 1 if is_multilabel:
469 assert isinstance(
470 y_pred[0], np.ndarray
471 ), "The predictions should be multilabel"
472
473 3291.6 MiB 0.0 MiB 1 logger.info(f"No confidence threshold - {len(y_test)} classified")
474 3291.6 MiB 0.0 MiB 1 if is_multilabel:
475 confusion_matrix = metrics.multilabel_confusion_matrix(y_test, y_pred)
476 else:
477 3291.6 MiB 0.0 MiB 2 confusion_matrix = metrics.confusion_matrix(
478 3291.6 MiB 0.0 MiB 1 y_test, y_pred, labels=self.class_names
479 )
480
481 3291.6 MiB 0.0 MiB 2 print(
482 3291.6 MiB 0.0 MiB 2 classification_report_imbalanced(
483 3291.6 MiB 0.0 MiB 1 y_test, y_pred, labels=self.class_names
484 )
485 )
486 3291.6 MiB 0.0 MiB 2 report = classification_report_imbalanced_values(
487 3291.6 MiB 0.0 MiB 1 y_test, y_pred, labels=self.class_names
488 )
489
490 3291.6 MiB 0.0 MiB 1 tracking_metrics["report"] = report
491
492 3291.6 MiB 0.0 MiB 2 print_labeled_confusion_matrix(
493 3291.6 MiB 0.0 MiB 1 confusion_matrix, self.class_names, is_multilabel=is_multilabel
494 )
495
496 3291.6 MiB 0.0 MiB 1 tracking_metrics["confusion_matrix"] = confusion_matrix.tolist()
497
498 3291.6 MiB 0.0 MiB 1 confidence_thresholds = [0.6, 0.7, 0.8, 0.9]
499
500 3291.6 MiB 0.0 MiB 1 if is_binary:
501 3291.6 MiB 0.0 MiB 1 confidence_thresholds = [0.1, 0.2, 0.3, 0.4] + confidence_thresholds
502
503 # Evaluate results on the test set for some confidence thresholds.
504 3293.9 MiB 0.0 MiB 9 for confidence_threshold in confidence_thresholds:
505 3293.9 MiB 0.0 MiB 8 y_pred_probas = self.clf.predict_proba(X_test)
506 3293.9 MiB 0.0 MiB 8 confidence_class_names = self.class_names + ["__NOT_CLASSIFIED__"]
507
508 3293.9 MiB 0.0 MiB 8 y_pred_filter = []
509 3293.9 MiB 0.0 MiB 8 classified_indices = []
510 3293.9 MiB 0.0 MiB 40248 for i in range(0, len(y_test)):
511 3293.9 MiB 0.0 MiB 40240 if not is_binary:
512 argmax = np.argmax(y_pred_probas[i])
513 else:
514 3293.9 MiB 0.0 MiB 40240 argmax = 1 if y_pred_probas[i][1] > confidence_threshold else 0
515
516 3293.9 MiB 0.0 MiB 40240 if y_pred_probas[i][argmax] < confidence_threshold:
517 3293.9 MiB 0.0 MiB 262 if not is_multilabel:
518 3293.9 MiB 0.0 MiB 262 y_pred_filter.append("__NOT_CLASSIFIED__")
519 3293.9 MiB 0.0 MiB 262 continue
520
521 3293.9 MiB 0.0 MiB 39978 classified_indices.append(i)
522 3293.9 MiB 0.0 MiB 39978 if is_multilabel:
523 y_pred_filter.append(y_pred[i])
524 else:
525 3293.9 MiB 0.0 MiB 39978 y_pred_filter.append(argmax)
526
527 3293.9 MiB 0.0 MiB 8 if not is_multilabel:
528 3293.9 MiB 0.0 MiB 8 y_pred_filter = np.array(y_pred_filter)
529 3293.9 MiB 0.0 MiB 16 y_pred_filter[classified_indices] = self.le.inverse_transform(
530 3293.9 MiB 0.0 MiB 8 np.array(y_pred_filter[classified_indices], dtype=int)
531 )
532
533 3293.9 MiB 0.0 MiB 80242 classified_num = sum(1 for v in y_pred_filter if v != "__NOT_CLASSIFIED__")
534
535 3293.9 MiB 0.0 MiB 16 logger.info(
536 3293.9 MiB 0.0 MiB 8 f"\nConfidence threshold > {confidence_threshold} - {classified_num} classified"
537 )
538 3293.9 MiB 0.0 MiB 8 if is_multilabel:
539 confusion_matrix = metrics.multilabel_confusion_matrix(
540 y_test[classified_indices], np.asarray(y_pred_filter)
541 )
542 else:
543 3293.9 MiB 0.1 MiB 16 confusion_matrix = metrics.confusion_matrix(
544 3293.9 MiB 0.0 MiB 8 y_test.astype(str),
545 3293.9 MiB 0.0 MiB 8 y_pred_filter.astype(str),
546 3293.9 MiB 0.0 MiB 8 labels=confidence_class_names,
547 )
548 3293.9 MiB 0.0 MiB 16 print(
549 3293.9 MiB 2.2 MiB 16 classification_report_imbalanced(
550 3293.9 MiB 0.0 MiB 8 y_test.astype(str),
551 3293.9 MiB 0.0 MiB 8 y_pred_filter.astype(str),
552 3293.9 MiB 0.0 MiB 8 labels=confidence_class_names,
553 )
554 )
555 3293.9 MiB 0.0 MiB 16 print_labeled_confusion_matrix(
556 3293.9 MiB 0.0 MiB 8 confusion_matrix, confidence_class_names, is_multilabel=is_multilabel
557 )
558
559 3293.9 MiB 0.0 MiB 1 self.evaluation()
560
561 3293.9 MiB 0.0 MiB 1 if self.entire_dataset_training:
562 logger.info("Retraining on the entire dataset...")
563
564 if self.sampler is not None:
565 X_train, y_train = self.sampler.fit_resample(X, y)
566 else:
567 X_train = X
568 y_train = y
569
570 logger.info(f"X_train: {X_train.shape}, y_train: {y_train.shape}")
571
572 self.clf.fit(X_train, self.le.transform(y_train))
573
574 3293.9 MiB 0.0 MiB 1 model_directory = self.__class__.__name__.lower()
575 3293.9 MiB 0.0 MiB 1 makedirs(model_directory, exist_ok=True)
576
577 3293.9 MiB 0.0 MiB 1 if issubclass(type(self.clf), XGBModel):
578 3293.9 MiB 0.0 MiB 1 xgboost_model_path = path.join(model_directory, "xgboost.ubj")
579 3294.0 MiB 0.1 MiB 1 self.clf.save_model(xgboost_model_path)
580
581 # Since we save the classifier separately, we need to clear the clf
582 # attribute to prevent it from being pickled with the model object.
583 3294.0 MiB 0.0 MiB 1 self.clf = self.clf.__class__(**self.hyperparameter)
584
585 3294.0 MiB 0.0 MiB 1 model_path = path.join(model_directory, "model.pkl")
586 3296.1 MiB 0.0 MiB 2 with open(model_path, "wb") as f:
587 3296.1 MiB 2.2 MiB 1 pickle.dump(self, f, protocol=pickle.HIGHEST_PROTOCOL)
588
589 3296.1 MiB 0.0 MiB 1 if self.store_dataset:
590 with open(f"{self.__class__.__name__.lower()}_data_X", "wb") as f:
591 pickle.dump(X, f, protocol=pickle.HIGHEST_PROTOCOL)
592
593 with open(f"{self.__class__.__name__.lower()}_data_y", "wb") as f:
594 pickle.dump(y, f, protocol=pickle.HIGHEST_PROTOCOL)
595
596 3296.1 MiB 0.0 MiB 1 return tracking_metrics
After the PR (5504.9 MiB)
Line # Mem usage Increment Occurrences Line Contents
=============================================================
345 234.4 MiB 234.4 MiB 1 @profile
346 def train(self, importance_cutoff=0.15, limit=None):
347 251.5 MiB 17.0 MiB 1 classes, self.class_names = self.get_labels()
348 251.5 MiB 0.0 MiB 1 self.class_names = sort_class_names(self.class_names)
349
350 # Get items and labels, filtering out those for which we have no labels.
351 251.5 MiB 0.0 MiB 3 X_gen, y = split_tuple_generator(lambda: self.items_gen(classes))
352
353 2357.0 MiB 2105.5 MiB 1 X = np.fromiter(X_gen(), dtype=object)
354 2357.0 MiB 0.0 MiB 1 y = np.array(y)
355 2357.0 MiB 0.0 MiB 1 if limit:
356 X = X[:limit]
357 y = y[:limit]
358 2357.0 MiB 0.0 MiB 1 logger.info("Number of data points: %d", len(X))
359
360 # Split dataset in training and test.
361 2357.1 MiB 0.2 MiB 1 X_train, X_test, y_train, y_test = self.train_test_split(X, y)
362
363 # Calculate labels.
364 2357.1 MiB 0.0 MiB 1 self.le.fit(y_train)
365
366 # Extract features from the items.
367 3063.6 MiB 706.5 MiB 3 X_train = self.extraction_pipeline.fit_transform(lambda: X_train)
368 3063.6 MiB -35.4 MiB 3 X_test = self.extraction_pipeline.transform(lambda: X_test)
369
370 3028.2 MiB -35.4 MiB 1 is_multilabel = isinstance(y_train[0], np.ndarray)
371 3028.2 MiB 0.0 MiB 1 is_binary = len(self.class_names) == 2
372
373 3028.2 MiB 0.0 MiB 1 if self.sampler is not None:
374 3028.2 MiB 0.0 MiB 1 pipeline = make_pipeline(self.sampler, self.clf)
375 else:
376 pipeline = self.clf
377
378 3028.2 MiB 0.0 MiB 1 tracking_metrics = {}
379
380 # Use k-fold cross validation to evaluate results.
381 3028.2 MiB 0.0 MiB 1 if self.cross_validation_enabled:
382 3028.2 MiB 0.0 MiB 1 scorings = ["accuracy"]
383 3028.2 MiB 0.0 MiB 1 if len(self.class_names) == 2:
384 3028.2 MiB 0.0 MiB 1 scorings += ["precision", "recall"]
385
386 4762.8 MiB 1734.3 MiB 2 scores = cross_validate(
387 3028.5 MiB 0.2 MiB 1 pipeline, X_train, self.le.transform(y_train), scoring=scorings, cv=5
388 )
389
390 4762.8 MiB 0.0 MiB 1 logger.info("Cross Validation scores:")
391 4762.8 MiB 0.0 MiB 4 for scoring in scorings:
392 4762.8 MiB 0.0 MiB 3 score = scores[f"test_{scoring}"]
393 4762.8 MiB 0.0 MiB 3 tracking_metrics[f"test_{scoring}"] = {
394 4762.8 MiB 0.0 MiB 3 "mean": score.mean(),
395 4762.8 MiB 0.0 MiB 3 "std": score.std() * 2,
396 }
397 4762.8 MiB 0.0 MiB 6 logger.info(
398 4762.8 MiB 0.0 MiB 3 f"{scoring.capitalize()}: f{score.mean()} (+/- {score.std() * 2})"
399 )
400
401 4762.8 MiB 0.0 MiB 1 logger.info(f"X_train: {X_train.shape}, y_train: {y_train.shape}")
402
403 # Training on the resampled dataset if sampler is provided.
404 4762.8 MiB 0.0 MiB 1 if self.sampler is not None:
405 5543.0 MiB 780.2 MiB 1 X_train, y_train = self.sampler.fit_resample(X_train, y_train)
406
407 5543.0 MiB 0.0 MiB 1 logger.info(f"resampled X_train: {X_train.shape}, y_train: {y_train.shape}")
408
409 5543.0 MiB 0.0 MiB 1 logger.info(f"X_test: {X_test.shape}, y_test: {y_test.shape}")
410
411 5483.6 MiB -59.4 MiB 1 self.clf.fit(X_train, self.le.transform(y_train))
412
413 5483.6 MiB 0.0 MiB 1 logger.info("Model trained")
414
415 5483.7 MiB 0.1 MiB 1 feature_names = self.get_human_readable_feature_names()
416 5483.7 MiB 0.0 MiB 1 if self.calculate_importance and len(feature_names):
417 explainer = shap.TreeExplainer(self.clf)
418 shap_values = explainer.shap_values(X_train)
419
420 # In the binary case, sometimes shap returns a single shap values matrix.
421 if is_binary and not isinstance(shap_values, list):
422 shap_values = [-shap_values, shap_values]
423 summary_plot_value = shap_values[1]
424 summary_plot_type = "layered_violin"
425 else:
426 summary_plot_value = shap_values
427 summary_plot_type = None
428
429 shap.summary_plot(
430 summary_plot_value,
431 to_array(X_train),
432 feature_names=feature_names,
433 class_names=self.class_names,
434 plot_type=summary_plot_type,
435 show=False,
436 )
437
438 matplotlib.pyplot.savefig("feature_importance.png", bbox_inches="tight")
439 matplotlib.pyplot.xlabel("Impact on model output")
440 matplotlib.pyplot.clf()
441
442 important_features = self.get_important_features(
443 importance_cutoff, shap_values
444 )
445
446 self.print_feature_importances(important_features)
447
448 # Save the important features in the metric report too
449 feature_report = self.save_feature_importances(
450 important_features, feature_names
451 )
452
453 tracking_metrics["feature_report"] = feature_report
454
455 5483.7 MiB 0.0 MiB 1 logger.info("Training Set scores:")
456 5495.1 MiB 11.4 MiB 1 y_pred = self.clf.predict(X_train)
457 5495.1 MiB 0.0 MiB 1 y_pred = self.le.inverse_transform(y_pred)
458 5495.1 MiB 0.0 MiB 1 if not is_multilabel:
459 5495.4 MiB 0.0 MiB 2 print(
460 5495.4 MiB 0.3 MiB 2 classification_report_imbalanced(
461 5495.1 MiB 0.0 MiB 1 y_train, y_pred, labels=self.class_names
462 )
463 )
464
465 5495.4 MiB 0.0 MiB 1 logger.info("Test Set scores:")
466 # Evaluate results on the test set.
467 5497.5 MiB 2.1 MiB 1 y_pred = self.clf.predict(X_test)
468 5497.5 MiB 0.0 MiB 1 y_pred = self.le.inverse_transform(y_pred)
469
470 5497.5 MiB 0.0 MiB 1 if is_multilabel:
471 assert isinstance(
472 y_pred[0], np.ndarray
473 ), "The predictions should be multilabel"
474
475 5497.5 MiB 0.0 MiB 1 logger.info(f"No confidence threshold - {len(y_test)} classified")
476 5497.5 MiB 0.0 MiB 1 if is_multilabel:
477 confusion_matrix = metrics.multilabel_confusion_matrix(y_test, y_pred)
478 else:
479 5497.5 MiB 0.0 MiB 2 confusion_matrix = metrics.confusion_matrix(
480 5497.5 MiB 0.0 MiB 1 y_test, y_pred, labels=self.class_names
481 )
482
483 5497.5 MiB 0.0 MiB 2 print(
484 5497.5 MiB 0.0 MiB 2 classification_report_imbalanced(
485 5497.5 MiB 0.0 MiB 1 y_test, y_pred, labels=self.class_names
486 )
487 )
488 5497.5 MiB 0.0 MiB 2 report = classification_report_imbalanced_values(
489 5497.5 MiB 0.0 MiB 1 y_test, y_pred, labels=self.class_names
490 )
491
492 5497.5 MiB 0.0 MiB 1 tracking_metrics["report"] = report
493
494 5497.5 MiB 0.0 MiB 2 print_labeled_confusion_matrix(
495 5497.5 MiB 0.0 MiB 1 confusion_matrix, self.class_names, is_multilabel=is_multilabel
496 )
497
498 5497.5 MiB 0.0 MiB 1 tracking_metrics["confusion_matrix"] = confusion_matrix.tolist()
499
500 5497.5 MiB 0.0 MiB 1 confidence_thresholds = [0.6, 0.7, 0.8, 0.9]
501
502 5497.5 MiB 0.0 MiB 1 if is_binary:
503 5497.5 MiB 0.0 MiB 1 confidence_thresholds = [0.1, 0.2, 0.3, 0.4] + confidence_thresholds
504
505 # Evaluate results on the test set for some confidence thresholds.
506 5497.6 MiB 0.0 MiB 9 for confidence_threshold in confidence_thresholds:
507 5497.6 MiB 0.0 MiB 8 y_pred_probas = self.clf.predict_proba(X_test)
508 5497.6 MiB 0.0 MiB 8 confidence_class_names = self.class_names + ["__NOT_CLASSIFIED__"]
509
510 5497.6 MiB 0.0 MiB 8 y_pred_filter = []
511 5497.6 MiB 0.0 MiB 8 classified_indices = []
512 5497.6 MiB 0.0 MiB 40248 for i in range(0, len(y_test)):
513 5497.6 MiB 0.0 MiB 40240 if not is_binary:
514 argmax = np.argmax(y_pred_probas[i])
515 else:
516 5497.6 MiB 0.0 MiB 40240 argmax = 1 if y_pred_probas[i][1] > confidence_threshold else 0
517
518 5497.6 MiB 0.0 MiB 40240 if y_pred_probas[i][argmax] < confidence_threshold:
519 5497.6 MiB 0.0 MiB 254 if not is_multilabel:
520 5497.6 MiB 0.0 MiB 254 y_pred_filter.append("__NOT_CLASSIFIED__")
521 5497.6 MiB 0.0 MiB 254 continue
522
523 5497.6 MiB 0.0 MiB 39986 classified_indices.append(i)
524 5497.6 MiB 0.0 MiB 39986 if is_multilabel:
525 y_pred_filter.append(y_pred[i])
526 else:
527 5497.6 MiB 0.0 MiB 39986 y_pred_filter.append(argmax)
528
529 5497.6 MiB 0.0 MiB 8 if not is_multilabel:
530 5497.6 MiB 0.0 MiB 8 y_pred_filter = np.array(y_pred_filter)
531 5497.6 MiB 0.0 MiB 16 y_pred_filter[classified_indices] = self.le.inverse_transform(
532 5497.6 MiB 0.0 MiB 8 np.array(y_pred_filter[classified_indices], dtype=int)
533 )
534
535 5497.6 MiB 0.0 MiB 80250 classified_num = sum(1 for v in y_pred_filter if v != "__NOT_CLASSIFIED__")
536
537 5497.6 MiB 0.0 MiB 16 logger.info(
538 5497.6 MiB 0.0 MiB 8 f"\nConfidence threshold > {confidence_threshold} - {classified_num} classified"
539 )
540 5497.6 MiB 0.0 MiB 8 if is_multilabel:
541 confusion_matrix = metrics.multilabel_confusion_matrix(
542 y_test[classified_indices], np.asarray(y_pred_filter)
543 )
544 else:
545 5497.6 MiB 0.0 MiB 16 confusion_matrix = metrics.confusion_matrix(
546 5497.6 MiB 0.0 MiB 8 y_test.astype(str),
547 5497.6 MiB 0.0 MiB 8 y_pred_filter.astype(str),
548 5497.6 MiB 0.0 MiB 8 labels=confidence_class_names,
549 )
550 5497.6 MiB 0.0 MiB 16 print(
551 5497.6 MiB 0.0 MiB 16 classification_report_imbalanced(
552 5497.6 MiB 0.0 MiB 8 y_test.astype(str),
553 5497.6 MiB 0.0 MiB 8 y_pred_filter.astype(str),
554 5497.6 MiB 0.0 MiB 8 labels=confidence_class_names,
555 )
556 )
557 5497.6 MiB 0.0 MiB 16 print_labeled_confusion_matrix(
558 5497.6 MiB 0.0 MiB 8 confusion_matrix, confidence_class_names, is_multilabel=is_multilabel
559 )
560
561 5497.6 MiB 0.0 MiB 1 self.evaluation()
562
563 5497.6 MiB 0.0 MiB 1 if self.entire_dataset_training:
564 logger.info("Retraining on the entire dataset...")
565
566 X = np.concatenate((X_train, X_test))
567 y = np.concatenate((y_train, y_test))
568 if self.sampler is not None:
569 X_train, y_train = self.sampler.fit_resample(X, y)
570 else:
571 X_train = X
572 y_train = y
573
574 logger.info(f"X_train: {X_train.shape}, y_train: {y_train.shape}")
575
576 self.clf.fit(X_train, self.le.fit_transform(y_train))
577
578 5497.6 MiB 0.0 MiB 1 model_directory = self.__class__.__name__.lower()
579 5497.6 MiB 0.0 MiB 1 makedirs(model_directory, exist_ok=True)
580
581 5497.6 MiB 0.0 MiB 1 if issubclass(type(self.clf), XGBModel):
582 5497.6 MiB 0.0 MiB 1 xgboost_model_path = path.join(model_directory, "xgboost.ubj")
583 5497.6 MiB 0.1 MiB 1 self.clf.save_model(xgboost_model_path)
584
585 # Since we save the classifier separately, we need to clear the clf
586 # attribute to prevent it from being pickled with the model object.
587 5497.6 MiB 0.0 MiB 1 self.clf = self.clf.__class__(**self.hyperparameter)
588
589 5497.6 MiB 0.0 MiB 1 model_path = path.join(model_directory, "model.pkl")
590 5504.9 MiB 0.0 MiB 2 with open(model_path, "wb") as f:
591 5504.9 MiB 7.2 MiB 1 pickle.dump(self, f, protocol=pickle.HIGHEST_PROTOCOL)
592
593 5504.9 MiB 0.0 MiB 1 if self.store_dataset:
594 with open(f"{self.__class__.__name__.lower()}_data_X", "wb") as f:
595 pickle.dump(X, f, protocol=pickle.HIGHEST_PROTOCOL)
596
597 with open(f"{self.__class__.__name__.lower()}_data_y", "wb") as f:
598 pickle.dump(y, f, protocol=pickle.HIGHEST_PROTOCOL)
599
600 5504.9 MiB 0.0 MiB 1 return tracking_metrics
After the optimization (4043.4 MiB)
Line # Mem usage Increment Occurrences Line Contents
=============================================================
364 235.7 MiB 235.7 MiB 1 @profile
365 def train(self, importance_cutoff=0.15, limit=None):
366 2372.3 MiB 2136.6 MiB 1 X_train, X_test, y_train, y_test = self.__prepare_data(limit)
367
368 # Calculate labels.
369 2372.3 MiB 0.0 MiB 1 self.le.fit(y_train)
370
371 # Extract features from the items.
372 2910.0 MiB 537.7 MiB 3 X_train = self.extraction_pipeline.fit_transform(lambda: X_train)
373 2910.0 MiB -1824.7 MiB 3 X_test = self.extraction_pipeline.transform(lambda: X_test)
374
375 1085.3 MiB -1824.7 MiB 1 is_multilabel = isinstance(y_train[0], np.ndarray)
376 1085.3 MiB 0.0 MiB 1 is_binary = len(self.class_names) == 2
377
378 1085.3 MiB 0.0 MiB 1 if self.sampler is not None:
379 1085.3 MiB 0.0 MiB 1 pipeline = make_pipeline(self.sampler, self.clf)
380 else:
381 pipeline = self.clf
382
383 1085.3 MiB 0.0 MiB 1 tracking_metrics = {}
384
385 # Use k-fold cross validation to evaluate results.
386 1085.3 MiB 0.0 MiB 1 if self.cross_validation_enabled:
387 1085.3 MiB 0.0 MiB 1 scorings = ["accuracy"]
388 1085.3 MiB 0.0 MiB 1 if len(self.class_names) == 2:
389 1085.3 MiB 0.0 MiB 1 scorings += ["precision", "recall"]
390
391 2980.7 MiB 1895.3 MiB 2 scores = cross_validate(
392 1085.4 MiB 0.0 MiB 1 pipeline, X_train, self.le.transform(y_train), scoring=scorings, cv=5
393 )
394
395 2979.7 MiB -1.0 MiB 1 logger.info("Cross Validation scores:")
396 2979.7 MiB 0.0 MiB 4 for scoring in scorings:
397 2979.7 MiB 0.0 MiB 3 score = scores[f"test_{scoring}"]
398 2979.7 MiB 0.0 MiB 3 tracking_metrics[f"test_{scoring}"] = {
399 2979.7 MiB 0.0 MiB 3 "mean": score.mean(),
400 2979.7 MiB 0.0 MiB 3 "std": score.std() * 2,
401 }
402 2979.7 MiB 0.0 MiB 6 logger.info(
403 2979.7 MiB 0.0 MiB 3 f"{scoring.capitalize()}: f{score.mean()} (+/- {score.std() * 2})"
404 )
405
406 2979.7 MiB 0.0 MiB 1 logger.info(f"X_train: {X_train.shape}, y_train: {y_train.shape}")
407
408 # Training on the resampled dataset if sampler is provided.
409 2979.7 MiB 0.0 MiB 1 if self.sampler is not None:
410 3759.7 MiB 780.0 MiB 1 X_train, y_train = self.sampler.fit_resample(X_train, y_train)
411
412 3759.7 MiB 0.0 MiB 1 logger.info(f"resampled X_train: {X_train.shape}, y_train: {y_train.shape}")
413
414 3759.7 MiB 0.0 MiB 1 logger.info(f"X_test: {X_test.shape}, y_test: {y_test.shape}")
415
416 4045.3 MiB 285.6 MiB 1 self.clf.fit(X_train, self.le.transform(y_train))
417
418 4045.3 MiB 0.0 MiB 1 logger.info("Model trained")
419
420 4043.3 MiB -2.0 MiB 1 feature_names = self.get_human_readable_feature_names()
421 4043.3 MiB 0.0 MiB 1 if self.calculate_importance and len(feature_names):
422 explainer = shap.TreeExplainer(self.clf)
423 shap_values = explainer.shap_values(X_train)
424
425 # In the binary case, sometimes shap returns a single shap values matrix.
426 if is_binary and not isinstance(shap_values, list):
427 shap_values = [-shap_values, shap_values]
428 summary_plot_value = shap_values[1]
429 summary_plot_type = "layered_violin"
430 else:
431 summary_plot_value = shap_values
432 summary_plot_type = None
433
434 shap.summary_plot(
435 summary_plot_value,
436 to_array(X_train),
437 feature_names=feature_names,
438 class_names=self.class_names,
439 plot_type=summary_plot_type,
440 show=False,
441 )
442
443 matplotlib.pyplot.savefig("feature_importance.png", bbox_inches="tight")
444 matplotlib.pyplot.xlabel("Impact on model output")
445 matplotlib.pyplot.clf()
446
447 important_features = self.get_important_features(
448 importance_cutoff, shap_values
449 )
450
451 self.print_feature_importances(important_features)
452
453 # Save the important features in the metric report too
454 feature_report = self.save_feature_importances(
455 important_features, feature_names
456 )
457
458 tracking_metrics["feature_report"] = feature_report
459
460 4043.3 MiB 0.0 MiB 1 logger.info("Training Set scores:")
461 4043.3 MiB 0.0 MiB 1 y_pred = self.clf.predict(X_train)
462 4043.3 MiB 0.0 MiB 1 y_pred = self.le.inverse_transform(y_pred)
463 4043.3 MiB 0.0 MiB 1 if not is_multilabel:
464 4043.3 MiB 0.0 MiB 2 print(
465 4043.3 MiB 0.0 MiB 2 classification_report_imbalanced(
466 4043.3 MiB 0.0 MiB 1 y_train, y_pred, labels=self.class_names
467 )
468 )
469
470 4043.3 MiB 0.0 MiB 1 logger.info("Test Set scores:")
471 # Evaluate results on the test set.
472 4043.3 MiB 0.0 MiB 1 y_pred = self.clf.predict(X_test)
473 4043.3 MiB 0.0 MiB 1 y_pred = self.le.inverse_transform(y_pred)
474
475 4043.3 MiB 0.0 MiB 1 if is_multilabel:
476 assert isinstance(
477 y_pred[0], np.ndarray
478 ), "The predictions should be multilabel"
479
480 4043.3 MiB 0.0 MiB 1 logger.info(f"No confidence threshold - {len(y_test)} classified")
481 4043.3 MiB 0.0 MiB 1 if is_multilabel:
482 confusion_matrix = metrics.multilabel_confusion_matrix(y_test, y_pred)
483 else:
484 4043.3 MiB 0.0 MiB 2 confusion_matrix = metrics.confusion_matrix(
485 4043.3 MiB 0.0 MiB 1 y_test, y_pred, labels=self.class_names
486 )
487
488 4043.3 MiB 0.0 MiB 2 print(
489 4043.3 MiB 0.0 MiB 2 classification_report_imbalanced(
490 4043.3 MiB 0.0 MiB 1 y_test, y_pred, labels=self.class_names
491 )
492 )
493 4043.3 MiB 0.0 MiB 2 report = classification_report_imbalanced_values(
494 4043.3 MiB 0.0 MiB 1 y_test, y_pred, labels=self.class_names
495 )
496
497 4043.3 MiB 0.0 MiB 1 tracking_metrics["report"] = report
498
499 4043.3 MiB 0.0 MiB 2 print_labeled_confusion_matrix(
500 4043.3 MiB 0.0 MiB 1 confusion_matrix, self.class_names, is_multilabel=is_multilabel
501 )
502
503 4043.3 MiB 0.0 MiB 1 tracking_metrics["confusion_matrix"] = confusion_matrix.tolist()
504
505 4043.3 MiB 0.0 MiB 1 confidence_thresholds = [0.6, 0.7, 0.8, 0.9]
506
507 4043.3 MiB 0.0 MiB 1 if is_binary:
508 4043.3 MiB 0.0 MiB 1 confidence_thresholds = [0.1, 0.2, 0.3, 0.4] + confidence_thresholds
509
510 # Evaluate results on the test set for some confidence thresholds.
511 4043.3 MiB 0.0 MiB 9 for confidence_threshold in confidence_thresholds:
512 4043.3 MiB 0.0 MiB 8 y_pred_probas = self.clf.predict_proba(X_test)
513 4043.3 MiB 0.0 MiB 8 confidence_class_names = self.class_names + ["__NOT_CLASSIFIED__"]
514
515 4043.3 MiB 0.0 MiB 8 y_pred_filter = []
516 4043.3 MiB 0.0 MiB 8 classified_indices = []
517 4043.3 MiB 0.0 MiB 40248 for i in range(0, len(y_test)):
518 4043.3 MiB 0.0 MiB 40240 if not is_binary:
519 argmax = np.argmax(y_pred_probas[i])
520 else:
521 4043.3 MiB 0.0 MiB 40240 argmax = 1 if y_pred_probas[i][1] > confidence_threshold else 0
522
523 4043.3 MiB 0.0 MiB 40240 if y_pred_probas[i][argmax] < confidence_threshold:
524 4043.3 MiB 0.0 MiB 254 if not is_multilabel:
525 4043.3 MiB 0.0 MiB 254 y_pred_filter.append("__NOT_CLASSIFIED__")
526 4043.3 MiB 0.0 MiB 254 continue
527
528 4043.3 MiB 0.0 MiB 39986 classified_indices.append(i)
529 4043.3 MiB 0.0 MiB 39986 if is_multilabel:
530 y_pred_filter.append(y_pred[i])
531 else:
532 4043.3 MiB 0.0 MiB 39986 y_pred_filter.append(argmax)
533
534 4043.3 MiB 0.0 MiB 8 if not is_multilabel:
535 4043.3 MiB 0.0 MiB 8 y_pred_filter = np.array(y_pred_filter)
536 4043.3 MiB 0.0 MiB 16 y_pred_filter[classified_indices] = self.le.inverse_transform(
537 4043.3 MiB 0.0 MiB 8 np.array(y_pred_filter[classified_indices], dtype=int)
538 )
539
540 4043.3 MiB 0.0 MiB 80250 classified_num = sum(1 for v in y_pred_filter if v != "__NOT_CLASSIFIED__")
541
542 4043.3 MiB 0.0 MiB 16 logger.info(
543 4043.3 MiB 0.0 MiB 8 f"\nConfidence threshold > {confidence_threshold} - {classified_num} classified"
544 )
545 4043.3 MiB 0.0 MiB 8 if is_multilabel:
546 confusion_matrix = metrics.multilabel_confusion_matrix(
547 y_test[classified_indices], np.asarray(y_pred_filter)
548 )
549 else:
550 4043.3 MiB 0.0 MiB 16 confusion_matrix = metrics.confusion_matrix(
551 4043.3 MiB 0.0 MiB 8 y_test.astype(str),
552 4043.3 MiB 0.0 MiB 8 y_pred_filter.astype(str),
553 4043.3 MiB 0.0 MiB 8 labels=confidence_class_names,
554 )
555 4043.3 MiB 0.0 MiB 16 print(
556 4043.3 MiB 0.0 MiB 16 classification_report_imbalanced(
557 4043.3 MiB 0.0 MiB 8 y_test.astype(str),
558 4043.3 MiB 0.0 MiB 8 y_pred_filter.astype(str),
559 4043.3 MiB 0.0 MiB 8 labels=confidence_class_names,
560 )
561 )
562 4043.3 MiB 0.0 MiB 16 print_labeled_confusion_matrix(
563 4043.3 MiB 0.0 MiB 8 confusion_matrix, confidence_class_names, is_multilabel=is_multilabel
564 )
565
566 4043.3 MiB 0.0 MiB 1 self.evaluation()
567
568 4043.3 MiB 0.0 MiB 1 if self.entire_dataset_training:
569 logger.info("Retraining on the entire dataset...")
570
571 X = np.concatenate((X_train, X_test))
572 y = np.concatenate((y_train, y_test))
573 if self.sampler is not None:
574 X_train, y_train = self.sampler.fit_resample(X, y)
575 else:
576 X_train = X
577 y_train = y
578
579 logger.info(f"X_train: {X_train.shape}, y_train: {y_train.shape}")
580
581 self.clf.fit(X_train, self.le.fit_transform(y_train))
582
583 4043.3 MiB 0.0 MiB 1 model_directory = self.__class__.__name__.lower()
584 4043.3 MiB 0.0 MiB 1 makedirs(model_directory, exist_ok=True)
585
586 4043.3 MiB 0.0 MiB 1 if issubclass(type(self.clf), XGBModel):
587 4043.3 MiB 0.0 MiB 1 xgboost_model_path = path.join(model_directory, "xgboost.ubj")
588 4043.4 MiB 0.1 MiB 1 self.clf.save_model(xgboost_model_path)
589
590 # Since we save the classifier separately, we need to clear the clf
591 # attribute to prevent it from being pickled with the model object.
592 4043.4 MiB 0.0 MiB 1 self.clf = self.clf.__class__(**self.hyperparameter)
593
594 4043.4 MiB 0.0 MiB 1 model_path = path.join(model_directory, "model.pkl")
595 4043.4 MiB 0.0 MiB 2 with open(model_path, "wb") as f:
596 4043.4 MiB 0.0 MiB 1 pickle.dump(self, f, protocol=pickle.HIGHEST_PROTOCOL)
597
598 4043.4 MiB 0.0 MiB 1 if self.store_dataset:
599 with open(f"{self.__class__.__name__.lower()}_data_X", "wb") as f:
600 pickle.dump(X, f, protocol=pickle.HIGHEST_PROTOCOL)
601
602 with open(f"{self.__class__.__name__.lower()}_data_y", "wb") as f:
603 pickle.dump(y, f, protocol=pickle.HIGHEST_PROTOCOL)
604
605 4043.4 MiB 0.0 MiB 1 return tracking_metrics
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure why with the optimization the memory usage would be higher. I would expect it to be higher only temporarily, but then after the lines
372 2910.0 MiB 537.7 MiB 3 X_train = self.extraction_pipeline.fit_transform(lambda: X_train)
373 2910.0 MiB -1824.7 MiB 3 X_test = self.extraction_pipeline.transform(lambda: X_test)
it should be equal to "Before the PR"!
Another optimization option would be to del X
right after X_train, X_test, y_train, y_test = self.train_test_split(X, y)
.
Could you try to rerun the training in the three cases using /usr/bin/time -vv python ...
? This should give us some more precise numbers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you try to rerun the training in the three cases using /usr/bin/time -vv python ...? This should give us some more precise numbers.
Command being timed: python -m scripts.trainer spambug
Before the PR | After the PR | After the optimization | |
---|---|---|---|
User time (seconds) | 532.45 | 563.36 | 550.28 |
System time (seconds) | 60.06 | 73.35 | 79.71 |
Percent of CPU this job got | 476% | 491% | 476% |
Elapsed (wall clock) time (h:mm:ss or m:ss) | 2:04.41 | 2:09.51 | 2:12.23 |
Average shared text size (kbytes) | 0 | 0 | 0 |
Average unshared data size (kbytes) | 0 | 0 | 0 |
Average stack size (kbytes) | 0 | 0 | 0 |
Average total size (kbytes) | 0 | 0 | 0 |
Maximum resident set size (kbytes) | 5454032 | 6996208 | 5620096 |
Average resident set size (kbytes) | 0 | 0 | 0 |
Major (requiring I/O) page faults | 2311 | 515 | 465 |
Minor (reclaiming a frame) page faults | 2253813 | 2961066 | 2824487 |
Voluntary context switches | 21225 | 21420 | 19188 |
Involuntary context switches | 2573661 | 4018905 | 6129784 |
Swaps | 0 | 0 | 0 |
File system inputs | 0 | 0 | 0 |
File system outputs | 0 | 0 | 0 |
Socket messages sent | 83 | 83 | 83 |
Socket messages received | 129 | 129 | 129 |
Signals delivered | 0 | 0 | 0 |
Page size (bytes) | 16384 | 16384 | 16384 |
Exit status | 0 | 0 | 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it should be equal to "Before the PR"!
I guess it depends on how active the GC is. I have a 64 GB RAM machine. So, there is not much pressure on the GC to be greedy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The RSS seems almost the same, so I'd say let's land this with the optimization, so we don't have to change anything in the models (unlike #3819 (comment)). Hopefully the findings on the spambug model will apply to other models too. Maybe try with testgroupselect too, it's currently the one using the most memory on TC: spambug is using compute-small, testgroupselect is using compute-super-large. No need to try the whole thing, return early right after logger.info("Model trained")
.
I would also take the opportunity to remove split_tuple_generator
, it doesn't make much sense anymore.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Command being timed: python -m scripts.trainer testgroupselect
Before the PR | After the PR | After the optimization | |
---|---|---|---|
User time (seconds) | 356.94 | 413.23 | 409.01 |
System time (seconds) | 12.73 | 257.04 | 246.91 |
Percent of CPU this job got | 100% | 82% | 81% |
Elapsed (wall clock) time (h:mm:ss or m:ss) | 6:08.22 | 13:36.06 | 13:24.37 |
Average shared text size (kbytes) | 0 | 0 | 0 |
Average unshared data size (kbytes) | 0 | 0 | 0 |
Average stack size (kbytes) | 0 | 0 | 0 |
Average total size (kbytes) | 0 | 0 | 0 |
Maximum resident set size (kbytes) | 27329648 | 31941216 | 31173472 |
Average resident set size (kbytes) | 0 | 0 | 0 |
Major (requiring I/O) page faults | 78 | 75 | 16 |
Minor (reclaiming a frame) page faults | 3252276 | 74624195 | 73167637 |
Voluntary context switches | 2868 | 245713 | 335880 |
Involuntary context switches | 63405 | 3485322 | 3015781 |
Swaps | 0 | 0 | 0 |
File system inputs | 0 | 0 | 0 |
File system outputs | 0 | 0 | 0 |
Socket messages sent | 147 | 147 | 148 |
Socket messages received | 227 | 225 | 228 |
Signals delivered | 0 | 0 | 0 |
Page size (bytes) | 16384 | 16384 | 16384 |
Exit status | 1 | 1 | 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting it took longer, I would have expected the performance to be almost the same, but maybe there was something else running on your machine (I see before the PR the CPU usage was 100%, then around 80%).
There was a 4 GB memory regression, but then again we don't know if it would have been lower if the GC was more aggressive (perhaps we can explicitly call the gc with gc.collect()
?).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tried locally, it was slower for me too, not sure why :/
Anyway, the top memory usage for me was around half what you are seeing and with little difference between before/after the patch.
Just to confirm everything is OK before landing, you could apply the optimization (either the one you have already or the del X
one), and then trigger a taskcluster training of that model in this PR
X_train = self.extraction_pipeline.fit_transform(lambda: X_train) | ||
X_test = self.extraction_pipeline.transform(lambda: X_test) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After this PR, we may want to start passing X to the pipeline as array-like instead of a function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I filed a follow up issue: #3848
X = np.concatenate((X_train, X_test)) | ||
y = np.concatenate((y_train, y_test)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternatively, we could transform the data from scratch. Here, we have the opposite of leaking. The transformer is fitted on part of the training data.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I filed an issue to follow up on this: #3849
Another option could be to do something similar to https://github.com/mozilla/bugbug/pull/3761/files, where we no longer fit anything in the extraction_pipeline, we only use it to extract features from the bugs/patches/whatever. WDYT? |
I mentioned that in #3818 (comment). Currently, We could use the extraction class (e.g., |
Closing in favour of implementing the solution described in #3818 (comment) |
Fixes #3818