diff --git a/BDD_OIA/BDD/data_processing.py b/BDD_OIA/BDD/data_processing.py index 3bc1b83..fc4adb7 100644 --- a/BDD_OIA/BDD/data_processing.py +++ b/BDD_OIA/BDD/data_processing.py @@ -3,14 +3,14 @@ Each dataset is a list of metadata, each includes official image id, full image path, class label, attribute labels, attribute certainty scores, and attribute labels calibrated for uncertainty """ -import sys +import argparse import os -import random import pickle -import argparse -from os import listdir -from os.path import isfile, isdir, join +import random +import sys from collections import defaultdict as ddict +from os import listdir +from os.path import isdir, isfile, join def extract_data(data_dir): @@ -44,20 +44,38 @@ def extract_data(data_dir): }, # calibrate main label based on uncertainty label 0: {1: 0, 2: 0.5, 3: 0.25, 4: 0}, } - with open(join(cwd, data_dir + "/attributes/image_attribute_labels.txt"), "r") as f: + with open( + join( + cwd, data_dir + "/attributes/image_attribute_labels.txt" + ), + "r", + ) as f: for line in f: - file_idx, attribute_idx, attribute_label, attribute_certainty = ( - line.strip().split()[:4] - ) + ( + file_idx, + attribute_idx, + attribute_label, + attribute_certainty, + ) = line.strip().split()[:4] attribute_label = int(attribute_label) attribute_certainty = int(attribute_certainty) - uncertain_label = uncertainty_map[attribute_label][attribute_certainty] - attribute_labels_all[int(file_idx)].append(attribute_label) - attribute_uncertain_labels_all[int(file_idx)].append(uncertain_label) - attribute_certainties_all[int(file_idx)].append(attribute_certainty) + uncertain_label = uncertainty_map[attribute_label][ + attribute_certainty + ] + attribute_labels_all[int(file_idx)].append( + attribute_label + ) + attribute_uncertain_labels_all[int(file_idx)].append( + uncertain_label + ) + attribute_certainties_all[int(file_idx)].append( + attribute_certainty + ) is_train_test = dict() # map from image id to 0 / 1 (1 = train) - with open(join(cwd, data_dir + "/train_test_split.txt"), "r") as f: + with open( + join(cwd, data_dir + "/train_test_split.txt"), "r" + ) as f: for line in f: idx, is_train = line.strip().split() is_train_test[int(idx)] = int(is_train) @@ -68,7 +86,9 @@ def extract_data(data_dir): train_val_data, test_data = [], [] train_data, val_data = [], [] - folder_list = [f for f in listdir(data_path) if isdir(join(data_path, f))] + folder_list = [ + f for f in listdir(data_path) if isdir(join(data_path, f)) + ] folder_list.sort() # sort by class index for i, folder in enumerate(folder_list): folder_path = join(data_path, folder) @@ -86,8 +106,12 @@ def extract_data(data_dir): "img_path": img_path, "class_label": i, "attribute_label": attribute_labels_all[img_id], - "attribute_certainty": attribute_certainties_all[img_id], - "uncertain_attribute_label": attribute_uncertain_labels_all[img_id], + "attribute_certainty": attribute_certainties_all[ + img_id + ], + "uncertain_attribute_label": attribute_uncertain_labels_all[ + img_id + ], } if is_train_test[img_id]: train_val_data.append(metadata) @@ -107,9 +131,15 @@ def extract_data(data_dir): if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Dataset preparation") - parser.add_argument("--save_dir", "-d", help="Where to save the new datasets") - parser.add_argument("--data_dir", help="Where to load the datasets") + parser = argparse.ArgumentParser( + description="Dataset preparation" + ) + parser.add_argument( + "--save_dir", "-d", help="Where to save the new datasets" + ) + parser.add_argument( + "--data_dir", help="Where to load the datasets" + ) args = parser.parse_args() print(args.data_dir) train_data, val_data, test_data = extract_data(args.data_dir) diff --git a/BDD_OIA/BDD/dataset.py b/BDD_OIA/BDD/dataset.py index 7de1904..b3ee6d1 100644 --- a/BDD_OIA/BDD/dataset.py +++ b/BDD_OIA/BDD/dataset.py @@ -2,17 +2,16 @@ General utils for training, evaluation and data loading """ -import sys import os -import torch import pickle +import sys + import numpy as np +import torch import torchvision.transforms as transforms - -from PIL import Image from BDD.config import BASE_DIR, N_ATTRIBUTES -from torch.utils.data import BatchSampler -from torch.utils.data import Dataset, DataLoader +from PIL import Image +from torch.utils.data import BatchSampler, DataLoader, Dataset class BDDDataset(Dataset): @@ -41,9 +40,16 @@ def __init__( transform: whether to apply any special transformation. Default = None, i.e. use standard ImageNet preprocessing """ self.data = [] - self.is_train = any(["train" in path for path in pkl_file_paths]) + self.is_train = any( + ["train" in path for path in pkl_file_paths] + ) if not self.is_train: - assert any([("test" in path) or ("val" in path) for path in pkl_file_paths]) + assert any( + [ + ("test" in path) or ("val" in path) + for path in pkl_file_paths + ] + ) for file_path in pkl_file_paths: self.data.extend(pickle.load(open(file_path, "rb"))) self.transform = transform @@ -133,7 +139,9 @@ class ImbalancedDatasetSampler(torch.utils.data.sampler.Sampler): def __init__(self, dataset, indices=None): # if indices is not provided, # all elements in the dataset will be considered - self.indices = list(range(len(dataset))) if indices is None else indices + self.indices = ( + list(range(len(dataset))) if indices is None else indices + ) # if num_samples is not provided, # draw `len(indices)` samples in each iteration @@ -150,17 +158,22 @@ def __init__(self, dataset, indices=None): # weight for each sample weights = [ - 1.0 / label_to_count[self._get_label(dataset, idx)] for idx in self.indices + 1.0 / label_to_count[self._get_label(dataset, idx)] + for idx in self.indices ] self.weights = torch.DoubleTensor(weights) - def _get_label(self, dataset, idx): # Note: for single attribute dataset + def _get_label( + self, dataset, idx + ): # Note: for single attribute dataset return dataset.data[idx]["attribute_label"][0] def __iter__(self): idx = ( self.indices[i] - for i in torch.multinomial(self.weights, self.num_samples, replacement=True) + for i in torch.multinomial( + self.weights, self.num_samples, replacement=True + ) ) return idx @@ -218,7 +231,10 @@ def load_data( loader = DataLoader(dataset, batch_sampler=sampler) else: loader = DataLoader( - dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last + dataset, + batch_size=batch_size, + shuffle=shuffle, + drop_last=drop_last, ) return loader diff --git a/BDD_OIA/BDD/template_model.py b/BDD_OIA/BDD/template_model.py index 3306e95..a757fce 100644 --- a/BDD_OIA/BDD/template_model.py +++ b/BDD_OIA/BDD/template_model.py @@ -1,11 +1,12 @@ +import math import os + import torch import torch.nn as nn -from torch.nn import Parameter import torch.nn.functional as F import torch.utils.model_zoo as model_zoo import torchvision -import math +from torch.nn import Parameter __all__ = ["FRCNN", "RCNN_global"] @@ -28,7 +29,9 @@ def __init__(self, cfg=None, random_select=False): self.conv_glob2 = nn.Conv2d(128, 64, 3, padding=1) self.relu_glob2 = nn.ReLU(inplace=True) - self.lin_glob = nn.Linear(in_features=3136, out_features=2048, bias=True) + self.lin_glob = nn.Linear( + in_features=3136, out_features=2048, bias=True + ) self.relu_glob = nn.ReLU() def forward(self, x): diff --git a/BDD_OIA/DPL/dpl.py b/BDD_OIA/DPL/dpl.py index c99fa05..f0ac2c9 100644 --- a/BDD_OIA/DPL/dpl.py +++ b/BDD_OIA/DPL/dpl.py @@ -3,14 +3,16 @@ Detail of forwarding of our model """ +import pdb + # -*- coding: utf-8 -*- import sys -import pdb + import numpy as np import torch +import torch.autograd as autograd import torch.nn as nn import torch.nn.functional as F -import torch.autograd as autograd from torch.autograd import Variable # if you set True, many print function is used to debug @@ -20,13 +22,17 @@ def expand_normalize_concepts(concepts: torch.Tensor): assert ( - len(concepts[concepts < 0]) == 0 and len(concepts[concepts > 1]) == 0 + len(concepts[concepts < 0]) == 0 + and len(concepts[concepts > 1]) == 0 ), concepts[:10, :, 0] pC = [] for i in range(concepts.size(1)): # add offset - c = torch.cat((1 - concepts[:, i], concepts[:, i]), dim=1) + 1e-5 + c = ( + torch.cat((1 - concepts[:, i], concepts[:, i]), dim=1) + + 1e-5 + ) with torch.no_grad(): Z = torch.sum(c, dim=1, keepdim=True) pC.append(c / Z) @@ -36,17 +42,19 @@ def expand_normalize_concepts(concepts: torch.Tensor): def create_w_to_y(): - three_bits_or = torch.cat((torch.zeros((8, 1)), torch.ones((8, 1))), dim=1).to( - dtype=torch.float - ) + three_bits_or = torch.cat( + (torch.zeros((8, 1)), torch.ones((8, 1))), dim=1 + ).to(dtype=torch.float) three_bits_or[0] = torch.tensor([1, 0]) - six_bits_or = torch.cat((torch.zeros((64, 1)), torch.ones((64, 1))), dim=1).to( - dtype=torch.float - ) + six_bits_or = torch.cat( + (torch.zeros((64, 1)), torch.ones((64, 1))), dim=1 + ).to(dtype=torch.float) six_bits_or[0] = torch.tensor([1, 0]) - and_not_for_stop = torch.tensor([[0, 1], [0, 1], [1, 0], [0, 1]], dtype=torch.float) + and_not_for_stop = torch.tensor( + [[0, 1], [0, 1], [1, 0], [0, 1]], dtype=torch.float + ) and_not = torch.tensor([[0], [0], [1], [0]], dtype=torch.float) @@ -69,7 +77,9 @@ def __init__: None """ - def __init__(self, conceptizer, parametrizer, aggregator, cbm, senn, device): + def __init__( + self, conceptizer, parametrizer, aggregator, cbm, senn, device + ): super(DPL, self).__init__() self.cbm = cbm self.senn = senn @@ -100,7 +110,9 @@ def forward(self, x): # Get concepts, h_x_labeled is known, h_x is unknown concepts h_x_labeled_raw, h_x, _ = self.conceptizer(x) - h_x_labeled = h_x_labeled_raw.view(-1, h_x_labeled_raw.shape[1], 1) + h_x_labeled = h_x_labeled_raw.view( + -1, h_x_labeled_raw.shape[1], 1 + ) # self.h_norm_l1 = h_x.norm(p=1) @@ -138,7 +150,9 @@ def compute_logic_forward(self, concepts: torch.Tensor): poss_worlds = A.multiply(B).multiply(C).view(-1, 2 * 2 * 2) - active = torch.einsum("bi,ik->bk", poss_worlds, self.or_three_bits) + active = torch.einsum( + "bi,ik->bk", poss_worlds, self.or_three_bits + ) # assert torch.abs(active.sum() / len(active)- 1) < 0.001, (active, active.sum() / len(active) ) @@ -195,10 +209,17 @@ def compute_logic_stop(self, concepts: torch.Tensor): ) poss_worlds = ( - A.multiply(B).multiply(C).multiply(D).multiply(E).multiply(F).view(-1, 64) + A.multiply(B) + .multiply(C) + .multiply(D) + .multiply(E) + .multiply(F) + .view(-1, 64) ) - active = torch.einsum("bi,ik->bk", poss_worlds, self.or_six_bits) + active = torch.einsum( + "bi,ik->bk", poss_worlds, self.or_six_bits + ) return active @@ -209,7 +230,9 @@ def compute_logic_left(self, concepts: torch.Tensor): poss_worlds = A.multiply(B).multiply(C).view(-1, 8) - active = torch.einsum("bi,ik->bk", poss_worlds, self.or_three_bits) + active = torch.einsum( + "bi,ik->bk", poss_worlds, self.or_three_bits + ) return active @@ -220,7 +243,9 @@ def compute_logic_no_left(self, concepts: torch.Tensor): poss_worlds = A.multiply(B).multiply(C).view(-1, 8) - active = torch.einsum("bi,ik->bk", poss_worlds, self.or_three_bits) + active = torch.einsum( + "bi,ik->bk", poss_worlds, self.or_three_bits + ) return active @@ -231,7 +256,9 @@ def compute_logic_right(self, concepts: torch.Tensor): poss_worlds = A.multiply(B).multiply(C).view(-1, 8) - active = torch.einsum("bi,ik->bk", poss_worlds, self.or_three_bits) + active = torch.einsum( + "bi,ik->bk", poss_worlds, self.or_three_bits + ) return active @@ -242,7 +269,9 @@ def compute_logic_no_right(self, concepts: torch.Tensor): poss_worlds = A.multiply(B).multiply(C).view(-1, 8) - active = torch.einsum("bi,ik->bk", poss_worlds, self.or_three_bits) + active = torch.einsum( + "bi,ik->bk", poss_worlds, self.or_three_bits + ) return active @@ -262,7 +291,9 @@ def proglob_pred(self): F_pred = F_pred[..., None] S_pred = S_pred[:, None, :] w_FS = F_pred.multiply(S_pred).view(-1, 4) - labels_01 = torch.einsum("bi,ik->bk", w_FS, self.rule_for_stop) + labels_01 = torch.einsum( + "bi,ik->bk", w_FS, self.rule_for_stop + ) L_pred = L_pred[..., None] NL_pred = NL_pred[:, None, :] @@ -275,15 +306,17 @@ def proglob_pred(self): w_R = R_pred.multiply(NR_pred).view(-1, 4) labels_3 = torch.einsum("bi,il->bl", w_R, self.rule_lr_move) - labels_4 = torch.sigmoid(self.pred_5(self.concepts_labeled[:, :, 0])).view( - -1, 1 - ) + labels_4 = torch.sigmoid( + self.pred_5(self.concepts_labeled[:, :, 0]) + ).view(-1, 1) labels_2 = torch.cat([labels_2, 1 - labels_2], dim=1) labels_3 = torch.cat([labels_3, 1 - labels_3], dim=1) labels_4 = torch.cat([labels_4, 1 - labels_4], dim=1) - pred = torch.cat([labels_01, labels_2, labels_3, labels_4], dim=1) + pred = torch.cat( + [labels_01, labels_2, labels_3, labels_4], dim=1 + ) # avoid overflow pred = (pred + 1e-5) / (1 + 2 * 1e-5) diff --git a/BDD_OIA/DPL/dpl_auc.py b/BDD_OIA/DPL/dpl_auc.py index 7b786ad..2255e90 100644 --- a/BDD_OIA/DPL/dpl_auc.py +++ b/BDD_OIA/DPL/dpl_auc.py @@ -3,20 +3,22 @@ Detail of forwarding of our model """ +import pdb + # -*- coding: utf-8 -*- import sys -import pdb + import numpy as np import torch +import torch.autograd as autograd import torch.nn as nn import torch.nn.functional as F -import torch.autograd as autograd -from torch.autograd import Variable from DPL.utils_problog import ( build_world_queries_matrix_FS, build_world_queries_matrix_L, build_world_queries_matrix_R, ) +from torch.autograd import Variable # if you set True, many print function is used to debug DEBUG = False @@ -25,13 +27,17 @@ def expand_normalize_concepts(concepts: torch.Tensor): assert ( - len(concepts[concepts < 0]) == 0 and len(concepts[concepts > 1]) == 0 + len(concepts[concepts < 0]) == 0 + and len(concepts[concepts > 1]) == 0 ), concepts[:10, :, 0] pC = [] for i in range(concepts.size(1)): # add offset - c = torch.cat((1 - concepts[:, i], concepts[:, i]), dim=1) + 1e-5 + c = ( + torch.cat((1 - concepts[:, i], concepts[:, i]), dim=1) + + 1e-5 + ) with torch.no_grad(): Z = torch.sum(c, dim=1, keepdim=True) pC.append(c / Z) @@ -42,9 +48,9 @@ def expand_normalize_concepts(concepts: torch.Tensor): def create_w_to_y(): - four_bits_or = torch.cat((torch.zeros((16, 1)), torch.ones((16, 1))), dim=1).to( - dtype=torch.float - ) + four_bits_or = torch.cat( + (torch.zeros((16, 1)), torch.ones((16, 1))), dim=1 + ).to(dtype=torch.float) four_bits_or[0] = torch.tensor([1, 0]) return four_bits_or @@ -66,7 +72,9 @@ def __init__: None """ - def __init__(self, conceptizer, parametrizer, aggregator, cbm, senn, device): + def __init__( + self, conceptizer, parametrizer, aggregator, cbm, senn, device + ): super(DPL_AUC, self).__init__() self.cbm = cbm self.senn = senn @@ -100,7 +108,9 @@ def forward(self, x): # Get concepts, h_x_labeled is known, h_x is unknown concepts h_x_labeled_raw, h_x, _ = self.conceptizer(x) - h_x_labeled = h_x_labeled_raw.view(-1, h_x_labeled_raw.shape[1], 1) + h_x_labeled = h_x_labeled_raw.view( + -1, h_x_labeled_raw.shape[1], 1 + ) # self.h_norm_l1 = h_x.norm(p=1) if DEBUG: @@ -136,24 +146,37 @@ def compute_logic_no_left_lane(self): obs_worlds = A.multiply(B).view(-1, 4) - no_left_lane = torch.einsum("bi,ik->bk", obs_worlds, self.or_two_bits) + no_left_lane = torch.einsum( + "bi,ik->bk", obs_worlds, self.or_two_bits + ) return no_left_lane def compute_logic_obstacle(self): - o_car = self.pC[:, 10:12].unsqueeze(2).unsqueeze(3).unsqueeze(4) # car - o_person = self.pC[:, 12:14].unsqueeze(1).unsqueeze(3).unsqueeze(4) # person - o_rider = self.pC[:, 14:16].unsqueeze(1).unsqueeze(2).unsqueeze(4) # rider + o_car = ( + self.pC[:, 10:12].unsqueeze(2).unsqueeze(3).unsqueeze(4) + ) # car + o_person = ( + self.pC[:, 12:14].unsqueeze(1).unsqueeze(3).unsqueeze(4) + ) # person + o_rider = ( + self.pC[:, 14:16].unsqueeze(1).unsqueeze(2).unsqueeze(4) + ) # rider o_other = ( self.pC[:, 16:18].unsqueeze(1).unsqueeze(2).unsqueeze(3) ) # other obstacle obs_worlds = ( - o_car.multiply(o_person).multiply(o_rider).multiply(o_other).view(-1, 16) + o_car.multiply(o_person) + .multiply(o_rider) + .multiply(o_other) + .view(-1, 16) ) - obs_active = torch.einsum("bi,ik->bk", obs_worlds, self.or_four_bits) + obs_active = torch.einsum( + "bi,ik->bk", obs_worlds, self.or_four_bits + ) return obs_active @@ -170,15 +193,56 @@ def proglob_pred(self): t_sign = self.pC[:, 8:10] # traffic sign present obs = self.compute_logic_obstacle() # generic obstacle - A = tl_green.unsqueeze(2).unsqueeze(3).unsqueeze(4).unsqueeze(5).unsqueeze(6) - B = follow.unsqueeze(1).unsqueeze(3).unsqueeze(4).unsqueeze(5).unsqueeze(6) - C = clear.unsqueeze(1).unsqueeze(2).unsqueeze(4).unsqueeze(5).unsqueeze(6) - D = tl_red.unsqueeze(1).unsqueeze(2).unsqueeze(3).unsqueeze(5).unsqueeze(6) - E = t_sign.unsqueeze(1).unsqueeze(2).unsqueeze(3).unsqueeze(4).unsqueeze(6) - F = obs.unsqueeze(1).unsqueeze(2).unsqueeze(3).unsqueeze(4).unsqueeze(5) + A = ( + tl_green.unsqueeze(2) + .unsqueeze(3) + .unsqueeze(4) + .unsqueeze(5) + .unsqueeze(6) + ) + B = ( + follow.unsqueeze(1) + .unsqueeze(3) + .unsqueeze(4) + .unsqueeze(5) + .unsqueeze(6) + ) + C = ( + clear.unsqueeze(1) + .unsqueeze(2) + .unsqueeze(4) + .unsqueeze(5) + .unsqueeze(6) + ) + D = ( + tl_red.unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .unsqueeze(5) + .unsqueeze(6) + ) + E = ( + t_sign.unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .unsqueeze(4) + .unsqueeze(6) + ) + F = ( + obs.unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .unsqueeze(4) + .unsqueeze(5) + ) w_FS = ( - A.multiply(B).multiply(C).multiply(D).multiply(E).multiply(F).view(-1, 64) + A.multiply(B) + .multiply(C) + .multiply(D) + .multiply(E) + .multiply(F) + .view(-1, 64) ) # labels_FS = torch.einsum("bi,ik->bk", w_FS, self.FS_w_q) @@ -194,7 +258,13 @@ def proglob_pred(self): l_obs = self.pC[:, 26:28] # LEFT obstacle left_line = self.pC[:, 28:30] # solid line on LEFT - AL = left_lane.unsqueeze(2).unsqueeze(3).unsqueeze(4).unsqueeze(5).unsqueeze(6) + AL = ( + left_lane.unsqueeze(2) + .unsqueeze(3) + .unsqueeze(4) + .unsqueeze(5) + .unsqueeze(6) + ) BL = ( tl_green_left.unsqueeze(1) .unsqueeze(3) @@ -203,7 +273,11 @@ def proglob_pred(self): .unsqueeze(6) ) CL = ( - follow_left.unsqueeze(1).unsqueeze(2).unsqueeze(4).unsqueeze(5).unsqueeze(6) + follow_left.unsqueeze(1) + .unsqueeze(2) + .unsqueeze(4) + .unsqueeze(5) + .unsqueeze(6) ) DL = ( no_left_lane.unsqueeze(1) @@ -212,8 +286,20 @@ def proglob_pred(self): .unsqueeze(5) .unsqueeze(6) ) - EL = l_obs.unsqueeze(1).unsqueeze(2).unsqueeze(3).unsqueeze(4).unsqueeze(6) - FL = left_line.unsqueeze(1).unsqueeze(2).unsqueeze(3).unsqueeze(4).unsqueeze(5) + EL = ( + l_obs.unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .unsqueeze(4) + .unsqueeze(6) + ) + FL = ( + left_line.unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .unsqueeze(4) + .unsqueeze(5) + ) w_L = ( AL.multiply(BL) @@ -237,7 +323,13 @@ def proglob_pred(self): r_obs = self.pC[:, 38:40] # RIGHT obstacle rigt_line = self.pC[:, 40:42] # solid line on RIGHT - AL = rigt_lane.unsqueeze(2).unsqueeze(3).unsqueeze(4).unsqueeze(5).unsqueeze(6) + AL = ( + rigt_lane.unsqueeze(2) + .unsqueeze(3) + .unsqueeze(4) + .unsqueeze(5) + .unsqueeze(6) + ) BL = ( tl_green_rigt.unsqueeze(1) .unsqueeze(3) @@ -246,7 +338,11 @@ def proglob_pred(self): .unsqueeze(6) ) CL = ( - follow_rigt.unsqueeze(1).unsqueeze(2).unsqueeze(4).unsqueeze(5).unsqueeze(6) + follow_rigt.unsqueeze(1) + .unsqueeze(2) + .unsqueeze(4) + .unsqueeze(5) + .unsqueeze(6) ) DL = ( no_rigt_lane.unsqueeze(1) @@ -255,8 +351,20 @@ def proglob_pred(self): .unsqueeze(5) .unsqueeze(6) ) - EL = r_obs.unsqueeze(1).unsqueeze(2).unsqueeze(3).unsqueeze(4).unsqueeze(6) - FL = rigt_line.unsqueeze(1).unsqueeze(2).unsqueeze(3).unsqueeze(4).unsqueeze(5) + EL = ( + r_obs.unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .unsqueeze(4) + .unsqueeze(6) + ) + FL = ( + rigt_line.unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .unsqueeze(4) + .unsqueeze(5) + ) w_R = ( AL.multiply(BL) @@ -269,7 +377,9 @@ def proglob_pred(self): label_R = torch.einsum("bi,ik->bk", w_R, self.R_w_q) - pred = torch.cat([labels_FS, label_L, label_R], dim=1) # this is 8 dim + pred = torch.cat( + [labels_FS, label_L, label_R], dim=1 + ) # this is 8 dim # avoid overflow pred = (pred + 1e-5) / (1 + 2 * 1e-5) diff --git a/BDD_OIA/DPL/dpl_auc_pcbm.py b/BDD_OIA/DPL/dpl_auc_pcbm.py index f58c629..cc1e5ee 100644 --- a/BDD_OIA/DPL/dpl_auc_pcbm.py +++ b/BDD_OIA/DPL/dpl_auc_pcbm.py @@ -1,10 +1,12 @@ -from DPL.dpl_auc import DPL_AUC import torch import torch.nn.functional as F +from DPL.dpl_auc import DPL_AUC class DPL_AUC_PCBM(DPL_AUC): - def __init__(self, conceptizer, parametrizer, aggregator, cbm, senn, device): + def __init__( + self, conceptizer, parametrizer, aggregator, cbm, senn, device + ): self.gaussian_vars = False super(DPL_AUC_PCBM, self).__init__( conceptizer, parametrizer, aggregator, cbm, senn, device @@ -29,8 +31,12 @@ def forward(self, x): if self.gaussian_vars: return latents, logsigma - pred_embeddings = self._sample_gaussian_tensors(latents, logsigma, 100) - concept_logit, concept_prob = self._compute_distance(pred_embeddings, z_tot) + pred_embeddings = self._sample_gaussian_tensors( + latents, logsigma, 100 + ) + concept_logit, concept_prob = self._compute_distance( + pred_embeddings, z_tot + ) # prob_Cs = concept_prob[..., 1] # print(concept_logit.shape) @@ -39,7 +45,9 @@ def forward(self, x): # store known concepts # if not self.senn: - squeezed_c_logits = torch.unsqueeze(concept_prob[..., 1], dim=-1) + squeezed_c_logits = torch.unsqueeze( + concept_prob[..., 1], dim=-1 + ) # store (known+unknown) concepts self.concepts = torch.cat((squeezed_c_logits, h_x), dim=1) # store known concepts @@ -59,7 +67,9 @@ def forward(self, x): return out def _batchwise_cdist(self, samples1, samples2, eps=1e-6): - if len(samples1.size()) not in [3, 4, 5] or len(samples2.size()) not in [ + if len(samples1.size()) not in [3, 4, 5] or len( + samples2.size() + ) not in [ 3, 4, 5, @@ -75,7 +85,9 @@ def _batchwise_cdist(self, samples1, samples2, eps=1e-6): samples2 = samples2.unsqueeze(3) samples1 = samples1.unsqueeze(1) samples2 = samples2.unsqueeze(0) - result = torch.sqrt(((samples1 - samples2) ** 2).sum(-1) + eps) + result = torch.sqrt( + ((samples1 - samples2) ** 2).sum(-1) + eps + ) return result.view(*result.shape[:-2], -1) else: raise RuntimeError( @@ -84,7 +96,12 @@ def _batchwise_cdist(self, samples1, samples2, eps=1e-6): ) def _compute_distance( - self, pred_embeddings, z_tot, negative_scale=None, shift=None, reduction="mean" + self, + pred_embeddings, + z_tot, + negative_scale=None, + shift=None, + reduction="mean", ): negative_scale = ( self.conceptizer.negative_scale @@ -111,6 +128,8 @@ def _sample_gaussian_tensors(self, mu, logsigma, num_samples): dtype=mu.dtype, device=mu.device, ) - samples_sigma = eps.mul(torch.exp(logsigma.unsqueeze(2) * 0.5)) + samples_sigma = eps.mul( + torch.exp(logsigma.unsqueeze(2) * 0.5) + ) samples = samples_sigma.add_(mu.unsqueeze(2)) return samples diff --git a/BDD_OIA/DPL/utils_problog.py b/BDD_OIA/DPL/utils_problog.py index 2229d80..59fc68b 100644 --- a/BDD_OIA/DPL/utils_problog.py +++ b/BDD_OIA/DPL/utils_problog.py @@ -1,6 +1,7 @@ -import torch from itertools import product +import torch + def build_world_queries_matrix_complete_FS(): @@ -11,7 +12,17 @@ def build_world_queries_matrix_complete_FS(): w_q = torch.zeros(n_worlds, n_queries) # (100, 20) for w in range(n_worlds): - tl_green, follow, clear, tl_red, t_sign, ob1, ob2, ob3, ob4 = look_up[w] + ( + tl_green, + follow, + clear, + tl_red, + t_sign, + ob1, + ob2, + ob3, + ob4, + ) = look_up[w] obs = min(ob1 + ob2 + ob3 + ob4, 1) @@ -71,9 +82,15 @@ def build_world_queries_matrix_LR(): w_q = torch.zeros(n_worlds, n_queries) # (100, 20) for w in range(n_worlds): - tl_red, no_left_lane, left_solid_line, obs, left_lane, tl_green, follow = ( - look_up[w] - ) + ( + tl_red, + no_left_lane, + left_solid_line, + obs, + left_lane, + tl_green, + follow, + ) = look_up[w] if left_lane + tl_green + follow > 0: if tl_green + tl_red == 2 or no_left_lane == 1: @@ -98,9 +115,24 @@ def build_world_queries_matrix_L(): w_q = torch.zeros(n_worlds, n_queries) for w in range(n_worlds): - left_lane, tl_green, follow, no_left_lane, obs, left_solid_line = look_up[w] - - if left_lane + tl_green + follow + no_left_lane + obs + left_solid_line == 0: + ( + left_lane, + tl_green, + follow, + no_left_lane, + obs, + left_solid_line, + ) = look_up[w] + + if ( + left_lane + + tl_green + + follow + + no_left_lane + + obs + + left_solid_line + == 0 + ): w_q[w, 0] = 0.5 w_q[w, 1] = 0.5 elif left_lane + tl_green + follow > 0: @@ -119,9 +151,24 @@ def build_world_queries_matrix_R(): w_q = torch.zeros(n_worlds, n_queries) for w in range(n_worlds): - right_lane, tl_green, follow, no_right_lane, obs, right_solid_line = look_up[w] - - if right_lane + tl_green + follow + no_right_lane + obs + right_solid_line == 0: + ( + right_lane, + tl_green, + follow, + no_right_lane, + obs, + right_solid_line, + ) = look_up[w] + + if ( + right_lane + + tl_green + + follow + + no_right_lane + + obs + + right_solid_line + == 0 + ): w_q[w, 0] = 0.5 w_q[w, 1] = 0.5 elif right_lane + tl_green + follow > 0: @@ -199,7 +246,12 @@ def compute_logic_stop(or_six_bits, concepts: torch.Tensor): ) poss_worlds = ( - A.multiply(B).multiply(C).multiply(D).multiply(E).multiply(F).view(-1, 64) + A.multiply(B) + .multiply(C) + .multiply(D) + .multiply(E) + .multiply(F) + .view(-1, 64) ) active = torch.einsum("bi,ik->bk", poss_worlds, or_six_bits) @@ -265,7 +317,10 @@ def compute_logic_obstacle(or_four_bits, pC: torch.Tensor): o_other = pC[:, 16:18].unsqueeze(1).unsqueeze(2).unsqueeze(3) obs_worlds = ( - o_car.multiply(o_person).multiply(o_rider).multiply(o_other).view(-1, 16) + o_car.multiply(o_person) + .multiply(o_rider) + .multiply(o_other) + .view(-1, 16) ) obs_active = torch.einsum("bi,ik->bk", obs_worlds, or_four_bits) diff --git a/BDD_OIA/SENN/arglist.py b/BDD_OIA/SENN/arglist.py index aea0149..816c07b 100644 --- a/BDD_OIA/SENN/arglist.py +++ b/BDD_OIA/SENN/arglist.py @@ -4,8 +4,9 @@ Example is in parameter.txt so please see it """ -import pdb import argparse +import pdb + import torch @@ -45,7 +46,10 @@ def get_senn_parser(): ### Save Paths parser.add_argument( - "--model_path", type=str, default="models", help="where to save the snapshot" + "--model_path", + type=str, + default="models", + help="where to save the snapshot", ) parser.add_argument( "--results_path", @@ -68,16 +72,25 @@ def get_senn_parser(): ### Device parser.add_argument( - "--cuda", action="store_true", default=False, help="enable the gpu" + "--cuda", + action="store_true", + default=False, + help="enable the gpu", + ) + parser.add_argument( + "--num_gpus", type=int, default=2, help="Num GPUs to use." + ) + parser.add_argument( + "--seed", type=int, default=2018, help="Set random seed." ) - parser.add_argument("--num_gpus", type=int, default=2, help="Num GPUs to use.") - parser.add_argument("--seed", type=int, default=2018, help="Set random seed.") ### Model # Concept Encoder (H) parser.add_argument( - "--cbm", default=True, help="type of conceptizer (learnt or input)" + "--cbm", + default=True, + help="type of conceptizer (learnt or input)", ) # newly added parser.add_argument( "--senn", @@ -92,13 +105,26 @@ def get_senn_parser(): help="type of conceptizer (learnt or input)", ) # Don's change!! parser.add_argument( - "--concept_dim", type=int, default=1, help="concept dimension. dont change" + "--concept_dim", + type=int, + default=1, + help="concept dimension. dont change", ) parser.add_argument( - "--nconcepts_labeled", type=int, default=21, help="number of labeled concepts" + "--nconcepts_labeled", + type=int, + default=21, + help="number of labeled concepts", ) # newly added - parser.add_argument("--nconcepts", type=int, default=30, help="number of concepts") - parser.add_argument("--h_sparsity", type=int, default=7, help="kWTA hyperparameter") + parser.add_argument( + "--nconcepts", type=int, default=30, help="number of concepts" + ) + parser.add_argument( + "--h_sparsity", + type=int, + default=7, + help="kWTA hyperparameter", + ) # Added for weak supervision parser.add_argument( @@ -120,7 +146,10 @@ def get_senn_parser(): help="parameter for learning h [default: 1-e4]", ) parser.add_argument( - "--info_hypara", type=float, default=0.5, help="hyperparameter of info loss" + "--info_hypara", + type=float, + default=0.5, + help="hyperparameter of info loss", ) # newly added # Parametrizing Function (Theta) @@ -141,7 +170,14 @@ def get_senn_parser(): type=str, default="simple", help="Parametrizer architecture", - choices=["simple", "alexnet", "vgg8", "vgg11_bn", "vgg11", "vgg16"], + choices=[ + "simple", + "alexnet", + "vgg8", + "vgg11_bn", + "vgg11", + "vgg16", + ], ) parser.add_argument( "--theta_dim", @@ -164,10 +200,16 @@ def get_senn_parser(): ### Learning parser.add_argument( - "--opt", type=str, default="adam", help="optim method [default: adam]" + "--opt", + type=str, + default="adam", + help="optim method [default: adam]", ) parser.add_argument( - "--lr", type=float, default=0.001, help="initial learning rate [default: 0.001]" + "--lr", + type=float, + default=0.001, + help="initial learning rate [default: 0.001]", ) parser.add_argument( "--epochs", @@ -201,16 +243,25 @@ def get_senn_parser(): ### Data --- FIXME: Not used yet. Maybe use to avoid duplication of main scripts for similar tasks (e.g. MNIST, CIFAR) parser.add_argument( - "--dataset", default="pathology", help="choose which dataset to run on" + "--dataset", + default="pathology", + help="choose which dataset to run on", + ) + parser.add_argument( + "--embedding", + default="pathology", + help="choose what embeddings to use", ) parser.add_argument( - "--embedding", default="pathology", help="choose what embeddings to use" + "--nclasses", type=int, default=2, help="number of classes" ) - parser.add_argument("--nclasses", type=int, default=2, help="number of classes") ### Misc parser.add_argument( - "--num_workers", type=int, default=30, help="num workers for data loader" + "--num_workers", + type=int, + default=30, + help="num workers for data loader", ) parser.add_argument( "--print_freq", @@ -219,7 +270,10 @@ def get_senn_parser(): help="print frequency during train (in batches)", ) parser.add_argument( - "--debug", action="store_true", default=False, help="debug mode" + "--debug", + action="store_true", + default=False, + help="debug mode", ) return parser @@ -252,23 +306,40 @@ def parse_args(): # device parser.add_argument( - "--cuda", action="store_true", default=False, help="enable the gpu" + "--cuda", + action="store_true", + default=False, + help="enable the gpu", + ) + parser.add_argument( + "--num_gpus", type=int, default=1, help="Num GPUs to use." ) - parser.add_argument("--num_gpus", type=int, default=1, help="Num GPUs to use.") parser.add_argument( - "--debug", action="store_true", default=False, help="debug mode" + "--debug", + action="store_true", + default=False, + help="debug mode", ) # learning parser.add_argument( - "--opt", type=str, default="adam", help="optim method [default: adam]" + "--opt", + type=str, + default="adam", + help="optim method [default: adam]", ) parser.add_argument( - "--lr", type=float, default=0.001, help="initial learning rate [default: 0.001]" + "--lr", + type=float, + default=0.001, + help="initial learning rate [default: 0.001]", ) parser.add_argument( - "--epochs", type=int, default=3, help="number of epochs for train [default: 10]" + "--epochs", + type=int, + default=3, + help="number of epochs for train [default: 10]", ) parser.add_argument( "--batch_size", @@ -284,7 +355,10 @@ def parse_args(): # pathsn parser.add_argument( - "--model_path", type=str, default="models", help="where to save the snapshot" + "--model_path", + type=str, + default="models", + help="where to save the snapshot", ) parser.add_argument( "--results_path", @@ -327,12 +401,20 @@ def parse_args(): # parser.add_argument('--learn_h', type='str', default='learnt', help='type of conceptizer (learnt or input)' ) parser.add_argument( - "--concept_dim", type=int, default=1, help="concept dimension. dont change" + "--concept_dim", + type=int, + default=1, + help="concept dimension. dont change", ) parser.add_argument( - "--nconcepts_labeled", type=int, default=6, help="number of labeled concepts" + "--nconcepts_labeled", + type=int, + default=6, + help="number of labeled concepts", ) # newly added - parser.add_argument("--nconcepts", type=int, default=20, help="number of concepts") + parser.add_argument( + "--nconcepts", type=int, default=20, help="number of concepts" + ) parser.add_argument( "--nobias", action="store_true", @@ -340,7 +422,12 @@ def parse_args(): help="do not add a bias term theta_0", ) - parser.add_argument("--h_sparsity", type=int, default=7, help="kWTA hyperparameter") + parser.add_argument( + "--h_sparsity", + type=int, + default=7, + help="kWTA hyperparameter", + ) parser.add_argument( "--positive_theta", @@ -381,16 +468,25 @@ def parse_args(): # data parser.add_argument( - "--dataset", default="pathology", help="choose which dataset to run on" + "--dataset", + default="pathology", + help="choose which dataset to run on", ) parser.add_argument( - "--embedding", default="pathology", help="choose what embeddings to use" + "--embedding", + default="pathology", + help="choose what embeddings to use", + ) + parser.add_argument( + "--nclasses", type=int, default=2, help="number of classes" ) - parser.add_argument("--nclasses", type=int, default=2, help="number of classes") # data loading parser.add_argument( - "--num_workers", type=int, default=4, help="num workers for data loader" + "--num_workers", + type=int, + default=4, + help="num workers for data loader", ) # misc diff --git a/BDD_OIA/SENN/classifier.py b/BDD_OIA/SENN/classifier.py index 52e2cb2..c52d0d8 100644 --- a/BDD_OIA/SENN/classifier.py +++ b/BDD_OIA/SENN/classifier.py @@ -18,25 +18,29 @@ from __future__ import absolute_import, division, unicode_literals -import time -import numpy as np import copy +import time +import numpy as np import torch +import torch.optim as optim from torch import nn from torch.autograd import Variable -import torch.optim as optim -from .utils import AverageMeter from .models import SENN_FFFC +from .utils import AverageMeter def tv_reg_loss(model): LAMBDA = 1e-6 params = model.params.view(-1, 10, 28, 28) reg_loss = LAMBDA * ( - torch.sum(torch.abs(params[:, :, :, :-1] - params[:, :, :, 1:])) - + torch.sum(torch.abs(params[:, :, :-1, :] - params[:, :, 1:, :])) + torch.sum( + torch.abs(params[:, :, :, :-1] - params[:, :, :, 1:]) + ) + + torch.sum( + torch.abs(params[:, :, :-1, :] - params[:, :, 1:, :]) + ) ) return reg_loss @@ -82,7 +86,9 @@ def __init__( self.maxepoch = maxepoch self.print_freq = print_freq - def prepare_split(self, X, y, validation_data=None, validation_split=None): + def prepare_split( + self, X, y, validation_data=None, validation_split=None + ): # Preparing validation data assert validation_split or validation_data if validation_data is not None: @@ -166,7 +172,9 @@ def trainepoch(self, X, y, nepoches=1): all_costs = [] for i in range(0, len(X), self.batch_size): # forward - idx = torch.LongTensor(permutation[i : i + self.batch_size]) + idx = torch.LongTensor( + permutation[i : i + self.batch_size] + ) if isinstance(X, torch.cuda.FloatTensor): idx = idx.cuda() inputs = Variable(X.index_select(0, idx)) @@ -191,7 +199,9 @@ def trainepoch(self, X, y, nepoches=1): # Update parameters self.optimizer.step() # measure accuracy and record loss - prec1, prec5 = self.accuracy(outputs.data, targets.data, topk=(1, 5)) + prec1, prec5 = self.accuracy( + outputs.data, targets.data, topk=(1, 5) + ) losses.update(loss.data[0], inputs.size(0)) top1.update(prec1[0], inputs.size(0)) top5.update(prec5[0], inputs.size(0)) @@ -225,13 +235,18 @@ def score(self, devX, devy): self.model.eval() correct = 0 if self.cuda and ( - not isinstance(devX, torch.cuda.FloatTensor) or self.cudaEfficient + not isinstance(devX, torch.cuda.FloatTensor) + or self.cudaEfficient ): devX = torch.FloatTensor(devX).cuda() devy = torch.LongTensor(devy).cuda() for i in range(0, len(devX), self.batch_size): - Xbatch = Variable(devX[i : i + self.batch_size], volatile=True) - ybatch = Variable(devy[i : i + self.batch_size], volatile=True) + Xbatch = Variable( + devX[i : i + self.batch_size], volatile=True + ) + ybatch = Variable( + devy[i : i + self.batch_size], volatile=True + ) if self.cudaEfficient: Xbatch = Xbatch.cuda() ybatch = ybatch.cuda() @@ -247,9 +262,13 @@ def predict(self, devX): devX = torch.FloatTensor(devX).cuda() yhat = np.array([]) for i in range(0, len(devX), self.batch_size): - Xbatch = Variable(devX[i : i + self.batch_size], volatile=True) + Xbatch = Variable( + devX[i : i + self.batch_size], volatile=True + ) output = self.model(Xbatch) - yhat = np.append(yhat, output.data.max(1)[1].cpu().numpy()) + yhat = np.append( + yhat, output.data.max(1)[1].cpu().numpy() + ) yhat = np.vstack(yhat) return yhat @@ -257,12 +276,16 @@ def predict_proba(self, devX): self.model.eval() probas = [] for i in range(0, len(devX), self.batch_size): - Xbatch = Variable(devX[i : i + self.batch_size], volatile=True) + Xbatch = Variable( + devX[i : i + self.batch_size], volatile=True + ) if not probas: probas = self.model(Xbatch).data.cpu().numpy() else: probas = np.concatenate( - probas, self.model(Xbatch).data.cpu().numpy(), axis=0 + probas, + self.model(Xbatch).data.cpu().numpy(), + axis=0, ) return probas @@ -275,7 +298,9 @@ def accuracy(self, output, target, topk=(1,)): correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: - correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) + correct_k = ( + correct[:k].view(-1).float().sum(0, keepdim=True) + ) res.append(correct_k.mul_(100.0 / batch_size)) return res @@ -313,7 +338,13 @@ def __init__( cudaEfficient=False, ): super(self.__class__, self).__init__( - inputdim, nclasses, l2reg, batch_size, seed, cuda, cudaEfficient + inputdim, + nclasses, + l2reg, + batch_size, + seed, + cuda, + cudaEfficient, ) self.cuda = cuda self.model = nn.Sequential( @@ -323,7 +354,9 @@ def __init__( self.model.cuda() self.loss_fn = nn.CrossEntropyLoss().cuda() self.loss_fn.size_average = False - self.optimizer = optim.Adam(self.model.parameters(), weight_decay=self.l2reg) + self.optimizer = optim.Adam( + self.model.parameters(), weight_decay=self.l2reg + ) class MLP(PyTorchClassifier): @@ -376,7 +409,9 @@ def __init__( self.loss_fn = nn.CrossEntropyLoss().cuda() self.loss_fn.size_average = False - self.optimizer = optim.Adam(self.model.parameters(), weight_decay=self.l2reg) + self.optimizer = optim.Adam( + self.model.parameters(), weight_decay=self.l2reg + ) class SENN_MLP(PyTorchClassifier): @@ -419,7 +454,9 @@ def __init__( nn.Linear(self.hiddendim, self.nclasses), ) - self.model = SENN_FFFC(self.inputdim, self.hiddendim, self.nclasses) + self.model = SENN_FFFC( + self.inputdim, self.hiddendim, self.nclasses + ) self.cuda = cuda @@ -433,4 +470,6 @@ def __init__( elif regularization.lower() == "tv": self.reg_loss = tv_reg_loss - self.optimizer = optim.Adam(self.model.parameters(), weight_decay=self.l2reg) + self.optimizer = optim.Adam( + self.model.parameters(), weight_decay=self.l2reg + ) diff --git a/BDD_OIA/SENN/eval_utils.py b/BDD_OIA/SENN/eval_utils.py index f0521b9..e72089e 100644 --- a/BDD_OIA/SENN/eval_utils.py +++ b/BDD_OIA/SENN/eval_utils.py @@ -15,16 +15,17 @@ You should have received a copy of the GNU General Public License along with this program. If not, see . """ -import sys import pdb +import sys + import numpy as np -from tqdm import tqdm import torch -from torch.autograd import Variable -from torch.utils.data import TensorDataset, DataLoader import torch.nn as nn import torch.nn.functional as F import torch.optim as optim +from torch.autograd import Variable +from torch.utils.data import DataLoader, TensorDataset +from tqdm import tqdm # =============================================================================== # ============= CONTINUOUS SPACE VERSIONS =========================== @@ -80,7 +81,9 @@ def local_lipschitz_estimate( z = Variable(torch.randn(x.size()), requires_grad=True) if cuda: z = z.cuda() - progress_string = "\rStep: {:8}/{:8} L:{:5.2f} Improv.:{:6.2f}" + progress_string = ( + "\rStep: {:8}/{:8} L:{:5.2f} Improv.:{:6.2f}" + ) if mode == 1: # fx = f(x).detach() @@ -115,7 +118,9 @@ def local_lipschitz_estimate( fz = f.thetas dist_f = (fz - fx).norm() dist_x = (z - x).norm() - loss = dist_x / dist_f # Want to maximize d_f/d_x (reciprocal) + loss = ( + dist_x / dist_f + ) # Want to maximize d_f/d_x (reciprocal) else: _ = f(z) fz = f.thetas @@ -149,7 +154,14 @@ def local_lipschitz_estimate( if i % log_interval == 0: if eps is not None: - prog_list = [i, maxit, loss.data[0], lip, dist, improvements[-1]] + prog_list = [ + i, + maxit, + loss.data[0], + lip, + dist, + improvements[-1], + ] else: prog_list = [i, maxit, lip, improvements[-1]] @@ -179,7 +191,11 @@ def local_lipschitz_estimate( print("Estimated Lipschitz constant: {:8.2f}".format(lip)) if eps is not None and verbose: if mode == 1: - print("|| x - z || = {:8.2f} < {:8.2f}".format((z - x).norm().data[0], eps)) + print( + "|| x - z || = {:8.2f} < {:8.2f}".format( + (z - x).norm().data[0], eps + ) + ) else: print( "|| g(x) - g(z) || = {:8.2f} < {:8.2f}".format( @@ -217,7 +233,9 @@ def estimate_dataset_lipschitz( inputs = inputs.cuda() # print(inputs.size()) # print(asd.asd) - inputs = Variable(inputs) # targets = Variable(inputs), Variable(targets) + inputs = Variable( + inputs + ) # targets = Variable(inputs), Variable(targets) l, _ = local_lipschitz_estimate( model, inputs, @@ -320,7 +338,9 @@ def sample_local_lipschitz( if max_distance is not None: # Distances above threshold: make them inf # print((denom_dists > max_distance).size()) - nonzero = torch.nonzero((denom_dists > max_distance).data).size(0) + nonzero = torch.nonzero( + (denom_dists > max_distance).data + ).size(0) total = denom_dists.size(0) ** 2 print( "Number of zero denom distances: {} ({:4.2f}%)".format( @@ -333,7 +353,10 @@ def sample_local_lipschitz( ratios = (num_dists / denom_dists).data argmaxes = {k: [] for k in range(n)} vals, inds = ratios.topk(top_k, 1, True, True) - argmaxes = {i: [(j, v) for (j, v) in zip(inds[i, :], vals[i, :])] for i in range(n)} + argmaxes = { + i: [(j, v) for (j, v) in zip(inds[i, :], vals[i, :])] + for i in range(n) + } return vals[:, 0].numpy(), argmaxes # @@ -390,7 +413,9 @@ def lipschitz_ratio(model, x, y, Th_x=None, mode=1): return ratio, num, denom -def find_maximum_lipschitz_dataset(model, dataset, top_k=1, max_distance=None): +def find_maximum_lipschitz_dataset( + model, dataset, top_k=1, max_distance=None +): """ Find pair of points x and y in dataset that maximize relative variation of model diff --git a/BDD_OIA/SENN/utils.py b/BDD_OIA/SENN/utils.py index 5c28bbf..3a1b7ea 100644 --- a/BDD_OIA/SENN/utils.py +++ b/BDD_OIA/SENN/utils.py @@ -18,15 +18,14 @@ import os import pdb -import numpy as np - -import matplotlib.pyplot as plt import pprint # For feature explainer +import matplotlib.gridspec as gridspec +import matplotlib.pyplot as plt +import numpy as np import torch from torch.autograd import Variable from torchvision.utils import make_grid -import matplotlib.gridspec as gridspec pp = pprint.PrettyPrinter(indent=4) @@ -104,9 +103,17 @@ def animate_training(Steps, Cs, X_train, y_train): cm = plt.cm.RdBu cm_bright = ListedColormap(["#FF0000", "#0000FF"]) h = 0.02 # step size in the mesh - x_min, x_max = X_train[:, 0].min() - 0.5, X_train[:, 0].max() + 0.5 - y_min, y_max = X_train[:, 1].min() - 0.5, X_train[:, 1].max() + 0.5 - xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h)) + x_min, x_max = ( + X_train[:, 0].min() - 0.5, + X_train[:, 0].max() + 0.5, + ) + y_min, y_max = ( + X_train[:, 1].min() - 0.5, + X_train[:, 1].max() + 0.5, + ) + xx, yy = np.meshgrid( + np.arange(x_min, x_max, h), np.arange(y_min, y_max, h) + ) fig = plt.figure() ax = fig.add_subplot(111) @@ -114,7 +121,13 @@ def animate_training(Steps, Cs, X_train, y_train): (line2,) = ax.plot([], [], "--") ax.set_xlim(np.min(xx), np.max(xx)) ax.set_xlim(np.min(yy), np.max(yy)) - ax.scatter(X_train[:, 0], X_train[:, 1], c=y_train, cmap=cm_bright, edgecolors="k") + ax.scatter( + X_train[:, 0], + X_train[:, 1], + c=y_train, + cmap=cm_bright, + edgecolors="k", + ) # ax.contourf(xx, yy, Cs[0].reshape(xx.shape)) anim = animation.FuncAnimation( @@ -150,7 +163,9 @@ def make_meshgrid(x, y, h=0.02): """ x_min, x_max = x.min() - 1, x.max() + 1 y_min, y_max = y.min() - 1, y.max() + 1 - xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h)) + xx, yy = np.meshgrid( + np.arange(x_min, x_max, h), np.arange(y_min, y_max, h) + ) return xx, yy @@ -200,7 +215,9 @@ def plot_embedding(X, y, Xp, title=None): continue shown_images = np.r_[shown_images, [Xp[i]]] imagebox = offsetbox.AnnotationBbox( - offsetbox.OffsetImage(X[i, :].reshape(28, 28), cmap=plt.cm.gray_r), + offsetbox.OffsetImage( + X[i, :].reshape(28, 28), cmap=plt.cm.gray_r + ), Xp[i], ) ax.add_artist(imagebox) @@ -209,7 +226,9 @@ def plot_embedding(X, y, Xp, title=None): plt.title(title) -def _explain_class(model, x_raw, x, k, typ="pos", thresh=0.5, recompute=True): +def _explain_class( + model, x_raw, x, k, typ="pos", thresh=0.5, recompute=True +): """ Given an input x and class index k, explain f(x) by returning indices of features in x that have highest positive impact on predicting class k. @@ -218,9 +237,13 @@ def _explain_class(model, x_raw, x, k, typ="pos", thresh=0.5, recompute=True): y = model(x) # y = self.model(x) B_k = model.params[0, k, :].data.numpy() if typ == "pos": - Mask = (B_k > thresh).astype(np.int).reshape(x.size()).squeeze() + Mask = ( + (B_k > thresh).astype(np.int).reshape(x.size()).squeeze() + ) elif typ == "neg": - Mask = (B_k < -thresh).astype(np.int).reshape(x.size()).squeeze() + Mask = ( + (B_k < -thresh).astype(np.int).reshape(x.size()).squeeze() + ) else: # Return weights instead of mask return B_k.reshape(x.size()).squeeze() @@ -241,7 +264,10 @@ def explain_digit(model, x_raw, x, thresh=0.5, save_path=None): plt.yticks([]) if save_path: plt.savefig( - save_path + "_input.pdf", bbox_inches="tight", format="pdf", dpi=300 + save_path + "_input.pdf", + bbox_inches="tight", + format="pdf", + dpi=300, ) plt.show() y_pred = model(x) @@ -249,13 +275,21 @@ def explain_digit(model, x_raw, x, thresh=0.5, save_path=None): pred_class = np.argmax(y_pred.data.numpy()) print("Predicted: ", pred_class) - fig, ax = plt.subplots(3, model.dout, figsize=(1.5 * model.dout, 1.5 * 3)) + fig, ax = plt.subplots( + 3, model.dout, figsize=(1.5 * model.dout, 1.5 * 3) + ) for i in range(model.dout): # print('Class {}:'.format(i)) # Positive x_imask = _explain_class( - model, x_raw, x, i, typ="pos", recompute=False, thresh=thresh + model, + x_raw, + x, + i, + typ="pos", + recompute=False, + thresh=thresh, ) ax[0, i].imshow(x_imask) ax[0, i].set_xticks([]) @@ -264,14 +298,22 @@ def explain_digit(model, x_raw, x, thresh=0.5, save_path=None): # Negative x_imask = _explain_class( - model, x_raw, x, i, typ="neg", recompute=False, thresh=thresh + model, + x_raw, + x, + i, + typ="neg", + recompute=False, + thresh=thresh, ) ax[1, i].imshow(x_imask) ax[1, i].set_xticks([]) ax[1, i].set_yticks([]) # Combined - x_imask = _explain_class(model, x_raw, x, i, typ="both", recompute=False) + x_imask = _explain_class( + model, x_raw, x, i, typ="both", recompute=False + ) ax[2, i].imshow(x_imask, cmap=plt.cm.RdBu) ax[2, i].set_xticks([]) ax[2, i].set_yticks([]) @@ -284,7 +326,12 @@ def explain_digit(model, x_raw, x, thresh=0.5, save_path=None): ax[2, 0].set_ylabel("Combined") if save_path: - plt.savefig(save_path + "_expl.pdf", bbox_inches="tight", format="pdf", dpi=300) + plt.savefig( + save_path + "_expl.pdf", + bbox_inches="tight", + format="pdf", + dpi=300, + ) plt.show() @@ -344,7 +391,12 @@ def plot_text_explanation(words, values, n_cols=6, save_path=None): plt.axis("off") plt.grid("off") if save_path: - plt.savefig(save_path + "_expl.pdf", bbox_inches="tight", format="pdf", dpi=300) + plt.savefig( + save_path + "_expl.pdf", + bbox_inches="tight", + format="pdf", + dpi=300, + ) plt.show() @@ -358,7 +410,13 @@ class FeatureInput_Explainer: """ - def __init__(self, feature_names, binary=False, sort_rows=True, scale_values=True): + def __init__( + self, + feature_names, + binary=False, + sort_rows=True, + scale_values=True, + ): super(FeatureInput_Explainer, self).__init__() self.features = feature_names self.binary = binary # Whether it is a binary classif task @@ -380,8 +438,12 @@ def explain(self, model, x, thresh=0.5, save_path=None): # Get data-dependent params B = model.thetas[0, :, :].data.numpy() # class x feats - Pos_Mask = (B > thresh).astype(np.int) # .reshape(x.size()).squeeze() - Neg_Mask = (B < thresh).astype(np.int) # .reshape(x.size()).squeeze() + Pos_Mask = (B > thresh).astype( + np.int + ) # .reshape(x.size()).squeeze() + Neg_Mask = (B < thresh).astype( + np.int + ) # .reshape(x.size()).squeeze() title = r"Relevance Score $\theta(x)$" + ( " (Scaled)" if self.scale_values else "" @@ -391,7 +453,10 @@ def explain(self, model, x, thresh=0.5, save_path=None): zip(self.features, B[:, 0]) ) # Change to B[0,:] when B model is truly binary A = plot_dependencies( - d, title=title, scale_values=self.scale_values, sort_rows=self.sort_rows + d, + title=title, + scale_values=self.scale_values, + sort_rows=self.sort_rows, ) else: Pos_Feats = {} @@ -405,20 +470,33 @@ def explain(self, model, x, thresh=0.5, save_path=None): scale_values=self.scale_values, sort_rows=self.sort_rows, ) - Neg_Feats = list(compress(self.features, B[k, :] < -thresh)) - Pos_Feats = list(compress(self.features, B[k, :] > thresh)) + Neg_Feats = list( + compress(self.features, B[k, :] < -thresh) + ) + Pos_Feats = list( + compress(self.features, B[k, :] > thresh) + ) print( "Class:{:5} Neg: {}, Pos: {}".format( k, ",".join(Neg_Feats), ",".join(Pos_Feats) ) ) if save_path: - plt.savefig(save_path, bbox_inches="tight", format="pdf", dpi=300) + plt.savefig( + save_path, bbox_inches="tight", format="pdf", dpi=300 + ) plt.show() print("-" * 60) def _explain_class( - self, x_raw, x, k, typ="pos", feat_names=None, thresh=0.5, recompute=True + self, + x_raw, + x, + k, + typ="pos", + feat_names=None, + thresh=0.5, + recompute=True, ): """ Given an input x and class index k, explain f(x) by returning indices of @@ -435,9 +513,19 @@ def _explain_class( elif feat_names and typ == "neg": return list(compress(feat_names, B_k < thresh)) if typ == "pos": - Mask = (B_k > thresh).astype(np.int).reshape(x.size()).squeeze() + Mask = ( + (B_k > thresh) + .astype(np.int) + .reshape(x.size()) + .squeeze() + ) elif typ == "neg": - Mask = (B_k < -thresh).astype(np.int).reshape(x.size()).squeeze() + Mask = ( + (B_k < -thresh) + .astype(np.int) + .reshape(x.size()) + .squeeze() + ) else: # Return weights instead of mask return B_k.reshape(x.size()).squeeze() @@ -482,7 +570,9 @@ def plot_dependencies( # get maximum maximum_value = np.absolute(np.array(coefficient_values)).max() if scale_values: - coefficient_values = (np.array(coefficient_values) / maximum_value) * 100 + coefficient_values = ( + np.array(coefficient_values) / maximum_value + ) * 100 if sort_rows: index_sorted = np.argsort(np.array(coefficient_values)) @@ -490,7 +580,9 @@ def plot_dependencies( index_sorted = range(len(coefficient_values))[::-1] sorted_column_names = list(np.array(column_names)[index_sorted]) - sorted_column_values = list(np.array(coefficient_values)[index_sorted]) + sorted_column_values = list( + np.array(coefficient_values)[index_sorted] + ) pos = np.arange(len(sorted_column_values)) + 0.7 # rearrange this at some other point. @@ -504,7 +596,10 @@ def assign_colors_to_bars( # if you want the colors to be reversed for positive # and negative influences. if reverse: - pos_influence, negative_influence = (negative_influence, pos_influence) + pos_influence, negative_influence = ( + negative_influence, + pos_influence, + ) # could rewrite this as a lambda function # but I understand this better @@ -517,7 +612,9 @@ def map_x(x): bar_colors = list(map(map_x, array_values)) return bar_colors - bar_colors = assign_colors_to_bars(coefficient_values, reverse=True) + bar_colors = assign_colors_to_bars( + coefficient_values, reverse=True + ) bar_colors = list(np.array(bar_colors)[index_sorted]) # pdb.set_trace() @@ -528,7 +625,9 @@ def map_x(x): fig, axes = plt.subplots(1, 2, figsize=fig_size) ax_table, ax = axes - ax.barh(pos, sorted_column_values, align="center", color=bar_colors) + ax.barh( + pos, sorted_column_values, align="center", color=bar_colors + ) ax.set_yticks(pos) ax.set_yticklabels(sorted_column_names) if scale_values: @@ -570,7 +669,12 @@ def map_x(x): def plot_theta_stability( - model, input, pert_type="gauss", noise_level=0.5, samples=5, save_path=None + model, + input, + pert_type="gauss", + noise_level=0.5, + samples=5, + save_path=None, ): """Test stability of relevance scores theta for perturbations of an input. @@ -604,7 +708,9 @@ def gauss_perturbation(x, scale=1): for i in range(samples): inputs.append(gauss_perturbation(input, scale=noise_level)) - fig, ax = plt.subplots(2, len(inputs), figsize=(2 * len(inputs), 1.5 * 3)) + fig, ax = plt.subplots( + 2, len(inputs), figsize=(2 * len(inputs), 1.5 * 3) + ) # Map Them thetas = [] @@ -632,7 +738,9 @@ def gauss_perturbation(x, scale=1): thetas.append(deps) classes = ["C" + str(i) for i in range(theta.shape[0])] d = dict(zip(classes, deps)) - A = plot_dependencies(d, title="Dependencies", sort_rows=False, ax=ax[1, i]) + A = plot_dependencies( + d, title="Dependencies", sort_rows=False, ax=ax[1, i] + ) # ax[1,i].locator_params(axis = 'y', nbins=10) # max_yticks = 10 @@ -646,7 +754,9 @@ def gauss_perturbation(x, scale=1): plt.tight_layout() # print(dists.max()) if save_path: - plt.savefig(save_path, bbox_inches="tight", format="pdf", dpi=300) + plt.savefig( + save_path, bbox_inches="tight", format="pdf", dpi=300 + ) # plt.show(block=False) @@ -675,7 +785,9 @@ def concept_grid( num_concepts = model.parametrizer.nconcept concept_dim = model.parametrizer.dout - top_activations = {k: np.array(top_k * [-1000.00]) for k in range(num_concepts)} + top_activations = { + k: np.array(top_k * [-1000.00]) for k in range(num_concepts) + } top_examples = {k: top_k * [None] for k in range(num_concepts)} all_activs = [] for idx, (data, target, _) in enumerate(data_loader): @@ -687,7 +799,9 @@ def concept_grid( data, target = data.cuda(), target.cuda() data, target = Variable(data), Variable(target) """ - data, target = Variable(data).to(device), Variable(target).to(device) + data, target = Variable(data).to(device), Variable(target).to( + device + ) # data, target = Variable(data, volatile=True), Variable(target) pretrained_out = pretrained_model(data) @@ -709,8 +823,12 @@ def concept_grid( # break all_activs = torch.cat(all_activs) - top_activations, top_idxs = torch.topk(all_activs, int(top_k / 2), 0) - low_activations, low_idxs = torch.topk(-all_activs, int(top_k / 2), 0) + top_activations, top_idxs = torch.topk( + all_activs, int(top_k / 2), 0 + ) + low_activations, low_idxs = torch.topk( + -all_activs, int(top_k / 2), 0 + ) top_activations = top_activations.squeeze().t() low_activations = low_activations.squeeze().t() top_idxs = top_idxs.squeeze().t() @@ -719,20 +837,28 @@ def concept_grid( top_idxs = torch.cat([top_idxs, low_idxs], dim=1) for i in range(num_concepts): buf = data_loader.dataset[top_idxs[i][0]][0] - buf = buf.reshape([1, buf.shape[1], buf.shape[2], buf.shape[0]]) + buf = buf.reshape( + [1, buf.shape[1], buf.shape[2], buf.shape[0]] + ) for k in range(3): - buf[:, :, :, k] = data_loader.dataset[top_idxs[i][0]][0][k] - buf[:, :, :, k] = (buf[:, :, :, k] - buf[:, :, :, k].min()) / ( - buf[:, :, :, k].max() - buf[:, :, :, k].min() - ) + buf[:, :, :, k] = data_loader.dataset[top_idxs[i][0]][0][ + k + ] + buf[:, :, :, k] = ( + buf[:, :, :, k] - buf[:, :, :, k].min() + ) / (buf[:, :, :, k].max() - buf[:, :, :, k].min()) for j in range(1, top_k): buf2 = data_loader.dataset[top_idxs[i][j]][0] - buf2 = buf2.reshape([1, buf2.shape[1], buf2.shape[2], buf2.shape[0]]) + buf2 = buf2.reshape( + [1, buf2.shape[1], buf2.shape[2], buf2.shape[0]] + ) for k in range(3): - buf2[:, :, :, k] = data_loader.dataset[top_idxs[i][j]][0][k] - buf2[:, :, :, k] = (buf2[:, :, :, k] - buf2[:, :, :, k].min()) / ( - buf2[:, :, :, k].max() - buf2[:, :, :, k].min() - ) + buf2[:, :, :, k] = data_loader.dataset[ + top_idxs[i][j] + ][0][k] + buf2[:, :, :, k] = ( + buf2[:, :, :, k] - buf2[:, :, :, k].min() + ) / (buf2[:, :, :, k].max() - buf2[:, :, :, k].min()) buf = np.concatenate([buf, buf2], axis=0) top_examples[i] = buf @@ -766,7 +892,9 @@ def concept_grid( num_rows = top_k figsize = (1.4 * num_cols, num_rows) - fig, axes = plt.subplots(figsize=figsize, nrows=num_rows, ncols=num_cols) + fig, axes = plt.subplots( + figsize=figsize, nrows=num_rows, ncols=num_cols + ) for i in range(num_concepts): for j in range(top_k): @@ -776,11 +904,15 @@ def concept_grid( # print(i,j) # print(top_examples[i][j].shape) # axes[pos].imshow(top_examples[i][j], cmap='Greys', interpolation='nearest') - axes[pos].imshow(top_examples[i][j], interpolation="nearest") + axes[pos].imshow( + top_examples[i][j], interpolation="nearest" + ) if layout == "vertical": axes[pos].axis("off") if j == 0: - axes[pos].set_title("Cpt {}".format(i + 1), fontsize=24) + axes[pos].set_title( + "Cpt {}".format(i + 1), fontsize=24 + ) else: axes[pos].set_xticklabels([]) axes[pos].set_yticklabels([]) @@ -791,7 +923,9 @@ def concept_grid( if i == 0: axes[pos].set_title("Proto {}".format(j + 1)) if j == 0: - axes[pos].set_ylabel("Concept {}".format(i + 1), rotation=90) + axes[pos].set_ylabel( + "Concept {}".format(i + 1), rotation=90 + ) print("Done") @@ -811,7 +945,9 @@ def concept_grid( fig.subplots_adjust(wspace=0.1, hspace=0.01) if save_path is not None: - plt.savefig(save_path, bbox_inches="tight", format="pdf", dpi=300) + plt.savefig( + save_path, bbox_inches="tight", format="pdf", dpi=300 + ) plt.show() if return_fig: return fig, axes @@ -828,7 +964,9 @@ def plot_prob_drop(attribs, prob_drop, save_path=None): color1 = "#377eb8" ax1.bar(ind + width + 0.35, attribs, 0.45, color=color1) - ax1.set_ylabel(r"Feature Relevance $\theta(x)_i$", color=color1, fontsize=14) + ax1.set_ylabel( + r"Feature Relevance $\theta(x)_i$", color=color1, fontsize=14 + ) # ax1.set_ylim(-1,1) ax1.set_xlabel("Feature") ax1.tick_params(axis="y", colors=color1) @@ -836,7 +974,13 @@ def plot_prob_drop(attribs, prob_drop, save_path=None): color2 = "#ff7f00" ax2 = ax1.twinx() ax2.ticklabel_format(style="sci", scilimits=(-2, 2), axis="y") - ax2.plot(ind + width + 0.35, prob_drop, "bo", linestyle="dashed", color=color2) + ax2.plot( + ind + width + 0.35, + prob_drop, + "bo", + linestyle="dashed", + color=color2, + ) ax2.set_ylabel("Probability Drop", color=color2, fontsize=14) ax2.tick_params(axis="y", colors=color2) @@ -845,7 +989,9 @@ def plot_prob_drop(attribs, prob_drop, save_path=None): fig.tight_layout() if save_path: - plt.savefig(save_path, bbox_inches="tight", format="pdf", dpi=300) + plt.savefig( + save_path, bbox_inches="tight", format="pdf", dpi=300 + ) plt.show() diff --git a/BDD_OIA/aggregators.py b/BDD_OIA/aggregators.py index 79a84fd..9747865 100644 --- a/BDD_OIA/aggregators.py +++ b/BDD_OIA/aggregators.py @@ -3,8 +3,9 @@ Aggregator """ -import sys import pdb +import sys + import torch import torch.nn as nn import torch.nn.functional as F @@ -53,7 +54,9 @@ def forward: """ def forward(self, H, Th): - assert H.size(-1) == 1, "Concept h_i should be scalar, not vector sized" + assert ( + H.size(-1) == 1 + ), "Concept h_i should be scalar, not vector sized" buf = torch.reshape(H, (H.data.shape[0], H.data.shape[1])) # if self.binary is true, output activation uses sigmoid, otherwise log_softmax if self.binary: @@ -109,8 +112,12 @@ def forward: """ def forward(self, H, Th): - assert H.size(-2) == Th.size(-2), "Number of concepts in H and Th don't match" - assert H.size(-1) == 1, "Concept h_i should be scalar, not vector sized" + assert H.size(-2) == Th.size( + -2 + ), "Number of concepts in H and Th don't match" + assert ( + H.size(-1) == 1 + ), "Concept h_i should be scalar, not vector sized" assert Th.size(-1) == self.nclasses, "Wrong Theta size" combined = torch.bmm(Th.transpose(1, 2), H).squeeze(dim=-1) diff --git a/BDD_OIA/aggregators_BDD.py b/BDD_OIA/aggregators_BDD.py index 79a84fd..9747865 100644 --- a/BDD_OIA/aggregators_BDD.py +++ b/BDD_OIA/aggregators_BDD.py @@ -3,8 +3,9 @@ Aggregator """ -import sys import pdb +import sys + import torch import torch.nn as nn import torch.nn.functional as F @@ -53,7 +54,9 @@ def forward: """ def forward(self, H, Th): - assert H.size(-1) == 1, "Concept h_i should be scalar, not vector sized" + assert ( + H.size(-1) == 1 + ), "Concept h_i should be scalar, not vector sized" buf = torch.reshape(H, (H.data.shape[0], H.data.shape[1])) # if self.binary is true, output activation uses sigmoid, otherwise log_softmax if self.binary: @@ -109,8 +112,12 @@ def forward: """ def forward(self, H, Th): - assert H.size(-2) == Th.size(-2), "Number of concepts in H and Th don't match" - assert H.size(-1) == 1, "Concept h_i should be scalar, not vector sized" + assert H.size(-2) == Th.size( + -2 + ), "Number of concepts in H and Th don't match" + assert ( + H.size(-1) == 1 + ), "Concept h_i should be scalar, not vector sized" assert Th.size(-1) == self.nclasses, "Wrong Theta size" combined = torch.bmm(Th.transpose(1, 2), H).squeeze(dim=-1) diff --git a/BDD_OIA/conceptizers_BDD.py b/BDD_OIA/conceptizers_BDD.py index 06312b1..c5f9dfb 100644 --- a/BDD_OIA/conceptizers_BDD.py +++ b/BDD_OIA/conceptizers_BDD.py @@ -2,16 +2,14 @@ """ Conceptizers """ -import sys import pdb -import numpy as np +import sys +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F - -from torch.autograd import Function -from torch.autograd import Variable +from torch.autograd import Function, Variable # =============================================================================== # ======================= MODELS FOR IMAGES ========================= @@ -98,7 +96,9 @@ def __init__: None """ - def __init__(self, din, nconcept, nconcept_labeled, cdim, sparsity, senn): + def __init__( + self, din, nconcept, nconcept_labeled, cdim, sparsity, senn + ): super(image_fcc_conceptizer, self).__init__() # set self hyperparameters @@ -110,7 +110,9 @@ def __init__(self, din, nconcept, nconcept_labeled, cdim, sparsity, senn): self.sparsity = sparsity # Number of kWTA self.nconcept = nconcept # Number of all concepts - self.nconcept_labeled = nconcept_labeled # Number of unknown concepts + self.nconcept_labeled = ( + nconcept_labeled # Number of unknown concepts + ) self.senn = senn # flag of senn @@ -126,14 +128,18 @@ def __init__(self, din, nconcept, nconcept_labeled, cdim, sparsity, senn): if senn == True: self.enc = nn.Linear(self.din, self.nconcept) else: - self.enc1_original = nn.Linear(self.din, self.nconcept_labeled) + self.enc1_original = nn.Linear( + self.din, self.nconcept_labeled + ) self.enc1 = nn.Linear(self.din, 512) self.enc_middle_f = nn.Linear(512, 3 * 2) self.enc_middle_s = nn.Linear(512, 6 * 2) self.enc_middle_l = nn.Linear(512, 6 * 2) self.enc_middle_r = nn.Linear(512, 6 * 2) # self.enc_final = nn.Linear(256, self.nconcept_labeled * 2) - self.enc2 = nn.Linear(self.din, self.nconcept - self.nconcept_labeled) + self.enc2 = nn.Linear( + self.din, self.nconcept - self.nconcept_labeled + ) self.relu = nn.ReLU() @@ -174,12 +180,16 @@ def encode(self, x): """ k = self.sparsity topval = encoded.topk(k, dim=1)[0][:, -1] - topval = topval.expand(encoded.shape[1], encoded.shape[0]).permute(1, 0) + topval = topval.expand( + encoded.shape[1], encoded.shape[0] + ).permute(1, 0) comp = (encoded >= topval).to(encoded) encoded = comp * encoded # reshape for the following process - encoded = encoded.reshape([encoded.shape[0], encoded.shape[1], 1]) + encoded = encoded.reshape( + [encoded.shape[0], encoded.shape[1], 1] + ) # aling the return's shape encoded_1 = encoded @@ -202,7 +212,8 @@ def encode(self, x): # concatenate them logits_c = torch.cat( - [logits_c_f, logits_c_s, logits_c_l, logits_c_r], dim=1 + [logits_c_f, logits_c_s, logits_c_l, logits_c_r], + dim=1, ) """ @@ -214,7 +225,9 @@ def encode(self, x): all_logits = [] for i in range(len(logits_split)): - logit1 = torch.softmax(logits_split[i], dim=-1)[:, 1].view(-1, 1) + logit1 = torch.softmax(logits_split[i], dim=-1)[ + :, 1 + ].view(-1, 1) all_logits.append(logit1) encoded_1 = torch.cat(all_logits, dim=-1) @@ -232,12 +245,16 @@ def encode(self, x): """ k = self.sparsity topval = encoded_2.topk(k, dim=1)[0][:, -1] - topval = topval.expand(encoded_2.shape[1], encoded_2.shape[0]).permute(1, 0) + topval = topval.expand( + encoded_2.shape[1], encoded_2.shape[0] + ).permute(1, 0) comp = (encoded_2 >= topval).to(encoded_2) encoded_2 = comp * encoded_2 # reshape for the following process - encoded_2 = encoded_2.reshape([encoded_2.shape[0], encoded_2.shape[1], 1]) + encoded_2 = encoded_2.reshape( + [encoded_2.shape[0], encoded_2.shape[1], 1] + ) return encoded_1, encoded_2 @@ -316,7 +333,13 @@ def __init__: """ def __init__( - self, din, nconcept, nconcept_labeled, cdim=None, nchannel=1, sparsity=1 + self, + din, + nconcept, + nconcept_labeled, + cdim=None, + nchannel=1, + sparsity=1, ): super(image_cnn_conceptizer, self).__init__() self.din = din # Input dimension @@ -347,14 +370,20 @@ def __init__( self.linear1 = nn.ModuleList() for i in range(self.nconcept_labeled): self.linear1.append(nn.Linear(self.dout**2, self.cdim)) - self.linear2 = nn.Linear(self.dout**2, self.cdim) # b, nconcepts, cdim + self.linear2 = nn.Linear( + self.dout**2, self.cdim + ) # b, nconcepts, cdim # Decoding - self.unlinear = nn.Linear(self.cdim, self.dout**2) # b, nconcepts, dout*2 + self.unlinear = nn.Linear( + self.cdim, self.dout**2 + ) # b, nconcepts, dout*2 self.deconv3 = nn.ConvTranspose2d( nconcept, 16, 5, stride=2 ) # b, 16, (dout-1)*2 + 5, 5 - self.deconv2 = nn.ConvTranspose2d(16, 8, 5) # b, 8, (dout -1)*2 + 9 + self.deconv2 = nn.ConvTranspose2d( + 16, 8, 5 + ) # b, 8, (dout -1)*2 + 9 self.deconv1 = nn.ConvTranspose2d( 8, nchannel, 2, stride=2, padding=1 ) # b, nchannel, din, din @@ -380,15 +409,25 @@ def encode(self, x): encoded_1 = [] for fc in self.linear1: encoded_1.append( - fc(p_1.view(-1, self.nconcept_labeled, self.dout**2)[:, cnt]) + fc( + p_1.view( + -1, self.nconcept_labeled, self.dout**2 + )[:, cnt] + ) ) cnt = cnt + 1 # compute unknown concepts encoded_2 = self.linear2( - p_2.view(-1, self.nconcept - self.nconcept_labeled, self.dout**2) + p_2.view( + -1, + self.nconcept - self.nconcept_labeled, + self.dout**2, + ) + ) + encoded_2 = encoded_2.reshape( + [encoded_2.shape[0], encoded_2.shape[1]] ) - encoded_2 = encoded_2.reshape([encoded_2.shape[0], encoded_2.shape[1]]) """ kWTA: https://github.com/a554b554/kWTA-Activation/ @@ -401,10 +440,14 @@ def encode(self, x): """ k = self.sparsity topval = encoded_2.topk(k, dim=1)[0][:, -1] - topval = topval.expand(encoded_2.shape[1], encoded_2.shape[0]).permute(1, 0) + topval = topval.expand( + encoded_2.shape[1], encoded_2.shape[0] + ).permute(1, 0) comp = (encoded_2 >= topval).to(encoded_2) encoded_2 = comp * encoded_2 - encoded_2 = encoded_2.reshape([encoded_2.shape[0], encoded_2.shape[1], 1]) + encoded_2 = encoded_2.reshape( + [encoded_2.shape[0], encoded_2.shape[1], 1] + ) return encoded_1, encoded_2 @@ -427,7 +470,9 @@ def decode(self, z1_list, z2): z1 = torch.cat(z1_list, dim=1) z1 = z1.view(-1, self.nconcept_labeled, self.cdim) z = torch.cat((z1, z2), dim=1) - q = self.unlinear(z).view(-1, self.nconcept, self.dout, self.dout) + q = self.unlinear(z).view( + -1, self.nconcept, self.dout, self.dout + ) q = F.relu(self.deconv3(q)) q = F.relu(self.deconv2(q)) decoded = torch.tanh(self.deconv1(q)) @@ -435,7 +480,16 @@ def decode(self, z1_list, z2): class PCBMConceptizer(image_fcc_conceptizer): - def __init__(self, din, nconcept, nconcept_labeled, cdim, sparsity, senn, device): + def __init__( + self, + din, + nconcept, + nconcept_labeled, + cdim, + sparsity, + senn, + device, + ): super(PCBMConceptizer, self).__init__( din, nconcept, nconcept_labeled, cdim, sparsity, senn ) @@ -453,10 +507,12 @@ def __init__(self, din, nconcept, nconcept_labeled, cdim, sparsity, senn, device self.shift = torch.ones(0, device=device) self.dense_logvar = nn.Linear( - in_features=self.din, out_features=self.latent_dim * self.nconcept_labeled + in_features=self.din, + out_features=self.latent_dim * self.nconcept_labeled, ) self.dense_mu = nn.Linear( - in_features=self.din, out_features=self.latent_dim * self.nconcept_labeled + in_features=self.din, + out_features=self.latent_dim * self.nconcept_labeled, ) def encode(self, x): @@ -470,22 +526,32 @@ def encode(self, x): encoded_1 = torch.sigmoid(logits_enc) - mu, logvar = self.dense_mu(logits_c), self.dense_logvar(logits_c) + mu, logvar = self.dense_mu(logits_c), self.dense_logvar( + logits_c + ) - mu = torch.stack(torch.split(mu, self.latent_dim, dim=-1), dim=1) - logvar = torch.stack(torch.split(logvar, self.latent_dim, dim=-1), dim=1) + mu = torch.stack( + torch.split(mu, self.latent_dim, dim=-1), dim=1 + ) + logvar = torch.stack( + torch.split(logvar, self.latent_dim, dim=-1), dim=1 + ) # compute unknown concepts encoded_2 = self.enc2(p) k = self.sparsity topval = encoded_2.topk(k, dim=1)[0][:, -1] - topval = topval.expand(encoded_2.shape[1], encoded_2.shape[0]).permute(1, 0) + topval = topval.expand( + encoded_2.shape[1], encoded_2.shape[0] + ).permute(1, 0) comp = (encoded_2 >= topval).to(encoded_2) encoded_2 = comp * encoded_2 # reshape for the following process - encoded_2 = encoded_2.reshape([encoded_2.shape[0], encoded_2.shape[1], 1]) + encoded_2 = encoded_2.reshape( + [encoded_2.shape[0], encoded_2.shape[1], 1] + ) return encoded_1, encoded_2, mu, logvar diff --git a/BDD_OIA/experiments.py b/BDD_OIA/experiments.py index 4e019a4..76529d0 100644 --- a/BDD_OIA/experiments.py +++ b/BDD_OIA/experiments.py @@ -1,5 +1,5 @@ -import itertools import copy +import itertools def launch_bdd(args): @@ -51,7 +51,12 @@ def launch_bdd(args): for element in itertools.product(*hyperparameters): args1 = copy.copy(args) - args1.model_name, args1.h_labeled_param, args1.w_entropy, args1.seed = element + ( + args1.model_name, + args1.h_labeled_param, + args1.w_entropy, + args1.seed, + ) = element if args1.model_name == "dpl_auc": args1.h_labeled_param = 0.01 * args1.h_labeled_param diff --git a/BDD_OIA/main_bdd.py b/BDD_OIA/main_bdd.py index f79b006..bad866d 100644 --- a/BDD_OIA/main_bdd.py +++ b/BDD_OIA/main_bdd.py @@ -1,81 +1,78 @@ # -*- coding: utf-8 -*- # Standard Imports -import sys, os -import numpy as np -import pdb -import pickle import argparse +import math import operator +import os +import pdb +import pickle +import sys + import matplotlib import matplotlib.pyplot as plt +import numpy as np import pandas as pd -import math - -from sklearn.linear_model import LinearRegression -from sklearn.metrics import ( - mean_squared_error, - accuracy_score, - f1_score, - precision_recall_fscore_support, - multilabel_confusion_matrix, -) import seaborn as sns # Torch-related import torch -from torch.utils.data import TensorDataset -from torch.autograd import Variable -import torchvision -from torchvision import transforms -from torch.utils.data.sampler import SubsetRandomSampler import torch.utils.data.dataloader as dataloader - -from sklearn.metrics import accuracy_score, f1_score, precision_score - -# Local imports -from SENN.utils import ( - plot_theta_stability, - generate_dir_names, - noise_stability_plots, - concept_grid, -) -from SENN.eval_utils import estimate_dataset_lipschitz -from SENN.arglist import get_senn_parser - - -from BDD.dataset import load_data, find_class_imbalance +import torchvision +import wandb +from aggregators_BDD import CBM_aggregator, additive_scalar_aggregator from BDD.config import ( BASE_DIR, - N_CLASSES, + LR_DECAY_SIZE, + MIN_LR, N_ATTRIBUTES, + N_CLASSES, UPWEIGHT_RATIO, - MIN_LR, - LR_DECAY_SIZE, ) - -from models import GSENN +from BDD.dataset import find_class_imbalance, load_data +from conceptizers_BDD import ( + PCBMConceptizer, + image_cnn_conceptizer, + image_fcc_conceptizer, +) from DPL.dpl import DPL from DPL.dpl_auc import DPL_AUC from DPL.dpl_auc_pcbm import DPL_AUC_PCBM -from conceptizers_BDD import ( - image_fcc_conceptizer, - image_cnn_conceptizer, - PCBMConceptizer, +from models import GSENN +from parametrizers import dfc_parametrizer, image_parametrizer +from scipy.special import softmax +from SENN.arglist import get_senn_parser +from SENN.eval_utils import estimate_dataset_lipschitz + +# Local imports +from SENN.utils import ( + concept_grid, + generate_dir_names, + noise_stability_plots, + plot_theta_stability, +) +from sklearn.linear_model import LinearRegression +from sklearn.metrics import ( + accuracy_score, + f1_score, + mean_squared_error, + multilabel_confusion_matrix, + precision_recall_fscore_support, + precision_score, ) -from parametrizers import image_parametrizer, dfc_parametrizer -from aggregators_BDD import additive_scalar_aggregator, CBM_aggregator -from trainers_BDD import GradPenaltyTrainer from testers_BDD import ClassificationTesterFactory -import wandb -from scipy.special import softmax +from torch.autograd import Variable +from torch.utils.data import TensorDataset +from torch.utils.data.sampler import SubsetRandomSampler +from torchvision import transforms +from trainers_BDD import GradPenaltyTrainer from visualization import ( - produce_confusion_matrix, + create_output_folder, + plot_grouped_entropies, produce_alpha_matrix, - produce_calibration_curve, produce_bar_plot, + produce_calibration_curve, + produce_confusion_matrix, produce_scatter_multi_class, - plot_grouped_entropies, - create_output_folder, ) @@ -87,7 +84,10 @@ def convert_to_json_serializable(obj): elif isinstance(obj, (list, tuple)): return [convert_to_json_serializable(item) for item in obj] elif isinstance(obj, dict): - return {key: convert_to_json_serializable(value) for key, value in obj.items()} + return { + key: convert_to_json_serializable(value) + for key, value in obj.items() + } elif isinstance(obj, (int, float, bool, str, type(None))): return obj else: @@ -99,9 +99,9 @@ def entropy(probabilities, n_values: int): probabilities += 1e-5 probabilities /= 1 + (n_values * 1e-5) - entropy_values = -np.sum(probabilities * np.log(probabilities), axis=1) / np.log( - n_values - ) + entropy_values = -np.sum( + probabilities * np.log(probabilities), axis=1 + ) / np.log(n_values) return entropy_values @@ -196,7 +196,9 @@ def class_mean_entropy(probabilities, true_classes, n_classes: int): class_counts = np.zeros(n_classes) for i in range(num_samples): - sample_entropy = entropy(np.expand_dims(probabilities[i], axis=0), num_classes) + sample_entropy = entropy( + np.expand_dims(probabilities[i], axis=0), num_classes + ) class_mean_entropy_values[true_classes[i]] += sample_entropy class_counts[true_classes[i]] += 1 @@ -208,7 +210,9 @@ def class_mean_entropy(probabilities, true_classes, n_classes: int): return class_mean_entropy_values -def produce_confusion_matrices(p_true, p_pred, n_values: int, mode: str, suffix: str): +def produce_confusion_matrices( + p_true, p_pred, n_values: int, mode: str, suffix: str +): sklearn_concept_labels = [str(int(el)) for el in range(n_values)] print("--- Saving the RSs Confusion Matrix ---") @@ -229,12 +233,20 @@ def produce_confusion_matrices(p_true, p_pred, n_values: int, mode: str, suffix: def _bin_initializer(num_bins: int): # Builds the bin return { - i: {"COUNT": 0, "CONF": 0, "ACC": 0, "BIN_ACC": 0, "BIN_CONF": 0} + i: { + "COUNT": 0, + "CONF": 0, + "ACC": 0, + "BIN_ACC": 0, + "BIN_CONF": 0, + } for i in range(num_bins) } -def _populate_bins(confs, preds, labels, num_bins: int, multilabel=False): +def _populate_bins( + confs, preds, labels, num_bins: int, multilabel=False +): # initializes n bins (a bin contains probability from x to x + smth (where smth is greater than zero)) bin_dict = _bin_initializer(num_bins) @@ -244,7 +256,9 @@ def _populate_bins(confs, preds, labels, num_bins: int, multilabel=False): binn = int(math.ceil(num_bins * confidence[i] - 1)) bin_dict[binn]["COUNT"] += 1 bin_dict[binn]["CONF"] += confidence[i] - bin_dict[binn]["ACC"] += 1 if label[i] == prediction[i] else 0 + bin_dict[binn]["ACC"] += ( + 1 if label[i] == prediction[i] else 0 + ) else: binn = int(math.ceil(num_bins * confidence - 1)) bin_dict[binn]["COUNT"] += 1 @@ -299,7 +313,9 @@ def produce_ece_curve( ece_data = list() for i in range(p.shape[1]): ece_data.append( - expected_calibration_error(p[:, i], pred[:, i], true[:, i])[0] + expected_calibration_error( + p[:, i], pred[:, i], true[:, i] + )[0] ) ece_data = np.mean(np.asarray(ece_data), axis=0) else: @@ -308,10 +324,16 @@ def produce_ece_curve( if ece_data: if multilabel: ece = ece_data - print(f"Expected Calibration Error (ECE) {exp_mode} on {purpose}", ece) + print( + f"Expected Calibration Error (ECE) {exp_mode} on {purpose}", + ece, + ) else: ece, ece_bins = ece_data - print(f"Expected Calibration Error (ECE) {exp_mode} on {purpose}", ece) + print( + f"Expected Calibration Error (ECE) {exp_mode} on {purpose}", + ece, + ) concept_flag = True if purpose != "labels" else False produce_calibration_curve( ece_bins, @@ -327,7 +349,9 @@ def print_distance( tester, test_loader, ): - dist_fs, dist_left, dist_right = tester.p_c_x_distance(test_loader) + dist_fs, dist_left, dist_right = tester.p_c_x_distance( + test_loader + ) print(f"Distance FS: {dist_fs}") print(f"Distance LEFT: {dist_left}") print(f"Distance RIGHT: {dist_right}") @@ -352,7 +376,12 @@ def plot_multilabel_confusion_matrices( for i, (matrix, ax) in enumerate(zip(conf_matrices, axes)): matrix_to_disp = matrix.astype("float") / matrix.sum() sns.heatmap( - matrix_to_disp, annot=True, fmt=".2%", cmap="viridis", cbar=False, ax=ax + matrix_to_disp, + annot=True, + fmt=".2%", + cmap="viridis", + cbar=False, + ax=ax, ) ax.set_title(f"Label {labels[i]}") @@ -362,7 +391,9 @@ def plot_multilabel_confusion_matrices( ax.set_xlabel("Predicted") plt.suptitle(f"{plot_title}", y=1.02) - plt.savefig(f"./plots/normalized_multilabel_confusion_{fig_title}.png") + plt.savefig( + f"./plots/normalized_multilabel_confusion_{fig_title}.png" + ) def plot_statistics_single_model( @@ -384,7 +415,9 @@ def plot_statistics_single_model( suffix, ): - for i, direction in zip([0, 1, 2], ["stop_forward", "left", "right"]): + for i, direction in zip( + [0, 1, 2], ["stop_forward", "left", "right"] + ): mean_h_c = print_metrics( y_true, y_predictions, @@ -430,7 +463,9 @@ def plot_statistics_single_model( if i > 0: continue - conf_matrix = multilabel_confusion_matrix(y_true, y_predictions) + conf_matrix = multilabel_confusion_matrix( + y_true, y_predictions + ) labels = [f"{i + 1}" for i in range(len(conf_matrix))] plot_multilabel_confusion_matrices( conf_matrix, @@ -463,7 +498,13 @@ def plot_statistics_single_model( ) produce_ece_curve( - pc_prob, pc_pred, c_true, bayes_method, "concepts", suffix, True + pc_prob, + pc_pred, + c_true, + bayes_method, + "concepts", + suffix, + True, ) @@ -485,8 +526,12 @@ def ova_entropy(p): c_fact_filtered = c_fact[:, c] c_fact_filtered = np.expand_dims(c_fact_filtered, axis=-1) - result = np.apply_along_axis(ova_entropy, axis=1, arr=c_fact_filtered) - conditional_entropies["c_ova_filtered"].append(np.mean(result)) + result = np.apply_along_axis( + ova_entropy, axis=1, arr=c_fact_filtered + ) + conditional_entropies["c_ova_filtered"].append( + np.mean(result) + ) return conditional_entropies @@ -524,7 +569,11 @@ def world_accuracy(world_pred, world_true): for i, direction, world_size in zip( [0, 1, 2], ["stop_forward", "left", "right"], - [int(math.pow(2, 9)), int(math.pow(2, 6)), int(math.pow(2, 6))], + [ + int(math.pow(2, 9)), + int(math.pow(2, 6)), + int(math.pow(2, 6)), + ], ): acc_list, counter_list = get_accuracy_and_counter( world_pred[i], world_true[i], world_size @@ -552,7 +601,12 @@ def single_concept_ece(bayes_method, labels, p, pred, true, suffix): for c in labels: ece_single_concept = produce_ece_curve( - p[:, c], pred[:, c], true[:, c], bayes_method, f"concept {c}", f"_{suffix}" + p[:, c], + pred[:, c], + true[:, c], + bayes_method, + f"concept {c}", + f"_{suffix}", ) single_concepts_ece.append(ece_single_concept) @@ -562,22 +616,39 @@ def single_concept_ece(bayes_method, labels, p, pred, true, suffix): def compute_mean_acc_f1(y_true, y_predictions, dim): f1_M_list, acc_list = list(), list() for i in range(dim): - f1_M_list.append(f1_score(y_true[:, i], y_predictions[:, i], average="macro")) - acc_list.append(accuracy_score(y_true[:, i], y_predictions[:, i])) + f1_M_list.append( + f1_score( + y_true[:, i], y_predictions[:, i], average="macro" + ) + ) + acc_list.append( + accuracy_score(y_true[:, i], y_predictions[:, i]) + ) f1_M = np.mean(np.asarray(f1_M_list), axis=0) acc = np.mean(np.asarray(acc_list), axis=0) return f1_M, acc def compute_acc_f1( - y_true, y_predictions, c_true, pc_pred, w_groundtruths, w_predictions + y_true, + y_predictions, + c_true, + pc_pred, + w_groundtruths, + w_predictions, ): - f1, accuracy = compute_mean_acc_f1(y_true, y_predictions, y_true.shape[1]) + f1, accuracy = compute_mean_acc_f1( + y_true, y_predictions, y_true.shape[1] + ) precision_per_class, recall_per_class, f1_score_per_class, _ = ( - precision_recall_fscore_support(y_true, y_predictions, average=None) + precision_recall_fscore_support( + y_true, y_predictions, average=None + ) ) - concept_f1, concept_accuracy = compute_mean_acc_f1(c_true, pc_pred, c_true.shape[1]) + concept_f1, concept_accuracy = compute_mean_acc_f1( + c_true, pc_pred, c_true.shape[1] + ) worlds_test_accuracies, worlds_test_f1 = worlds_f1_acc( w_groundtruths, w_predictions ) @@ -599,7 +670,9 @@ def worlds_f1_acc(w_groundtruths, w_predictions): worlds_test_f1 = [] for i in range(len(w_groundtruths)): accuracy = accuracy_score(w_groundtruths[i], w_predictions[i]) - f1 = f1_score(w_groundtruths[i], w_predictions[i], average="micro") + f1 = f1_score( + w_groundtruths[i], w_predictions[i], average="micro" + ) worlds_test_accuracies.append(accuracy) worlds_test_f1.append(f1) return worlds_test_accuracies, worlds_test_f1 @@ -673,9 +746,7 @@ def dump_dictionary( # If not, create it os.makedirs("dumps") - file_path = ( - f"dumps/dpl-seed_{args.seed}-nens_{args.n_models}-lambda_{args.lambda_h}.json" - ) + file_path = f"dumps/dpl-seed_{args.seed}-nens_{args.n_models}-lambda_{args.lambda_h}.json" if incomplete: print("Sono incompleto") @@ -722,7 +793,13 @@ def total_evaluation_stuff( single_concept_ece_list_train, worlds_size, ): - evals = ["frequentist", "laplace", "mcdropout", "biretta", "deep ensembles"] + evals = [ + "frequentist", + "laplace", + "mcdropout", + "biretta", + "deep ensembles", + ] categories = [i for i in range(21)] for direction in ["stop_forward", "left", "right"]: @@ -861,7 +938,10 @@ def parse_args(): help="ncalls for bayes opt gp method in Lipschitz estimation", ) parser.add_argument( - "--lip_eps", type=float, default=0.01, help="eps for Lipschitz estimation" + "--lip_eps", + type=float, + default=0.01, + help="eps for Lipschitz estimation", ) parser.add_argument( "--lip_points", @@ -870,11 +950,17 @@ def parse_args(): help="sample size for dataset Lipschitz estimation", ) parser.add_argument( - "--optim", type=str, default="gp", help="black-box optimization method" + "--optim", + type=str, + default="gp", + help="black-box optimization method", ) parser.add_argument( - "--model_name", type=str, default="dpl", help="Choose model to fit" + "--model_name", + type=str, + default="dpl", + help="Choose model to fit", ) parser.add_argument( @@ -891,12 +977,20 @@ def parse_args(): default=[-1], help="Which concepts explicitly supervise (-1 means all)", ) - parser.add_argument("--wandb", type=str, default=None, help="Activate wandb") parser.add_argument( - "--project", type=str, default="BDD-OIA", help="Select wandb project" + "--wandb", type=str, default=None, help="Activate wandb" ) parser.add_argument( - "--do-test", default=False, action="store_true", help="Test the model" + "--project", + type=str, + default="BDD-OIA", + help="Select wandb project", + ) + parser.add_argument( + "--do-test", + default=False, + action="store_true", + help="Test the model", ) parser.add_argument( @@ -906,7 +1000,10 @@ def parse_args(): help="Use KL to differentiate the models", ) parser.add_argument( - "--epsilon", type=float, default=0.01, help="Use KL to differentiate the models" + "--epsilon", + type=float, + default=0.01, + help="Use KL to differentiate the models", ) parser.add_argument( "--lambda_h", @@ -914,7 +1011,9 @@ def parse_args(): default=1.0, help="Lambda parameter used to weight the entropy loss", ) - parser.add_argument("--n-models", type=int, default=30, help="Number of runs") + parser.add_argument( + "--n-models", type=int, default=30, help="Number of runs" + ) parser.add_argument( "--knowledge_aware_kl", default=True, @@ -942,7 +1041,10 @@ def parse_args(): help="Lambda parameter used to weight the KL loss", ) parser.add_argument( - "--pcbm", action="store_true", default=False, help="KL for PCBM" + "--pcbm", + action="store_true", + default=False, + help="KL for PCBM", ) ##### @@ -995,7 +1097,9 @@ def main(args): # set which GPU uses # if args.cuda: - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device = torch.device( + "cuda" if torch.cuda.is_available() else "cpu" + ) # else: # device = torch.device("cpu") @@ -1050,10 +1154,14 @@ def main(args): ) # get paths (see SENN/utils.py, lines 34-). This function is made by SENN's authors - model_path, log_path, results_path = generate_dir_names("bdd", args) + model_path, log_path, results_path = generate_dir_names( + "bdd", args + ) # Convert the arguments to a string representation - arg_string = "\n".join([f"{arg}={getattr(args, arg)}" for arg in vars(args)]) + arg_string = "\n".join( + [f"{arg}={getattr(args, arg)}" for arg in vars(args)] + ) file_path = "%s/args.txt" % (results_path) with open(file_path, "w") as f: f.write(arg_string) @@ -1084,7 +1192,14 @@ def main(args): sys.exit(1) parametrizer1 = dfc_parametrizer( - 2048, 1024, 512, 256, 128, args.nconcepts, args.theta_dim, layers=4 + 2048, + 1024, + 512, + 256, + 128, + args.nconcepts, + args.theta_dim, + layers=4, ) buf = 1 @@ -1098,7 +1213,9 @@ def main(args): args.concept_dim, args.nclasses, args.nconcepts_labeled ) else: - aggregator = additive_scalar_aggregator(args.concept_dim, args.nclasses) + aggregator = additive_scalar_aggregator( + args.concept_dim, args.nclasses + ) # you should set load_model as True. If you set, you can use inception v.3 as the encoder, otherwise end. @@ -1110,11 +1227,21 @@ def main(args): # model = GSENN(conceptizer1, parametrizer1, aggregator, args.cbm, args.senn) if args.model_name == "dpl": model = DPL( - conceptizer1, parametrizer1, aggregator, args.cbm, args.senn, device + conceptizer1, + parametrizer1, + aggregator, + args.cbm, + args.senn, + device, ) elif args.model_name == "dpl_auc": model = DPL_AUC( - conceptizer1, parametrizer1, aggregator, args.cbm, args.senn, device + conceptizer1, + parametrizer1, + aggregator, + args.cbm, + args.senn, + device, ) elif args.model_name == "dpl_auc_pcbm": args.pcbm = True @@ -1128,13 +1255,20 @@ def main(args): device, ) model = DPL_AUC_PCBM( - conceptizer1, parametrizer1, aggregator, args.cbm, args.senn, device + conceptizer1, + parametrizer1, + aggregator, + args.cbm, + args.senn, + device, ) # send models to device you want to use model = model.to(device) print("Res path", results_path) - load_checkpoint(model, f"models/bdd/{args.model_name}-{args.seed}", args.seed) + load_checkpoint( + model, f"models/bdd/{args.model_name}-{args.seed}", args.seed + ) print("Model", model) # Test or train @@ -1201,7 +1335,10 @@ def main(args): tester.setup( train_loader, train_loader_no_shuffle, - [args.seed + seed + 1 for seed in range(args.n_models)], + [ + args.seed + seed + 1 + for seed in range(args.n_models) + ], valid_loader, epochs=args.epochs, save_path=model_path, @@ -1212,8 +1349,13 @@ def main(args): ) # plot the losses for deep ensembles only - if bayes_method == "deepensembles" or bayes_method == "resense": - tester.plot_losses(bayes_method, save_path=results_path) + if ( + bayes_method == "deepensembles" + or bayes_method == "resense" + ): + tester.plot_losses( + bayes_method, save_path=results_path + ) save_file_name = f"{results_path}/test_results_of_BDD_{args.seed}_{bayes_method}_{args.lambda_h}_{args.lambda_kl}.csv" fp = open(save_file_name, "w") @@ -1226,24 +1368,38 @@ def main(args): # evaluation by test dataset tester.test_and_save_csv( - test_loader, save_file_name, fold="test", pcbm=args.pcbm + test_loader, + save_file_name, + fold="test", + pcbm=args.pcbm, ) tester.test_and_save_csv( - train_loader, save_file_name_train, fold="train", pcbm=args.pcbm + train_loader, + save_file_name_train, + fold="train", + pcbm=args.pcbm, ) # for the ensemble method write everything as frequentist if bayes_method != "frequentist": if bayes_method == "laplace": - for i, (inputs, targets, concepts) in enumerate(test_loader): - ensemble = tester.get_ensemble_from_bayes(args.n_models, inputs) + for i, (inputs, targets, concepts) in enumerate( + test_loader + ): + ensemble = tester.get_ensemble_from_bayes( + args.n_models, inputs + ) break else: - ensemble = tester.get_ensemble_from_bayes(args.n_models) + ensemble = tester.get_ensemble_from_bayes( + args.n_models + ) for j in range(len(ensemble)): - frequentist_m_tester = ClassificationTesterFactory.get_model( - "frequentist", ensemble[j], args, device + frequentist_m_tester = ( + ClassificationTesterFactory.get_model( + "frequentist", ensemble[j], args, device + ) ) # initialize the csv file (cleaning before training) @@ -1379,7 +1535,9 @@ def main(args): fp.close() #### EM - save_file_name_train = "%s/train_results_of_BDD.csv" % (results_path) + save_file_name_train = "%s/train_results_of_BDD.csv" % ( + results_path + ) fp_train = open(save_file_name_train, "w") fp_train.close() @@ -1404,7 +1562,10 @@ def main(args): test_loader, save_file_name, fold="test", pcbm=args.pcbm ) trainer.test_and_save_csv( - train_loader, save_file_name_train, fold="train", pcbm=args.pcbm + train_loader, + save_file_name_train, + fold="train", + pcbm=args.pcbm, ) # send model result to cpu diff --git a/BDD_OIA/models.py b/BDD_OIA/models.py index 0857834..a6be9b8 100644 --- a/BDD_OIA/models.py +++ b/BDD_OIA/models.py @@ -3,14 +3,16 @@ Detail of forwarding of our model """ +import pdb + # -*- coding: utf-8 -*- import sys -import pdb + import numpy as np import torch +import torch.autograd as autograd import torch.nn as nn import torch.nn.functional as F -import torch.autograd as autograd from torch.autograd import Variable # if you set True, many print function is used to debug @@ -47,7 +49,9 @@ def __init__: None """ - def __init__(self, conceptizer, parametrizer, aggregator, cbm, senn): + def __init__( + self, conceptizer, parametrizer, aggregator, cbm, senn + ): super(GSENN, self).__init__() self.cbm = cbm self.senn = senn @@ -95,7 +99,9 @@ def forward(self, x): # Concepts are two-dimensional, so flatten h_x = h_x.view(h_x.size(0), h_x.size(1), -1) if len(h_x_labeled.size()) == 4: - h_x_labeled = h_x_labeled.view(h_x_labeled.size(0), h_x.size(1), -1) + h_x_labeled = h_x_labeled.view( + h_x_labeled.size(0), h_x.size(1), -1 + ) if not self.senn: # store (known+unknown) concepts diff --git a/BDD_OIA/parametrizers.py b/BDD_OIA/parametrizers.py index 0822699..91dd54f 100644 --- a/BDD_OIA/parametrizers.py +++ b/BDD_OIA/parametrizers.py @@ -5,15 +5,16 @@ # -*- coding: utf-8 -*- +import pdb + +import numpy as np + # Torch Imports import torch +import torch.autograd as autograd import torch.nn as nn import torch.nn.functional as F -import torch.autograd as autograd import torchvision -import pdb -import numpy as np - """ class dfc_parametrizer: @@ -41,7 +42,17 @@ def __init__: None """ - def __init__(self, din, hdim1, hdim2, hdim3, hdim4, nconcept, dout, layers=2): + def __init__( + self, + din, + hdim1, + hdim2, + hdim3, + hdim4, + nconcept, + dout, + layers=2, + ): super(dfc_parametrizer, self).__init__() print(din, hdim1, hdim2, hdim3, hdim4, nconcept, dout) @@ -63,7 +74,9 @@ def __init__(self, din, hdim1, hdim2, hdim3, hdim4, nconcept, dout, layers=2): self.layers = 4, define the parametrizer for the final layer of encoder """ if self.layers == 3: - self.linear3_2 = nn.Linear(hdim2, self.nconcept * self.dout) + self.linear3_2 = nn.Linear( + hdim2, self.nconcept * self.dout + ) else: self.linear3_1 = nn.Linear(hdim2, hdim3) self.bn3 = nn.BatchNorm1d(num_features=hdim3) @@ -129,18 +142,26 @@ def __init__: None """ - def __init__(self, din, nconcept, dout, nchannel=1, only_positive=False): + def __init__( + self, din, nconcept, dout, nchannel=1, only_positive=False + ): super(image_parametrizer, self).__init__() self.nconcept = nconcept self.dout = dout self.din = din - self.conv1 = nn.Conv2d(nchannel, 10, kernel_size=5) # b, 10, din - (k -1), same + self.conv1 = nn.Conv2d( + nchannel, 10, kernel_size=5 + ) # b, 10, din - (k -1), same # after ppol layer with stride=2: din/2 - (k -1)/2 - self.conv2 = nn.Conv2d(10, 20, kernel_size=5) # b, 20, din/2 - 3(k -1)/2, same + self.conv2 = nn.Conv2d( + 10, 20, kernel_size=5 + ) # b, 20, din/2 - 3(k -1)/2, same # after ppol layer with stride=2: din/4 - 3(k -1)/4 self.dout_conv = int(np.sqrt(din) // 4 - 3 * (5 - 1) // 4) self.conv2_drop = nn.Dropout2d() - self.fc1 = nn.Linear(20 * (self.dout_conv**2), nconcept * dout) + self.fc1 = nn.Linear( + 20 * (self.dout_conv**2), nconcept * dout + ) self.positive = only_positive """ @@ -157,10 +178,14 @@ def forward(self, x): p_1 = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(p)), 2)) p_2 = p_1.view(-1, 20 * (self.dout_conv**2)) p_3 = self.fc1(p_2) - out = F.dropout(p_3, training=self.training).view(-1, self.nconcept, self.dout) + out = F.dropout(p_3, training=self.training).view( + -1, self.nconcept, self.dout + ) # if self.positive is true, output activation uses sigmoid, otherwise tanh if self.positive: - out_val = F.sigmoid(out) # For fixed outputdim, sum over concepts = 1 + out_val = F.sigmoid( + out + ) # For fixed outputdim, sum over concepts = 1 else: out_val = torch.tanh(out) return out_val diff --git a/BDD_OIA/server.py b/BDD_OIA/server.py index 88640fd..a83c70f 100644 --- a/BDD_OIA/server.py +++ b/BDD_OIA/server.py @@ -1,7 +1,9 @@ +import os +import sys + import submitit -from main_bdd import parse_args, main from experiments import launch_bdd -import os, sys +from main_bdd import main, parse_args conf_path = os.getcwd() + "." sys.path.append(conf_path) @@ -14,7 +16,9 @@ args.nclasses = 5 args.theta_dim = args.nclasses - executor = submitit.AutoExecutor(folder="./logs", slurm_max_num_timeout=30) + executor = submitit.AutoExecutor( + folder="./logs", slurm_max_num_timeout=30 + ) executor.update_parameters( mem_gb=4, gpus_per_node=1, diff --git a/BDD_OIA/testers_BDD.py b/BDD_OIA/testers_BDD.py index 26d13ac..7a91b17 100644 --- a/BDD_OIA/testers_BDD.py +++ b/BDD_OIA/testers_BDD.py @@ -4,48 +4,49 @@ We modified so as to fit the semi-supervised fashion. """ -# standard imports -import sys import builtins +import copy +import math import os -import tqdm -import time import pdb +import random import shutil -import torch + +# standard imports +import sys +import time + +import matplotlib.pyplot as plt import numpy as np -import copy -from numpy import ndarray -from torch.autograd import Variable +import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim -import math -import matplotlib.pyplot as plt +import tqdm import wandb -import random -from torch.utils.data import Dataset +from aggregators_BDD import CBM_aggregator, additive_scalar_aggregator from conceptizers_BDD import image_fcc_conceptizer -from parametrizers import image_parametrizer, dfc_parametrizer -from aggregators_BDD import additive_scalar_aggregator, CBM_aggregator -from scipy.special import softmax +from DPL.dpl import DPL +from DPL.dpl_auc import DPL_AUC from laplace import Laplace +from numpy import ndarray +from parametrizers import dfc_parametrizer, image_parametrizer +from scipy.special import softmax # Local imports from SENN.utils import AverageMeter -from DPL.dpl import DPL -from DPL.dpl_auc import DPL_AUC - +from torch.autograd import Variable +from torch.utils.data import Dataset from worlds_BDD import ( - compute_forward_stop_prob, - compute_left, - compute_right, compute_forward_prob, - compute_stop_prob, - compute_output_probability, compute_forward_stop_groundtruth, + compute_forward_stop_prob, + compute_left, compute_left_groundtruth, + compute_output_probability, + compute_right, compute_right_groundtruth, + compute_stop_prob, convert_np_array_to_binary, ) @@ -65,7 +66,9 @@ def early_stop(self, model, validation_loss): self.counter = 0 self.best_weights = model.state_dict() self.stuck = False - elif validation_loss > (self.min_validation_loss + self.min_delta): + elif validation_loss > ( + self.min_validation_loss + self.min_delta + ): self.counter += 1 if self.counter >= self.patience: model.load_state_dict(self.best_weights) @@ -205,7 +208,9 @@ def _vector_to_parameters(vec: torch.Tensor, parameters) -> None: # Ensure vec of type Tensor if not isinstance(vec, torch.Tensor): raise TypeError( - "expected torch.Tensor, but got: {}".format(torch.typename(vec)) + "expected torch.Tensor, but got: {}".format( + torch.typename(vec) + ) ) # Pointer for slicing the vector for each parameter @@ -214,7 +219,9 @@ def _vector_to_parameters(vec: torch.Tensor, parameters) -> None: # The length of the parameter num_param = param.numel() # Slice the vector, reshape it, and replace the old data of the parameter - param.data = vec[pointer : pointer + num_param].view_as(param).data + param.data = ( + vec[pointer : pointer + num_param].view_as(param).data + ) # Increment the pointer pointer += num_param @@ -233,15 +240,21 @@ class ClassificationTester: def __init__(self, model, args, device): # hyparparameters used in the loss function self.lambd = ( - args.theta_reg_lambda if ("theta_reg_lambda" in args) else 1e-6 + args.theta_reg_lambda + if ("theta_reg_lambda" in args) + else 1e-6 ) # for regularization strenght self.eta = ( - args.h_labeled_param if ("h_labeled_param" in args) else 0.0 + args.h_labeled_param + if ("h_labeled_param" in args) + else 0.0 ) # for wealky supervised self.gamma = ( args.info_hypara if ("info_hypara" in args) else 0.0 ) # for wealky supervised - self.w_entropy = args.w_entropy if ("w_entropy" in args) else 0.0 + self.w_entropy = ( + args.w_entropy if ("w_entropy" in args) else 0.0 + ) # set the seed self.seed = args.seed @@ -308,7 +321,9 @@ def __init__(self, model, args, device): self.model.parameters(), lr=args.lr, betas=optim_betas ) elif args.opt == "rmsprop": - self.optimizer = optim.RMSprop(self.model.parameters(), lr=args.lr) + self.optimizer = optim.RMSprop( + self.model.parameters(), lr=args.lr + ) elif args.opt == "sgd": self.optimizer = optim.SGD( self.model.parameters(), @@ -365,8 +380,15 @@ def validate(self, val_loader, epoch, fold=None, name=""): all_losses = {"prediction": pred_loss.cpu().data.numpy()} # compute loss of known concets and discriminator - h_loss, hh_labeled = self.concept_learning_loss_for_weak_supervision( - inputs, all_losses, concepts, self.args.cbm, self.args.senn, epoch + h_loss, hh_labeled = ( + self.concept_learning_loss_for_weak_supervision( + inputs, + all_losses, + concepts, + self.args.cbm, + self.args.senn, + epoch, + ) ) loss_h += self.entropy_loss( @@ -382,14 +404,22 @@ def validate(self, val_loader, epoch, fold=None, name=""): # measure accuracy and record loss if self.nclasses > 4: # mainly use this line (current) - prec1, _ = self.accuracy(output.data, targets, topk=(1, 5)) + prec1, _ = self.accuracy( + output.data, targets, topk=(1, 5) + ) elif self.nclasses == 3: - prec1, _ = self.accuracy(output.data, targets, topk=(1, 3)) + prec1, _ = self.accuracy( + output.data, targets, topk=(1, 3) + ) else: - prec1, _ = self.binary_accuracy(output.data, targets), [100] + prec1, _ = self.binary_accuracy( + output.data, targets + ), [100] # update each value of fprint's values - losses.update(pred_loss.data.cpu().numpy(), inputs.size(0)) + losses.update( + pred_loss.data.cpu().numpy(), inputs.size(0) + ) top1.update(prec1[0], inputs.size(0)) # measure accuracy of concepts @@ -411,7 +441,10 @@ def validate(self, val_loader, epoch, fold=None, name=""): "Val: [{0}/{1}]\t" "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" "Loss {loss.val:.4f} ({loss.avg:.4f})\t".format( - i, len(val_loader), batch_time=batch_time, loss=losses + i, + len(val_loader), + batch_time=batch_time, + loss=losses, ) ) else: @@ -421,7 +454,10 @@ def validate(self, val_loader, epoch, fold=None, name=""): "Val: [{0}/{1}]\t" "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" "Loss {loss.val:.4f} ({loss.avg:.4f})\t".format( - i, len(val_loader), batch_time=batch_time, loss=losses + i, + len(val_loader), + batch_time=batch_time, + loss=losses, ) ) val_loss_dict = {"iter": epoch, f"{name} prediction": loss_y} @@ -472,11 +508,17 @@ def test(self, test_loader, save_file_name, fold=None): # measure accuracy and record loss if self.nclasses > 4: # mainly use this line (current) - prec1, _ = self.accuracy(output.data, targets, topk=(1, 5)) + prec1, _ = self.accuracy( + output.data, targets, topk=(1, 5) + ) elif self.nclasses == 3: - prec1, _ = self.accuracy(output.data, targets, topk=(1, 3)) + prec1, _ = self.accuracy( + output.data, targets, topk=(1, 3) + ) else: - prec1, _ = self.binary_accuracy(output.data, targets), [100] + prec1, _ = self.binary_accuracy( + output.data, targets + ), [100] # update each value of fprint's values losses.update(loss.data.cpu().numpy(), inputs.size(0)) @@ -563,7 +605,10 @@ def test(self, test_loader, save_file_name, fold=None): "Test: [{0}/{1}]\t" "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" "Loss {loss.val:.4f} ({loss.avg:.4f})\t".format( - i, len(test_loader), batch_time=batch_time, loss=losses + i, + len(test_loader), + batch_time=batch_time, + loss=losses, ) ) @@ -602,7 +647,9 @@ def compute_accuracy_f1(self, y_pred, y_true): # Modify each tensor in the list for i in range(len(y_preds_list)): - y_preds_list[i] = torch.argmax(y_preds_list[i], dim=1).unsqueeze(-1) + y_preds_list[i] = torch.argmax( + y_preds_list[i], dim=1 + ).unsqueeze(-1) all_true = torch.cat(y_trues[:4], dim=-1) all_pred = torch.cat(y_preds_list[:4], dim=-1) @@ -631,13 +678,17 @@ def mean_l2_distance(self, vectors): return mean_distance def concept_error(self, output, target): - err = torch.Tensor(1).fill_((output.round().eq(target)).float().mean() * 100) + err = torch.Tensor(1).fill_( + (output.round().eq(target)).float().mean() * 100 + ) err = (100.0 - err.data[0]) / 100 return err def binary_accuracy(self, output, target): """Computes the accuracy""" - return torch.Tensor(1).fill_((output.round().eq(target)).float().mean() * 100) + return torch.Tensor(1).fill_( + (output.round().eq(target)).float().mean() * 100 + ) def accuracy(self, output, target, topk=(1,), numpy=False): if numpy: @@ -648,12 +699,16 @@ def accuracy(self, output, target, topk=(1,), numpy=False): pred = np.argpartition(output, -maxk)[:, -maxk:] # Check if each predicted class matches the target class - correct = pred == target # (pred == target.reshape((-1, 1))) + correct = ( + pred == target + ) # (pred == target.reshape((-1, 1))) # If topk = (1,5), then, k=1 and k=5 res = [] for k in topk: - correct_k = np.sum(correct[:, :k].any(axis=1).astype(float)) + correct_k = np.sum( + correct[:, :k].any(axis=1).astype(float) + ) res.append(correct_k * 100.0 / batch_size) return res @@ -670,20 +725,26 @@ def accuracy(self, output, target, topk=(1,), numpy=False): # if topk = (1,5), then, k=1 and k=5 res = [] for k in topk: - correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) + correct_k = ( + correct[:k].view(-1).float().sum(0, keepdim=True) + ) res.append(correct_k.mul_(100.0 / batch_size)) return res # plot losses -> it just takes the values from loss_history and val_loss_history def plot_losses(self, name, save_path=None): - loss_types = [k for k in self.loss_history[0].keys() if k != "iter"] + loss_types = [ + k for k in self.loss_history[0].keys() if k != "iter" + ] losses = {k: [] for k in loss_types} iters = [] for e in self.loss_history: iters.append(e["iter"]) for k in loss_types: losses[k].append(e[k]) - fig, ax = plt.subplots(1, len(loss_types), figsize=(4 * len(loss_types), 5)) + fig, ax = plt.subplots( + 1, len(loss_types), figsize=(4 * len(loss_types), 5) + ) if len(loss_types) == 1: ax = [ax] # Hacky, fix for i, k in enumerate(loss_types): @@ -704,14 +765,18 @@ def plot_losses(self, name, save_path=None): #### VALIDATION plt.close() - loss_types = [k for k in self.val_loss_history[0].keys() if k != "iter"] + loss_types = [ + k for k in self.val_loss_history[0].keys() if k != "iter" + ] losses = {k: [] for k in loss_types} iters = [] for e in self.val_loss_history: iters.append(e["iter"]) for k in loss_types: losses[k].append(e[k]) - fig, ax = plt.subplots(1, len(loss_types), figsize=(4 * len(loss_types), 5)) + fig, ax = plt.subplots( + 1, len(loss_types), figsize=(4 * len(loss_types), 5) + ) if len(loss_types) == 1: ax = [ax] # Hacky, fix for i, k in enumerate(loss_types): @@ -731,7 +796,9 @@ def BCE_forloop(self, tar, pred): loss = F.binary_cross_entropy(tar[0, :4], pred[0, :4]) for i in range(1, len(tar)): - loss = loss + F.binary_cross_entropy(tar[i, :4], pred[i, :4]) + loss = loss + F.binary_cross_entropy( + tar[i, :4], pred[i, :4] + ) return loss def CE_forloop(self, y_pred, y_true): @@ -778,9 +845,12 @@ def concept_learning_loss_for_weak_supervision( hh_labeled_list[0], concepts[0].to(self.device) ) for j in range(1, len(hh_labeled_list)): - labeled_loss = labeled_loss + F.binary_cross_entropy( - hh_labeled_list[j], - concepts[j].to(self.device), + labeled_loss = ( + labeled_loss + + F.binary_cross_entropy( + hh_labeled_list[j], + concepts[j].to(self.device), + ) ) # labeled_loss = labeled_loss + torch.nn.BCELoss() F.binary_cross_entropy( @@ -793,7 +863,8 @@ def concept_learning_loss_for_weak_supervision( if i in self.args.which_c: labeled_loss = ( F.binary_cross_entropy( - hh_labeled_list[0, i], concepts[0, i].to(self.device) + hh_labeled_list[0, i], + concepts[0, i].to(self.device), ) / L ) @@ -820,7 +891,9 @@ def concept_learning_loss_for_weak_supervision( if not senn: # save loss (only value) to the all_losses list - all_losses["labeled_h"] = self.eta * labeled_loss.data.cpu().numpy() + all_losses["labeled_h"] = ( + self.eta * labeled_loss.data.cpu().numpy() + ) # use in def train_batch (class GradPenaltyTrainer) return info_loss, hh_labeled_list @@ -831,14 +904,14 @@ def entropy_loss(self, pred_c, all_losses, epoch): # real uses the discriminator's loss avg_c = torch.mean(pred_c, dim=0) - total_ent = -avg_c[0] * torch.log(avg_c[0]) - (1 - avg_c[0]) * torch.log( + total_ent = -avg_c[0] * torch.log(avg_c[0]) - ( 1 - avg_c[0] - ) + ) * torch.log(1 - avg_c[0]) total_ent /= np.log(2) for i in range(1, 21): - ent_i = -avg_c[i] * torch.log(avg_c[i]) - (1 - avg_c[i]) * torch.log( + ent_i = -avg_c[i] * torch.log(avg_c[i]) - ( 1 - avg_c[i] - ) + ) * torch.log(1 - avg_c[i]) ent_i /= np.log(2) assert ent_i <= 1 and ent_i >= 0, (ent_i, avg_c[i]) @@ -856,7 +929,12 @@ def entropy_loss(self, pred_c, all_losses, epoch): # test the model (on the current model) and then save the results def test_and_save_csv( - self, test_loader, save_file_name, fold=None, dropout=False, pcbm=False + self, + test_loader, + save_file_name, + fold=None, + dropout=False, + pcbm=False, ): print("Saving ", save_file_name, "...") @@ -909,11 +987,17 @@ def _deactivate_dropout(): # measure accuracy and record loss if self.nclasses > 4: # mainly use this line (current) - prec1, _ = self.accuracy(output.data, targets, topk=(1, 5)) + prec1, _ = self.accuracy( + output.data, targets, topk=(1, 5) + ) elif self.nclasses == 3: - prec1, _ = self.accuracy(output.data, targets, topk=(1, 3)) + prec1, _ = self.accuracy( + output.data, targets, topk=(1, 3) + ) else: - prec1, _ = self.binary_accuracy(output.data, targets), [100] + prec1, _ = self.binary_accuracy( + output.data, targets + ), [100] # update each value of print's values losses.update(loss.data.cpu().numpy(), inputs.size(0)) @@ -1022,7 +1106,9 @@ def get_ensemble_from_bayes(self, n_ensemble): raise NotImplemented("Not implemented for this method") # get the concept probability (factorized, not actual) for the ensemble - def get_concept_probability_factorized_ensemble(self, ensemble, loader): + def get_concept_probability_factorized_ensemble( + self, ensemble, loader + ): ensemble_c_prb = [] for model in ensemble: @@ -1048,17 +1134,20 @@ def get_concept_probability_factorized_ensemble(self, ensemble, loader): c_true = concepts.detach().cpu().numpy() else: c_prb = np.concatenate( - [c_prb, concept.detach().cpu().numpy()], axis=0 + [c_prb, concept.detach().cpu().numpy()], + axis=0, ) c_true = np.concatenate( - [c_true, concepts.detach().cpu().numpy()], axis=0 + [c_true, concepts.detach().cpu().numpy()], + axis=0, ) # the groundtruth world gt_factorized_max = np.max(c_true, axis=-1) # it is basically a list of indexes where the maximum value (1) occours gt_factorized = [ - np.where(row == gt_factorized_max[i])[0] for i, row in enumerate(c_true) + np.where(row == gt_factorized_max[i])[0] + for i, row in enumerate(c_true) ] ensemble_c_prb.append(c_prb) @@ -1092,10 +1181,14 @@ def setup( def test_and_save(self, test_loader, save_file_name, fold=None): # just add frequentist at the end of the file name and then save it - current_model_name = add_previous_dot(save_file_name, "_frequentist") + current_model_name = add_previous_dot( + save_file_name, "_frequentist" + ) super().test_and_save(test_loader, current_model_name, fold) - def frequentist_batch_prediction(self, batch_samples, apply_softmax=False): + def frequentist_batch_prediction( + self, batch_samples, apply_softmax=False + ): # single prediction for the standard frequentist model self.model.eval() @@ -1127,7 +1220,9 @@ def frequentist_prediction(self, loader, apply_softmax=False): ) (label_prob, concept_prob) = ( # (256, 4) # (256, 21) - self.frequentist_batch_prediction(images, apply_softmax) + self.frequentist_batch_prediction( + images, apply_softmax + ) ) # Concatenate the output @@ -1137,12 +1232,16 @@ def frequentist_prediction(self, loader, apply_softmax=False): y_pred = label_prob pc_pred = concept_prob else: - y_true = np.concatenate([y_true, labels.detach().cpu().numpy()], axis=0) + y_true = np.concatenate( + [y_true, labels.detach().cpu().numpy()], axis=0 + ) c_true = np.concatenate( [c_true, concepts.detach().cpu().numpy()], axis=0 ) y_pred = np.concatenate([y_pred, label_prob], axis=0) - pc_pred = np.concatenate([pc_pred, concept_prob], axis=0) + pc_pred = np.concatenate( + [pc_pred, concept_prob], axis=0 + ) return y_true, c_true, y_pred, pc_pred @@ -1156,18 +1255,24 @@ def worlds_probability( ): # get the prediction - y_true, c_true, y_pred_org, pc_pred = self.frequentist_prediction( - loader, apply_softmax + y_true, c_true, y_pred_org, pc_pred = ( + self.frequentist_prediction(loader, apply_softmax) ) # unsqueeze it as it was an ensemble of a single prediction y_pred = np.expand_dims(y_pred_org, axis=0) pc_pred_ext = np.expand_dims(pc_pred, axis=0) - fstop_prob = compute_forward_stop_prob(pc_pred_ext) # data, possibleworlds + fstop_prob = compute_forward_stop_prob( + pc_pred_ext + ) # data, possibleworlds left_prob = compute_left(pc_pred_ext) # data, possibleworlds - right_prob = compute_right(pc_pred_ext) # data, possibleworlds - y_prob = compute_output_probability(y_pred) # data, possibleworlds + right_prob = compute_right( + pc_pred_ext + ) # data, possibleworlds + y_prob = compute_output_probability( + y_pred + ) # data, possibleworlds # put all the probabilities together w_probs = [fstop_prob, left_prob, right_prob] @@ -1177,10 +1282,14 @@ def worlds_probability( w_predictions_prob_value = [] for prob in w_probs: w_predictions.append(np.argmax(prob, axis=-1)) # data, 1 - w_predictions_prob_value.append(np.max(prob, axis=-1)) # data, 1 + w_predictions_prob_value.append( + np.max(prob, axis=-1) + ) # data, 1 # compute the ground_truth (which is simply the binary representation up of that slice) - fstop_ground = compute_forward_stop_groundtruth(c_true) # data, 1 + fstop_ground = compute_forward_stop_groundtruth( + c_true + ) # data, 1 left_ground = compute_left_groundtruth(c_true) # data, 1 right_ground = compute_right_groundtruth(c_true) # data, 1 @@ -1205,7 +1314,9 @@ def worlds_probability( y_true = np.concatenate(y_trues[:4], axis=-1) y_predictions = np.concatenate(y_preds_list[:4], axis=-1) - y_predictions_prob = np.concatenate(y_preds_prob_list[:4], axis=-1) + y_predictions_prob = np.concatenate( + y_preds_prob_list[:4], axis=-1 + ) pc_prob = pc_pred pc_pred = (pc_prob > 0.5).astype(float) @@ -1268,8 +1379,12 @@ def test_and_save(self, test_loader, save_file_name, fold=None): self._activate_dropout() # then call the save method with the different kind of evaluation procedures for i in range(self.num_mc_samples): - current_model_name = add_previous_dot(save_file_name, f"_montecarlo_{i}") - super().test_and_save(test_loader, current_model_name, fold) + current_model_name = add_previous_dot( + save_file_name, f"_montecarlo_{i}" + ) + super().test_and_save( + test_loader, current_model_name, fold + ) def _montecarlo_dropout_single_batch( self, @@ -1288,14 +1403,19 @@ def _montecarlo_dropout_single_batch( # activate the double return self.model.return_both_concept_out_prob = True - output_list = [self.model(batch_samples) for _ in range(num_mc_samples)] # 30 + output_list = [ + self.model(batch_samples) for _ in range(num_mc_samples) + ] # 30 # deactivate the double return self.model.return_both_concept_out_prob = False self._deactivate_dropout() - label_prob = [lab.detach().cpu().numpy() for lab, _ in output_list] # 30 + label_prob = [ + lab.detach().cpu().numpy() for lab, _ in output_list + ] # 30 concept_prob = [ - concept.detach().cpu().numpy() for _, concept in output_list + concept.detach().cpu().numpy() + for _, concept in output_list ] # 30 label_prob = np.stack(label_prob, axis=0) @@ -1322,10 +1442,11 @@ def mc_dropout_predictions( ) # Call MC Dropout - (label_prob_ens, concept_prob_ens) = ( # (nmod, 256, 4) # (nmod, 256, 21) - self._montecarlo_dropout_single_batch( - images, num_mc_samples, apply_softmax - ) + ( + label_prob_ens, + concept_prob_ens, + ) = self._montecarlo_dropout_single_batch( # (nmod, 256, 4) # (nmod, 256, 21) + images, num_mc_samples, apply_softmax ) # Concatenate the output @@ -1335,12 +1456,18 @@ def mc_dropout_predictions( y_pred = label_prob_ens pc_pred = concept_prob_ens else: - y_true = np.concatenate([y_true, labels.detach().cpu().numpy()], axis=0) + y_true = np.concatenate( + [y_true, labels.detach().cpu().numpy()], axis=0 + ) c_true = np.concatenate( [c_true, concepts.detach().cpu().numpy()], axis=0 ) - y_pred = np.concatenate([y_pred, label_prob_ens], axis=1) - pc_pred = np.concatenate([pc_pred, concept_prob_ens], axis=1) + y_pred = np.concatenate( + [y_pred, label_prob_ens], axis=1 + ) + pc_pred = np.concatenate( + [pc_pred, concept_prob_ens], axis=1 + ) return y_true, c_true, y_pred, pc_pred @@ -1360,10 +1487,14 @@ def worlds_probability( # no need to expand the dimensions as here we are still considering the ensemble, hence # we have the first dimension valid - fstop_prob = compute_forward_stop_prob(pc_pred) # data, possibleworlds + fstop_prob = compute_forward_stop_prob( + pc_pred + ) # data, possibleworlds left_prob = compute_left(pc_pred) # data, possibleworlds right_prob = compute_right(pc_pred) # data, possibleworlds - y_prob = compute_output_probability(y_pred) # data, possibleworlds + y_prob = compute_output_probability( + y_pred + ) # data, possibleworlds w_probs = [fstop_prob, left_prob, right_prob] @@ -1371,9 +1502,13 @@ def worlds_probability( w_predictions_prob_value = [] for prob in w_probs: w_predictions.append(np.argmax(prob, axis=-1)) # data, 1 - w_predictions_prob_value.append(np.max(prob, axis=-1)) # data, 1 + w_predictions_prob_value.append( + np.max(prob, axis=-1) + ) # data, 1 - fstop_ground = compute_forward_stop_groundtruth(c_true) # data, 1 + fstop_ground = compute_forward_stop_groundtruth( + c_true + ) # data, 1 left_ground = compute_left_groundtruth(c_true) # data, 1 right_ground = compute_right_groundtruth(c_true) # data, 1 @@ -1400,7 +1535,9 @@ def worlds_probability( y_true = np.concatenate(y_trues[:4], axis=-1) y_predictions = np.concatenate(y_preds_list[:4], axis=-1) - y_predictions_prob = np.concatenate(y_preds_prob_list[:4], axis=-1) + y_predictions_prob = np.concatenate( + y_preds_prob_list[:4], axis=-1 + ) pc_pred_mean = np.mean(pc_pred, axis=0) pc_prob = pc_pred_mean @@ -1422,7 +1559,9 @@ def worlds_probability( ) # MC DROPOUT - def test_and_save_csv(self, test_loader, save_file_name, fold=None, pcbm=False): + def test_and_save_csv( + self, test_loader, save_file_name, fold=None, pcbm=False + ): print("Saving ", save_file_name, "...") @@ -1450,15 +1589,20 @@ def test_and_save_csv(self, test_loader, save_file_name, fold=None, pcbm=False): ) # compute output - (label_prob_ens, concept_prob_ens) = ( # (nmod, 256, 4) # (nmod, 256, 21) - self._montecarlo_dropout_single_batch(inputs, 30, False) + ( + label_prob_ens, + concept_prob_ens, + ) = self._montecarlo_dropout_single_batch( # (nmod, 256, 4) # (nmod, 256, 21) + inputs, 30, False ) output = torch.tensor(np.mean(label_prob_ens, axis=0)) pc_prob = torch.tensor(np.mean(concept_prob_ens, axis=0)) if self.cuda: - output, pc_prob = output.cuda(self.device), pc_prob.cuda(self.device) + output, pc_prob = output.cuda( + self.device + ), pc_prob.cuda(self.device) # prediction_criterion is defined in __init__ of "class GradPenaltyTrainer" loss = self.prediction_criterion(output, targets) @@ -1466,11 +1610,17 @@ def test_and_save_csv(self, test_loader, save_file_name, fold=None, pcbm=False): # measure accuracy and record loss if self.nclasses > 4: # mainly use this line (current) - prec1, _ = self.accuracy(output.data, targets, topk=(1, 5)) + prec1, _ = self.accuracy( + output.data, targets, topk=(1, 5) + ) elif self.nclasses == 3: - prec1, _ = self.accuracy(output.data, targets, topk=(1, 3)) + prec1, _ = self.accuracy( + output.data, targets, topk=(1, 3) + ) else: - prec1, _ = self.binary_accuracy(output.data, targets), [100] + prec1, _ = self.binary_accuracy( + output.data, targets + ), [100] # update each value of print's values losses.update(loss.data.cpu().numpy(), inputs.size(0)) @@ -1558,7 +1708,10 @@ def test_and_save_csv(self, test_loader, save_file_name, fold=None, pcbm=False): "Test: [{0}/{1}]\t" "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" "Loss {loss.val:.4f} ({loss.avg:.4f})\t".format( - i, len(test_loader), batch_time=batch_time, loss=losses + i, + len(test_loader), + batch_time=batch_time, + loss=losses, ) ) @@ -1600,10 +1753,14 @@ def p_c_x_distance( if i == 0: c_prb = concept_prob else: - c_prb = np.concatenate([c_prb, concept_prob], axis=1) + c_prb = np.concatenate( + [c_prb, concept_prob], axis=1 + ) # get the worlds probabilities by computing the matmultiplication - fstop_prob = compute_forward_stop_prob(c_prb) # data, possibleworlds + fstop_prob = compute_forward_stop_prob( + c_prb + ) # data, possibleworlds left_prob = compute_left(c_prb) # data, possibleworlds right_prob = compute_right(c_prb) # data, possibleworlds @@ -1645,7 +1802,9 @@ def setup( # Laplace model is given in the setup, which performs the laplace approximation # same stuff as always, compute the hessian around the MAP point on the loss function, as use it # as a multivariate gaussian distribution to sample new model weigths - self.laplace_model = self.laplace_approximation(train_loader, val_loader) + self.laplace_model = self.laplace_approximation( + train_loader, val_loader + ) # get the ensemble, and run the test on that specific ensemble! def test_and_save(self, test_loader, save_file_name, fold=None): @@ -1653,14 +1812,18 @@ def test_and_save(self, test_loader, save_file_name, fold=None): self.laplace_model, self.n_ensembles ) for i in range(self.num_mc_samples): - current_model_name = add_previous_dot(save_file_name, f"_laplace_{i}") + current_model_name = add_previous_dot( + save_file_name, f"_laplace_{i}" + ) self.model = ensemble[i] # override the model - super().test_and_save(test_loader, current_model_name, fold) + super().test_and_save( + test_loader, current_model_name, fold + ) def laplace_approximation(self, train_loader, val_loader): # usual laplace approximation hook - from torch.utils.data import DataLoader from laplace.curvature import AsdlGGN + from torch.utils.data import DataLoader def new_model_copy(to_be_copied): # only "fcc" conceptizer use, otherwise cannot use (not modifile so as to fit this task...) @@ -1681,15 +1844,26 @@ def new_model_copy(to_be_copied): sys.exit(1) parametrizer1 = dfc_parametrizer( - 2048, 1024, 512, 256, 128, self.nconcepts, self.theta_dim, layers=4 + 2048, + 1024, + 512, + 256, + 128, + self.nconcepts, + self.theta_dim, + layers=4, ) if self.cbm == True: aggregator = CBM_aggregator( - self.concept_dim, self.nclasses, self.nconcepts_labeled + self.concept_dim, + self.nclasses, + self.nconcepts_labeled, ) else: - aggregator = additive_scalar_aggregator(self.concept_dim, self.nclasses) + aggregator = additive_scalar_aggregator( + self.concept_dim, self.nclasses + ) if self.model_name == "dpl": model = DPL( @@ -1712,7 +1886,9 @@ def new_model_copy(to_be_copied): # send models to device you want to use model = model.to(self.device) - model.load_state_dict(copy.deepcopy(to_be_copied.state_dict())) + model.load_state_dict( + copy.deepcopy(to_be_copied.state_dict()) + ) return model # Wrapper DataLoader @@ -1724,7 +1900,9 @@ def __init__(self, original_dataloader, **kwargs): def __iter__(self): # Get the iterator from the original DataLoader - original_iterator = super(WrapperDataLoader, self).__iter__() + original_iterator = super( + WrapperDataLoader, self + ).__iter__() for original_batch in original_iterator: modified_batch = [ @@ -1735,7 +1913,9 @@ def __iter__(self): # Wrapper Model class WrapperModel(nn.Module): - def __init__(self, original_model, device, output_all=False): + def __init__( + self, original_model, device, output_all=False + ): super(WrapperModel, self).__init__() self.original_model = original_model self.original_model.to(device) @@ -1748,14 +1928,18 @@ def forward(self, input_batch): # Call the forward method of the model original_output = self.original_model(input_batch) - concept_p_x, h_x, _ = self.original_model.conceptizer(input_batch) + concept_p_x, h_x, _ = self.original_model.conceptizer( + input_batch + ) # torch.Size([batch, 19]) torch.Size([batch, 2, 10]) torch.Size([batch, 2, 10]) if not self.output_all: return original_output # I want to flat all the tensors in this way: - return torch.cat((original_output, concept_p_x), dim=1) + return torch.cat( + (original_output, concept_p_x), dim=1 + ) def get_ensembles(self, la_model, n_models): @@ -1765,16 +1949,21 @@ def get_ensembles(self, la_model, n_models): np.set_printoptions(threshold=sys.maxsize) for i, mp in enumerate(self.model_possibilities): - _vector_to_parameters(mp, la_model.model.last_layer.parameters()) + _vector_to_parameters( + mp, la_model.model.last_layer.parameters() + ) ensembles.append( - new_model_copy(la_model.model.model.original_model) + new_model_copy( + la_model.model.model.original_model + ) ) if i == n_models - 1: break # restore original model _vector_to_parameters( - la_model.mean, la_model.model.last_layer.parameters() + la_model.mean, + la_model.model.last_layer.parameters(), ) # return an ensembles of models return ensembles @@ -1784,7 +1973,9 @@ def get_ensembles(self, la_model, n_models): la_val_loader = WrapperDataLoader(val_loader) # wrap the model - la_model = WrapperModel(new_model_copy(self.model), self.model.device) + la_model = WrapperModel( + new_model_copy(self.model), self.model.device + ) la_model.to(self.model.device) la = Laplace( @@ -1793,21 +1984,27 @@ def get_ensembles(self, la_model, n_models): subset_of_weights="last_layer", # subset_of_weights='subnetwork', hessian_structure="diag", # hessian_structure='full', # hessian_structure='diag', # hessian_structure='kron', backend=AsdlGGN, - backend_kwargs={'boia':True} + backend_kwargs={"boia": True}, ) - return self._fit_la_model(la, la_training_loader, la_val_loader) + return self._fit_la_model( + la, la_training_loader, la_val_loader + ) # fit the laplace model def _fit_la_model(self, la, la_training_loader, la_val_loader): fprint("Doing Laplace fit...") la.fit(la_training_loader) - la.optimize_prior_precision(method="marglik", val_loader=la_val_loader) + la.optimize_prior_precision( + method="marglik", val_loader=la_val_loader + ) # Enabling last layer output all la.model.model.output_all = True return la - def test_and_save_csv(self, test_loader, save_file_name, fold=None, pcbm=False): + def test_and_save_csv( + self, test_loader, save_file_name, fold=None, pcbm=False + ): # initialization of print's values batch_time = AverageMeter() @@ -1830,7 +2027,9 @@ def test_and_save_csv(self, test_loader, save_file_name, fold=None, pcbm=False): concepts.cuda(self.device), ) - _ = self.laplace_single_prediction(self.laplace_model, inputs, 5, 21, False) + _ = self.laplace_single_prediction( + self.laplace_model, inputs, 5, 21, False + ) # Call Laplace ensembles ensemble = self.laplace_model.model.model.get_ensembles( @@ -1838,15 +2037,19 @@ def test_and_save_csv(self, test_loader, save_file_name, fold=None, pcbm=False): ) # Call Ensemble predict (same as ensemble) - (label_prob_ens, concept_prob_ens) = self.ensemble_single_la_predict( - ensemble, inputs, False + (label_prob_ens, concept_prob_ens) = ( + self.ensemble_single_la_predict( + ensemble, inputs, False + ) ) output = torch.tensor(np.mean(label_prob_ens, axis=0)) pc_prob = torch.tensor(np.mean(concept_prob_ens, axis=0)) if self.cuda: - output, pc_prob = output.cuda(self.device), pc_prob.cuda(self.device) + output, pc_prob = output.cuda( + self.device + ), pc_prob.cuda(self.device) # prediction_criterion is defined in __init__ of "class GradPenaltyTrainer" loss = self.prediction_criterion(output, targets) @@ -1854,11 +2057,17 @@ def test_and_save_csv(self, test_loader, save_file_name, fold=None, pcbm=False): # measure accuracy and record loss if self.nclasses > 4: # mainly use this line (current) - prec1, _ = self.accuracy(output.data, targets, topk=(1, 5)) + prec1, _ = self.accuracy( + output.data, targets, topk=(1, 5) + ) elif self.nclasses == 3: - prec1, _ = self.accuracy(output.data, targets, topk=(1, 3)) + prec1, _ = self.accuracy( + output.data, targets, topk=(1, 3) + ) else: - prec1, _ = self.binary_accuracy(output.data, targets), [100] + prec1, _ = self.binary_accuracy( + output.data, targets + ), [100] # update each value of print's values losses.update(loss.data.cpu().numpy(), inputs.size(0)) @@ -1912,7 +2121,10 @@ def test_and_save_csv(self, test_loader, save_file_name, fold=None, pcbm=False): "Test on " + fold + ": [{0}/{1}]\t" "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" "Loss {loss.val:.4f} ({loss.avg:.4f})\t".format( - i, len(test_loader), batch_time=batch_time, loss=losses + i, + len(test_loader), + batch_time=batch_time, + loss=losses, ) ) @@ -1948,7 +2160,10 @@ def test_and_save_csv(self, test_loader, save_file_name, fold=None, pcbm=False): "Test: [{0}/{1}]\t" "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" "Loss {loss.val:.4f} ({loss.avg:.4f})\t".format( - i, len(test_loader), batch_time=batch_time, loss=losses + i, + len(test_loader), + batch_time=batch_time, + loss=losses, ) ) @@ -1978,7 +2193,11 @@ def laplace_prediction( # prediction (this one is only needed for making laplace store the models inside the WrapperModel) _ = self.laplace_single_prediction( - laplace_model, images, output_classes, num_concepts, apply_softmax + laplace_model, + images, + output_classes, + num_concepts, + apply_softmax, ) # Call Laplace ensembles @@ -1987,8 +2206,10 @@ def laplace_prediction( ) # Call Ensemble predict (same as ensemble) - (label_prob_ens, concept_prob_ens) = self.ensemble_single_la_predict( - ensemble, images, apply_softmax + (label_prob_ens, concept_prob_ens) = ( + self.ensemble_single_la_predict( + ensemble, images, apply_softmax + ) ) # Concatenate the output @@ -1998,12 +2219,18 @@ def laplace_prediction( y_pred = label_prob_ens pc_pred = concept_prob_ens else: - y_true = np.concatenate([y_true, labels.detach().cpu().numpy()], axis=0) + y_true = np.concatenate( + [y_true, labels.detach().cpu().numpy()], axis=0 + ) c_true = np.concatenate( [c_true, concepts.detach().cpu().numpy()], axis=0 ) - y_pred = np.concatenate([y_pred, label_prob_ens], axis=1) - pc_pred = np.concatenate([pc_pred, concept_prob_ens], axis=1) + y_pred = np.concatenate( + [y_pred, label_prob_ens], axis=1 + ) + pc_pred = np.concatenate( + [pc_pred, concept_prob_ens], axis=1 + ) return y_true, c_true, y_pred, pc_pred @@ -2018,7 +2245,11 @@ def laplace_single_prediction( ): pred = la(sample_batch, pred_type="nn", link_approx="mc") recovered_pred = self.recover_predictions_from_laplace( - pred, sample_batch.shape[0], output_classes, num_concepts, apply_softmax + pred, + sample_batch.shape[0], + output_classes, + num_concepts, + apply_softmax, ) return recovered_pred @@ -2032,7 +2263,9 @@ def recover_predictions_from_laplace( apply_softmax=False, ): # Recovering shape - ys = la_prediction[:, :output_classes] # take all until output_classes + ys = la_prediction[ + :, :output_classes + ] # take all until output_classes pCS = la_prediction[ :, output_classes: ] # take all from output_classes until the end @@ -2061,9 +2294,12 @@ def ensemble_single_la_predict( model.return_both_concept_out_prob = False # get out the different output - label_prob = [lab.detach().cpu().numpy() for lab, _ in output_list] # 30 + label_prob = [ + lab.detach().cpu().numpy() for lab, _ in output_list + ] # 30 concept_prob = [ - concept.detach().cpu().numpy() for _, concept in output_list + concept.detach().cpu().numpy() + for _, concept in output_list ] # 30 label_prob = np.stack(label_prob, axis=0) @@ -2093,10 +2329,14 @@ def worlds_probability( apply_softmax, ) - fstop_prob = compute_forward_stop_prob(pc_pred) # data, possibleworlds + fstop_prob = compute_forward_stop_prob( + pc_pred + ) # data, possibleworlds left_prob = compute_left(pc_pred) # data, possibleworlds right_prob = compute_right(pc_pred) # data, possibleworlds - y_prob = compute_output_probability(y_pred) # data, possibleworlds + y_prob = compute_output_probability( + y_pred + ) # data, possibleworlds w_probs = [fstop_prob, left_prob, right_prob] @@ -2104,9 +2344,13 @@ def worlds_probability( w_predictions_prob_value = [] for prob in w_probs: w_predictions.append(np.argmax(prob, axis=-1)) # data, 1 - w_predictions_prob_value.append(np.max(prob, axis=-1)) # data, 1 + w_predictions_prob_value.append( + np.max(prob, axis=-1) + ) # data, 1 - fstop_ground = compute_forward_stop_groundtruth(c_true) # data, 1 + fstop_ground = compute_forward_stop_groundtruth( + c_true + ) # data, 1 left_ground = compute_left_groundtruth(c_true) # data, 1 right_ground = compute_right_groundtruth(c_true) # data, 1 @@ -2133,7 +2377,9 @@ def worlds_probability( y_true = np.concatenate(y_trues[:4], axis=-1) y_predictions = np.concatenate(y_preds_list[:4], axis=-1) - y_predictions_prob = np.concatenate(y_preds_prob_list[:4], axis=-1) + y_predictions_prob = np.concatenate( + y_preds_prob_list[:4], axis=-1 + ) pc_pred_mean = np.mean(pc_pred, axis=0) pc_prob = pc_pred_mean @@ -2191,9 +2437,13 @@ def p_c_x_distance( if i == 0: c_prb = concept_prob else: - c_prb = np.concatenate([c_prb, concept_prob], axis=1) + c_prb = np.concatenate( + [c_prb, concept_prob], axis=1 + ) - fstop_prob = compute_forward_stop_prob(c_prb) # data, possibleworlds + fstop_prob = compute_forward_stop_prob( + c_prb + ) # data, possibleworlds left_prob = compute_left(c_prb) # data, possibleworlds right_prob = compute_right(c_prb) # data, possibleworlds @@ -2209,7 +2459,9 @@ def p_c_x_distance( ) def get_ensemble_from_bayes(self, n_ensemble, inputs): - _ = self.laplace_single_prediction(self.laplace_model, inputs, 5, 21, False) + _ = self.laplace_single_prediction( + self.laplace_model, inputs, 5, 21, False + ) ensemble = self.laplace_model.model.model.get_ensembles( self.laplace_model, n_ensemble @@ -2231,11 +2483,17 @@ def __init__(self, model, args, device, name): # as previously, override the current model and then deepens! def test_and_save(self, test_loader, save_file_name, fold=None): for i in range(len(self.ensemble)): - current_model_name = add_previous_dot(save_file_name, f"_{self.name}_{i}") + current_model_name = add_previous_dot( + save_file_name, f"_{self.name}_{i}" + ) self.model = self.ensemble[i] # override the model - super().test_and_save(test_loader, current_model_name, fold) + super().test_and_save( + test_loader, current_model_name, fold + ) - def test_and_save_csv(self, test_loader, save_file_name, fold=None, pcbm=False): + def test_and_save_csv( + self, test_loader, save_file_name, fold=None, pcbm=False + ): print("Saving...", save_file_name) @@ -2263,27 +2521,39 @@ def test_and_save_csv(self, test_loader, save_file_name, fold=None, pcbm=False): ) # compute output - (label_prob_ens, concept_prob_ens) = self._ensemble_single_predict( - self.ensemble, inputs, False + (label_prob_ens, concept_prob_ens) = ( + self._ensemble_single_predict( + self.ensemble, inputs, False + ) ) output = torch.tensor(np.mean(label_prob_ens, axis=0)) pc_prob = torch.tensor(np.mean(concept_prob_ens, axis=0)) if self.cuda: - output, pc_prob = output.cuda(self.device), pc_prob.cuda(self.device) + output, pc_prob = output.cuda( + self.device + ), pc_prob.cuda(self.device) # prediction_criterion is defined in __init__ of "class GradPenaltyTrainer" - loss = self.prediction_criterion(torch.tensor(output), targets) + loss = self.prediction_criterion( + torch.tensor(output), targets + ) # measure accuracy and record loss if self.nclasses > 4: # mainly use this line (current) - prec1, _ = self.accuracy(output.data, targets, topk=(1, 5)) + prec1, _ = self.accuracy( + output.data, targets, topk=(1, 5) + ) elif self.nclasses == 3: - prec1, _ = self.accuracy(output.data, targets, topk=(1, 3)) + prec1, _ = self.accuracy( + output.data, targets, topk=(1, 3) + ) else: - prec1, _ = self.binary_accuracy(output.data, targets), [100] + prec1, _ = self.binary_accuracy( + output.data, targets + ), [100] # update each value of print's values losses.update(loss.data.cpu().numpy(), inputs.size(0)) @@ -2337,7 +2607,10 @@ def test_and_save_csv(self, test_loader, save_file_name, fold=None, pcbm=False): "Test on " + fold + ": [{0}/{1}]\t" "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" "Loss {loss.val:.4f} ({loss.avg:.4f})\t".format( - i, len(test_loader), batch_time=batch_time, loss=losses + i, + len(test_loader), + batch_time=batch_time, + loss=losses, ) ) @@ -2373,7 +2646,10 @@ def test_and_save_csv(self, test_loader, save_file_name, fold=None, pcbm=False): "Test: [{0}/{1}]\t" "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" "Loss {loss.val:.4f} ({loss.avg:.4f})\t".format( - i, len(test_loader), batch_time=batch_time, loss=losses + i, + len(test_loader), + batch_time=batch_time, + loss=losses, ) ) @@ -2409,7 +2685,9 @@ def _populate_pcx_dataset(self, model, pcx_loader, batch_size): return dataset - def save_model_params_all(self, save_path, separate_from_others, lambda_h): + def save_model_params_all( + self, save_path, separate_from_others, lambda_h + ): for i, model in enumerate(self.ensemble): file_name = ( f"dens-{i}-seed-{self.seed}-lambda-h-{lambda_h}-real_kl.pth" @@ -2418,7 +2696,9 @@ def save_model_params_all(self, save_path, separate_from_others, lambda_h): ) super().save_model_params(model, save_path, file_name) - def _populate_pcx_dataset_knowledge_aware(self, model, pcx_loader, n_facts): + def _populate_pcx_dataset_knowledge_aware( + self, model, pcx_loader, n_facts + ): print("Initializing PWX database...") from DPL.utils_problog import ( @@ -2427,7 +2707,9 @@ def _populate_pcx_dataset_knowledge_aware(self, model, pcx_loader, n_facts): build_world_queries_matrix_R, ) - FS_w_q = build_world_queries_matrix_complete_FS().to(self.model.device) + FS_w_q = build_world_queries_matrix_complete_FS().to( + self.model.device + ) L_w_q = build_world_queries_matrix_L().to(self.model.device) R_w_q = build_world_queries_matrix_R().to(self.model.device) @@ -2453,8 +2735,12 @@ def _populate_pcx_dataset_knowledge_aware(self, model, pcx_loader, n_facts): fstop_prob = compute_forward_stop_prob( p_c_x_ext, numpy=False ) # data, possibleworlds - left_prob = compute_left(p_c_x_ext, numpy=False) # data, possibleworlds - right_prob = compute_right(p_c_x_ext, numpy=False) # data, possibleworlds + left_prob = compute_left( + p_c_x_ext, numpy=False + ) # data, possibleworlds + right_prob = compute_right( + p_c_x_ext, numpy=False + ) # data, possibleworlds # wfstop list w_fstop_tmp = [] @@ -2468,9 +2754,15 @@ def _populate_pcx_dataset_knowledge_aware(self, model, pcx_loader, n_facts): R_w_q_transposed = R_w_q.t() # Extracting probabilities using indexing and performing element-wise multiplication - fstop_prob = fstop_prob * FS_w_q_transposed[label_indices[:, 0], :] - left_prob = left_prob * L_w_q_transposed[label_indices[:, 1], :] - right_prob = right_prob * R_w_q_transposed[label_indices[:, 2], :] + fstop_prob = ( + fstop_prob * FS_w_q_transposed[label_indices[:, 0], :] + ) + left_prob = ( + left_prob * L_w_q_transposed[label_indices[:, 1], :] + ) + right_prob = ( + right_prob * R_w_q_transposed[label_indices[:, 2], :] + ) # Adding a small constant for normalization fstop_prob += 1e-5 @@ -2489,13 +2781,22 @@ def _populate_pcx_dataset_knowledge_aware(self, model, pcx_loader, n_facts): # Reshaping and appending to temporary lists w_fstop_tmp = list( - map(lambda x: x.squeeze(0), torch.split(fstop_prob, 1, dim=0)) + map( + lambda x: x.squeeze(0), + torch.split(fstop_prob, 1, dim=0), + ) ) w_left_tmp = list( - map(lambda x: x.squeeze(0), torch.split(left_prob, 1, dim=0)) + map( + lambda x: x.squeeze(0), + torch.split(left_prob, 1, dim=0), + ) ) w_right_tmp = list( - map(lambda x: x.squeeze(0), torch.split(right_prob, 1, dim=0)) + map( + lambda x: x.squeeze(0), + torch.split(right_prob, 1, dim=0), + ) ) # Append the new tensor @@ -2527,7 +2828,9 @@ def _populate_pcx_dataset_knowledge_aware(self, model, pcx_loader, n_facts): return dataset # update the dataset for the deep diversification - def _update_pcx_dataset(self, model, dataset, pcx_loader, batch_size): + def _update_pcx_dataset( + self, model, dataset, pcx_loader, batch_size + ): indexes = 0 for _, data in enumerate(pcx_loader): @@ -2551,7 +2854,9 @@ def _update_pcx_dataset(self, model, dataset, pcx_loader, batch_size): return dataset - def _update_pcx_dataset_knowledge_aware(self, model, dataset, pcx_loader, n_facts): + def _update_pcx_dataset_knowledge_aware( + self, model, dataset, pcx_loader, n_facts + ): print("Updating PWX database...") indexes = 0 @@ -2571,8 +2876,12 @@ def _update_pcx_dataset_knowledge_aware(self, model, dataset, pcx_loader, n_fact fstop_prob = compute_forward_stop_prob( p_c_x_ext, numpy=False ) # data, possibleworlds - left_prob = compute_left(p_c_x_ext, numpy=False) # data, possibleworlds - right_prob = compute_right(p_c_x_ext, numpy=False) # data, possibleworlds + left_prob = compute_left( + p_c_x_ext, numpy=False + ) # data, possibleworlds + right_prob = compute_right( + p_c_x_ext, numpy=False + ) # data, possibleworlds # wfstop list w_fstop_tmp = [] @@ -2586,9 +2895,15 @@ def _update_pcx_dataset_knowledge_aware(self, model, dataset, pcx_loader, n_fact R_w_q_transposed = dataset.R_w_q.t() # Extracting probabilities using indexing and performing element-wise multiplication - fstop_prob = fstop_prob * FS_w_q_transposed[label_indices[:, 0], :] - left_prob = left_prob * L_w_q_transposed[label_indices[:, 1], :] - right_prob = right_prob * R_w_q_transposed[label_indices[:, 2], :] + fstop_prob = ( + fstop_prob * FS_w_q_transposed[label_indices[:, 0], :] + ) + left_prob = ( + left_prob * L_w_q_transposed[label_indices[:, 1], :] + ) + right_prob = ( + right_prob * R_w_q_transposed[label_indices[:, 2], :] + ) # Adding a small constant for normalization fstop_prob += 1e-5 @@ -2607,13 +2922,22 @@ def _update_pcx_dataset_knowledge_aware(self, model, dataset, pcx_loader, n_fact # Reshaping and appending to temporary lists w_fstop_tmp = list( - map(lambda x: x.squeeze(0), torch.split(fstop_prob, 1, dim=0)) + map( + lambda x: x.squeeze(0), + torch.split(fstop_prob, 1, dim=0), + ) ) w_left_tmp = list( - map(lambda x: x.squeeze(0), torch.split(left_prob, 1, dim=0)) + map( + lambda x: x.squeeze(0), + torch.split(left_prob, 1, dim=0), + ) ) w_right_tmp = list( - map(lambda x: x.squeeze(0), torch.split(right_prob, 1, dim=0)) + map( + lambda x: x.squeeze(0), + torch.split(right_prob, 1, dim=0), + ) ) j = 0 @@ -2630,18 +2954,30 @@ def give_full_worlds(self, model_itself_pc_x): p_c_x_ext = torch.unsqueeze(model_itself_pc_x, dim=0) - f_prob = compute_forward_prob(p_c_x_ext, numpy=False) # data, possibleworlds - stop_prob = compute_forward_prob(p_c_x_ext, numpy=False) # data, possibleworlds - left_prob = compute_left(p_c_x_ext, numpy=False) # data, possibleworlds - right_prob = compute_right(p_c_x_ext, numpy=False) # data, possibleworlds + f_prob = compute_forward_prob( + p_c_x_ext, numpy=False + ) # data, possibleworlds + stop_prob = compute_forward_prob( + p_c_x_ext, numpy=False + ) # data, possibleworlds + left_prob = compute_left( + p_c_x_ext, numpy=False + ) # data, possibleworlds + right_prob = compute_right( + p_c_x_ext, numpy=False + ) # data, possibleworlds return f_prob, stop_prob, left_prob, right_prob - def compute_pw_knowledge_filter(self, model_itself_pc_x, labels, wfs, wl, wR): + def compute_pw_knowledge_filter( + self, model_itself_pc_x, labels, wfs, wl, wR + ): p_c_x_ext = torch.unsqueeze(model_itself_pc_x, dim=0) - fstop_prob = compute_forward_stop_prob(p_c_x_ext, numpy=False).to( + fstop_prob = compute_forward_stop_prob( + p_c_x_ext, numpy=False + ).to( self.model.device ) # data, possibleworlds left_prob = compute_left(p_c_x_ext, numpy=False).to( @@ -2659,9 +2995,15 @@ def compute_pw_knowledge_filter(self, model_itself_pc_x, labels, wfs, wl, wR): # Extracting probabilities using indexing and performing element-wise multiplication print(FS_w_q_transposed.device, fstop_prob.device) - fstop_prob = fstop_prob * FS_w_q_transposed[label_indices[:, 0], :] - left_prob = left_prob * L_w_q_transposed[label_indices[:, 1], :] - right_prob = right_prob * R_w_q_transposed[label_indices[:, 2], :] + fstop_prob = ( + fstop_prob * FS_w_q_transposed[label_indices[:, 0], :] + ) + left_prob = ( + left_prob * L_w_q_transposed[label_indices[:, 1], :] + ) + right_prob = ( + right_prob * R_w_q_transposed[label_indices[:, 2], :] + ) # Adding a small constant for normalization fstop_prob += 1e-5 @@ -2680,11 +3022,22 @@ def compute_pw_knowledge_filter(self, model_itself_pc_x, labels, wfs, wl, wR): # Reshaping and appending to temporary lists w_fstop_tmp = list( - map(lambda x: x.squeeze(0), torch.split(fstop_prob, 1, dim=0)) + map( + lambda x: x.squeeze(0), + torch.split(fstop_prob, 1, dim=0), + ) + ) + w_left_tmp = list( + map( + lambda x: x.squeeze(0), + torch.split(left_prob, 1, dim=0), + ) ) - w_left_tmp = list(map(lambda x: x.squeeze(0), torch.split(left_prob, 1, dim=0))) w_right_tmp = list( - map(lambda x: x.squeeze(0), torch.split(right_prob, 1, dim=0)) + map( + lambda x: x.squeeze(0), + torch.split(right_prob, 1, dim=0), + ) ) w_fstop_tmp = torch.stack(w_fstop_tmp, dim=0) @@ -2716,9 +3069,15 @@ def setup( ) if separate_from_others: - self.FS_w_q = build_world_queries_matrix_complete_FS().to(self.model.device) - self.L_w_q = build_world_queries_matrix_L().to(self.model.device) - self.R_w_q = build_world_queries_matrix_R().to(self.model.device) + self.FS_w_q = build_world_queries_matrix_complete_FS().to( + self.model.device + ) + self.L_w_q = build_world_queries_matrix_L().to( + self.model.device + ) + self.R_w_q = build_world_queries_matrix_R().to( + self.model.device + ) self.train_ensembles( train_loader, train_loader_no_shuffle, @@ -2823,7 +3182,10 @@ def train_ensembles( """ self.ensemble.append(self.model) - fprint("Done!\nTotal length of the ensemble: ", len(self.ensemble)) + fprint( + "Done!\nTotal length of the ensemble: ", + len(self.ensemble), + ) # train the single model def train_single_model( @@ -2849,7 +3211,9 @@ def train_single_model( if self.args.wandb is not None: wandb.log({f"start-lr-model-{model_idx}": self.args.lr}) - early_stopper = EarlyStopper(patience=5, min_delta=0.001) # prev 0.01 + early_stopper = EarlyStopper( + patience=5, min_delta=0.001 + ) # prev 0.01 for epoch in range(epochs): # go to train_epoch function @@ -2867,7 +3231,11 @@ def train_single_model( if self.args.wandb is not None: wandb.log( - {f"lr-model-{model_idx}": float(self.scheduler.get_last_lr()[0])} + { + f"lr-model-{model_idx}": float( + self.scheduler.get_last_lr()[0] + ) + } ) # # validate evaluation @@ -2917,7 +3285,9 @@ def kl_paper(p_model: torch.tensor, p_rest: torch.tensor, k): p_rest = p_rest / (1 + (p_rest.shape[1] * 1e-5)) ratio = torch.div(p_rest, p_model) - kl_ew = torch.sum(p_model * torch.log(1 + (k - 1) * ratio), dim=1) + kl_ew = torch.sum( + p_model * torch.log(1 + (k - 1) * ratio), dim=1 + ) return torch.mean(kl_ew, dim=0) @@ -2937,14 +3307,18 @@ def kl_paper(p_model: torch.tensor, p_rest: torch.tensor, k): pred_loss = self.prediction_criterion(pred, targets) # save loss (only value) to the all_losses list - all_losses = {f"model {model_idx} prediction": pred_loss.cpu().data.numpy()} + all_losses = { + f"model {model_idx} prediction": pred_loss.cpu().data.numpy() + } # compute loss of known concets and discriminator fprint(f"\tCE Loss: {pred_loss}") def entropy(probabilities: torch.tensor, concept=False): def concept_entropy(p): - return -torch.sum(p * torch.log2(p)) / (len(p) * math.log(2)) + return -torch.sum(p * torch.log2(p)) / ( + len(p) * math.log(2) + ) def single_concept_entropy(p): positive = p * torch.log2(p) @@ -2953,18 +3327,29 @@ def single_concept_entropy(p): return entropy_per_element.mean() probabilities = probabilities + 1e-5 - probabilities = probabilities / (1 + (probabilities.shape[1] * 1e-5)) + probabilities = probabilities / ( + 1 + (probabilities.shape[1] * 1e-5) + ) from functorch import vmap if concept: entropies = vmap(concept_entropy)(probabilities) else: - entropies = vmap(single_concept_entropy)(probabilities) + entropies = vmap(single_concept_entropy)( + probabilities + ) return torch.mean(entropies, dim=0) - h_loss, hh_labeled = self.concept_learning_loss_for_weak_supervision( - inputs, all_losses, concepts, self.args.cbm, self.args.senn, epoch + h_loss, hh_labeled = ( + self.concept_learning_loss_for_weak_supervision( + inputs, + all_losses, + concepts, + self.args.cbm, + self.args.senn, + epoch, + ) ) # add entropy on concepts @@ -2985,7 +3370,9 @@ def single_concept_entropy(p): # Compute adversarial loss out_dict_adversarial = self.model(adversarial_batch) - loss_adversarial = self.prediction_criterion(out_dict_adversarial, targets) + loss_adversarial = self.prediction_criterion( + out_dict_adversarial, targets + ) # Minimize the combined loss l(θm, xbatch, ybatch) + l(θm, advbatch, advbatch) w.r.t. θm loss_adversarial.backward() @@ -3001,26 +3388,37 @@ def single_concept_entropy(p): # Adding entropy if model_idx == 0: - loss = loss + (1 - entropy(model_itself_pc_x)) * lambda_h - print("\tEntropy loss:", (1 - entropy(model_itself_pc_x)).item()) + loss = ( + loss + (1 - entropy(model_itself_pc_x)) * lambda_h + ) + print( + "\tEntropy loss:", + (1 - entropy(model_itself_pc_x)).item(), + ) all_losses.update( - {f"model {model_idx} all loss": loss.cpu().data.numpy()} + { + f"model {model_idx} all loss": loss.cpu().data.numpy() + } ) all_losses.update( - {f"entropy {model_idx} loss": e_pc.cpu().data.numpy()} + { + f"entropy {model_idx} loss": e_pc.cpu().data.numpy() + } ) # usual update for the deep diversification thing if model_idx > 0: if self.knowledge_aware_kl: - model_itself_pfs_x, model_itself_pleft_x, model_itself_right_x = ( - self.compute_pw_knowledge_filter( - model_itself_pc_x=model_itself_pc_x, - labels=targets, - wfs=self.FS_w_q, - wl=self.L_w_q, - wR=self.R_w_q, - ) + ( + model_itself_pfs_x, + model_itself_pleft_x, + model_itself_right_x, + ) = self.compute_pw_knowledge_filter( + model_itself_pc_x=model_itself_pc_x, + labels=targets, + wfs=self.FS_w_q, + wl=self.L_w_q, + wR=self.R_w_q, ) ( @@ -3063,21 +3461,29 @@ def single_concept_entropy(p): # mean forward stop pf_list_ensemble = torch.stack(pf_list_ensemble) - other_pf_mean = torch.mean(pf_list_ensemble, dim=0) # .unsqueeze(0) + other_pf_mean = torch.mean( + pf_list_ensemble, dim=0 + ) # .unsqueeze(0) - pstop_list_ensemble = torch.stack(pstop_list_ensemble) + pstop_list_ensemble = torch.stack( + pstop_list_ensemble + ) other_pstop_mean = torch.mean( pstop_list_ensemble, dim=0 ) # .unsqueeze(0) # mean left - pleftx_list_ensemble = torch.stack(pleftx_list_ensemble) + pleftx_list_ensemble = torch.stack( + pleftx_list_ensemble + ) other_pleftx_mean = torch.mean( pleftx_list_ensemble, dim=0 ) # .unsqueeze(0) # mean right - prightx_list_ensemble = torch.stack(prightx_list_ensemble) + prightx_list_ensemble = torch.stack( + prightx_list_ensemble + ) other_prightx_mean = torch.mean( prightx_list_ensemble, dim=0 ) # .unsqueeze(0) @@ -3108,11 +3514,19 @@ def single_concept_entropy(p): False, ) kl = kl_forward + kl_stop + kl_left + kl_right - entropy_pc_x = entropy(model_itself_pc_x, True) + entropy_pc_x = entropy( + model_itself_pc_x, True + ) entropy_f = entropy(model_itself_pf_x, True) - entropy_s = entropy(model_itself_pstop_x, True) - entropy_l = entropy(model_itself_pleft_x, True) - entropy_r = entropy(model_itself_right_x, True) + entropy_s = entropy( + model_itself_pstop_x, True + ) + entropy_l = entropy( + model_itself_pleft_x, True + ) + entropy_r = entropy( + model_itself_right_x, True + ) distance = ( kl_paper( @@ -3152,7 +3566,10 @@ def single_concept_entropy(p): ) distance = distance / (math.log(2) * 21) + 1 distance = lambda_kl * distance - distance = distance + (1 - entropy(model_itself_pc_x)) * lambda_h + distance = ( + distance + + (1 - entropy(model_itself_pc_x)) * lambda_h + ) print("kl", kl.item()) print("entropy p(c|x)", entropy_pc_x.item()) @@ -3162,16 +3579,24 @@ def single_concept_entropy(p): print("entropy r", entropy_r.item()) all_losses.update( - {f"model {model_idx} kl_f (v)": kl_forward.cpu().data.numpy()} + { + f"model {model_idx} kl_f (v)": kl_forward.cpu().data.numpy() + } ) all_losses.update( - {f"model {model_idx} kl_s (v)": kl_stop.cpu().data.numpy()} + { + f"model {model_idx} kl_s (v)": kl_stop.cpu().data.numpy() + } ) all_losses.update( - {f"model {model_idx} kl_left (v)": kl_left.cpu().data.numpy()} + { + f"model {model_idx} kl_left (v)": kl_left.cpu().data.numpy() + } ) all_losses.update( - {f"model {model_idx} kl_right (v)": kl_right.cpu().data.numpy()} + { + f"model {model_idx} kl_right (v)": kl_right.cpu().data.numpy() + } ) all_losses.update( { @@ -3243,8 +3668,12 @@ def single_concept_entropy(p): pcx_list_ensemble.append(other_model_pc_x) - pcx_list_ensemble = torch.stack(pcx_list_ensemble, dim=0) - pcx_list_ensemble = torch.mean(pcx_list_ensemble, dim=0) + pcx_list_ensemble = torch.stack( + pcx_list_ensemble, dim=0 + ) + pcx_list_ensemble = torch.mean( + pcx_list_ensemble, dim=0 + ) """ distance = - lambda_h * torch.mean( @@ -3259,9 +3688,14 @@ def single_concept_entropy(p): model_itself_pc_x, pcx_list_ensemble ) """ - print(model_itself_pc_x.shape, pcx_list_ensemble.shape) + print( + model_itself_pc_x.shape, + pcx_list_ensemble.shape, + ) distance = lambda_h * kl_paper( - model_itself_pc_x, pcx_list_ensemble, len(self.ensemble) + 1 + model_itself_pc_x, + pcx_list_ensemble, + len(self.ensemble) + 1, ) loss += distance @@ -3300,7 +3734,9 @@ def train_epoch_single_model( self.model.train() end = time.time() - for i, (inputs, targets, concepts) in enumerate(train_loader, 0): + for i, (inputs, targets, concepts) in enumerate( + train_loader, 0 + ): # measure data loading time data_time.update(time.time() - end) @@ -3338,16 +3774,24 @@ def train_epoch_single_model( # measure accuracy and record loss if self.nclasses > 4: # mainly use this line (current) - prec1, _ = self.accuracy(outputs.data, targets.data, topk=(1, 5)) + prec1, _ = self.accuracy( + outputs.data, targets.data, topk=(1, 5) + ) elif self.nclasses in [3, 4]: prec1, _ = self.accuracy( - outputs.data, targets.data, topk=(1, self.nclasses) + outputs.data, + targets.data, + topk=(1, self.nclasses), ) else: - prec1, _ = self.binary_accuracy(outputs.data, targets.data), [100] + prec1, _ = self.binary_accuracy( + outputs.data, targets.data + ), [100] # update each value of fprint's values - losses.update(loss.data.cpu().numpy(), pretrained_out.size(0)) + losses.update( + loss.data.cpu().numpy(), pretrained_out.size(0) + ) top1.update(prec1[0], pretrained_out.size(0)) if not self.args.senn: @@ -3414,15 +3858,26 @@ def _prepare_model(self): sys.exit(1) parametrizer1 = dfc_parametrizer( - 2048, 1024, 512, 256, 128, self.nconcepts, self.theta_dim, layers=4 + 2048, + 1024, + 512, + 256, + 128, + self.nconcepts, + self.theta_dim, + layers=4, ) if self.cbm == True: aggregator = CBM_aggregator( - self.concept_dim, self.nclasses, self.nconcepts_labeled + self.concept_dim, + self.nclasses, + self.nconcepts_labeled, ) else: - aggregator = additive_scalar_aggregator(self.concept_dim, self.nclasses) + aggregator = additive_scalar_aggregator( + self.concept_dim, self.nclasses + ) if self.model_name == "dpl": model = DPL( @@ -3453,7 +3908,9 @@ def _prepare_model(self): self.model.parameters(), lr=self.lr, betas=optim_betas ) elif self.opt == "rmsprop": - self.optimizer = optim.RMSprop(self.model.parameters(), lr=self.lr) + self.optimizer = optim.RMSprop( + self.model.parameters(), lr=self.lr + ) elif self.opt == "sgd": self.optimizer = optim.SGD( self.model.parameters(), @@ -3486,9 +3943,12 @@ def _ensemble_single_predict( model.return_both_concept_out_prob = False # get out the different output - label_prob = [lab.detach().cpu().numpy() for lab, _ in output_list] # 30 + label_prob = [ + lab.detach().cpu().numpy() for lab, _ in output_list + ] # 30 concept_prob = [ - concept.detach().cpu().numpy() for _, concept in output_list + concept.detach().cpu().numpy() + for _, concept in output_list ] # 30 label_prob = np.stack(label_prob, axis=0) @@ -3513,8 +3973,10 @@ def ensemble_predict(self, loader, apply_softmax=False): ) # Call Ensemble predict - (label_prob_ens, concept_prob_ens) = self._ensemble_single_predict( - self.ensemble, images, apply_softmax + (label_prob_ens, concept_prob_ens) = ( + self._ensemble_single_predict( + self.ensemble, images, apply_softmax + ) ) # Concatenate the output @@ -3524,12 +3986,18 @@ def ensemble_predict(self, loader, apply_softmax=False): y_pred = label_prob_ens pc_pred = concept_prob_ens else: - y_true = np.concatenate([y_true, labels.detach().cpu().numpy()], axis=0) + y_true = np.concatenate( + [y_true, labels.detach().cpu().numpy()], axis=0 + ) c_true = np.concatenate( [c_true, concepts.detach().cpu().numpy()], axis=0 ) - y_pred = np.concatenate([y_pred, label_prob_ens], axis=1) - pc_pred = np.concatenate([pc_pred, concept_prob_ens], axis=1) + y_pred = np.concatenate( + [y_pred, label_prob_ens], axis=1 + ) + pc_pred = np.concatenate( + [pc_pred, concept_prob_ens], axis=1 + ) return y_true, c_true, y_pred, pc_pred @@ -3543,12 +4011,18 @@ def worlds_probability( apply_softmax=False, ): - y_true, c_true, y_pred, pc_pred = self.ensemble_predict(loader, apply_softmax) + y_true, c_true, y_pred, pc_pred = self.ensemble_predict( + loader, apply_softmax + ) - fstop_prob = compute_forward_stop_prob(pc_pred) # data, possibleworlds + fstop_prob = compute_forward_stop_prob( + pc_pred + ) # data, possibleworlds left_prob = compute_left(pc_pred) # data, possibleworlds right_prob = compute_right(pc_pred) # data, possibleworlds - y_prob = compute_output_probability(y_pred) # data, possibleworlds + y_prob = compute_output_probability( + y_pred + ) # data, possibleworlds w_probs = [fstop_prob, left_prob, right_prob] @@ -3556,9 +4030,13 @@ def worlds_probability( w_predictions_prob_value = [] for prob in w_probs: w_predictions.append(np.argmax(prob, axis=-1)) # data, 1 - w_predictions_prob_value.append(np.max(prob, axis=-1)) # data, 1 + w_predictions_prob_value.append( + np.max(prob, axis=-1) + ) # data, 1 - fstop_ground = compute_forward_stop_groundtruth(c_true) # data, 1 + fstop_ground = compute_forward_stop_groundtruth( + c_true + ) # data, 1 left_ground = compute_left_groundtruth(c_true) # data, 1 right_ground = compute_right_groundtruth(c_true) # data, 1 @@ -3585,7 +4063,9 @@ def worlds_probability( y_true = np.concatenate(y_trues[:4], axis=-1) y_predictions = np.concatenate(y_preds_list[:4], axis=-1) - y_predictions_prob = np.concatenate(y_preds_prob_list[:4], axis=-1) + y_predictions_prob = np.concatenate( + y_preds_prob_list[:4], axis=-1 + ) pc_pred_mean = np.mean(pc_pred, axis=0) pc_prob = pc_pred_mean @@ -3649,9 +4129,13 @@ def p_c_x_distance( if i == 0: c_prb = concept_prob else: - c_prb = np.concatenate([c_prb, concept_prob], axis=1) + c_prb = np.concatenate( + [c_prb, concept_prob], axis=1 + ) - fstop_prob = compute_forward_stop_prob(c_prb) # data, possibleworlds + fstop_prob = compute_forward_stop_prob( + c_prb + ) # data, possibleworlds left_prob = compute_left(c_prb) # data, possibleworlds right_prob = compute_right(c_prb) # data, possibleworlds @@ -3675,7 +4159,9 @@ def get_ensemble_from_bayes(self, n_ensemble): # Mein Freund ClassificationTesterFactory class ClassificationTesterFactory: @staticmethod - def get_model(name: str, model, args, device) -> ClassificationTester: + def get_model( + name: str, model, args, device + ) -> ClassificationTester: if name == "frequentist": return Frequentist(model, args, device) elif name == "mcdropout": @@ -3687,4 +4173,6 @@ def get_model(name: str, model, args, device) -> ClassificationTester: elif name == "resense": return DeepEnsembles(model, args, device, name) else: - raise ValueError("The chosen model is not valid: chosen", name) + raise ValueError( + "The chosen model is not valid: chosen", name + ) diff --git a/BDD_OIA/track_stuff.py b/BDD_OIA/track_stuff.py index 753d65f..b62d3c7 100644 --- a/BDD_OIA/track_stuff.py +++ b/BDD_OIA/track_stuff.py @@ -1,18 +1,17 @@ -import numpy as np -import pandas as pd - -from sklearn.metrics import f1_score -from sklearn.metrics import confusion_matrix - -import matplotlib.pyplot as plt import itertools -import torch import math +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import torch from scipy import stats +from sklearn.metrics import confusion_matrix, f1_score -def get_dfs_merged(seeds, category, lambda_h, lambda_kl, train="train", pcbm=False): +def get_dfs_merged( + seeds, category, lambda_h, lambda_kl, train="train", pcbm=False +): dfs = [] for seed in seeds: if pcbm: @@ -35,7 +34,9 @@ def get_dfs_merged(seeds, category, lambda_h, lambda_kl, train="train", pcbm=Fal return dfs -def get_dfs_single(seed, n_model, category, lambda_h, lambda_kl, train="train"): +def get_dfs_single( + seed, n_model, category, lambda_h, lambda_kl, train="train" +): df = pd.read_csv( f"out/bdd/dpl_auc-{seed}/{train}_results_of_BDD_n_mod_{n_model}_{seed}_{category}_{lambda_h}_{lambda_kl}_real_kl.csv", header=None, @@ -68,11 +69,19 @@ def compute_f1(df): all_y_true.append(y_true[:, i]) all_y_pred.append(y_pred[:, i]) preds.append(y_pred[:, i].reshape(-1, 1)) - f1_value = f1_score(y_true[:, i], y_pred[:, i], average="macro") + f1_value = f1_score( + y_true[:, i], y_pred[:, i], average="macro" + ) f1_y.append(f1_value) - y_cfs.append(confusion_matrix(y_true[:, i], y_pred[:, i], normalize="true")) + y_cfs.append( + confusion_matrix( + y_true[:, i], y_pred[:, i], normalize="true" + ) + ) to_rtn[f"F1 of Label {i}"] = f1_value - ece_value = produce_ece_curve(y_prob[:, i], y_pred[:, i], y_true[:, i]) + ece_value = produce_ece_curve( + y_prob[:, i], y_pred[:, i], y_true[:, i] + ) ece_y.append(ece_value) to_rtn[f"ECE of label {i}"] = ece_value @@ -84,14 +93,22 @@ def compute_f1(df): for i in range(21): all_c_true.append(c_true[:, i]) all_c_pred.append(np.round(prob_C[:, i])) - f1_value = f1_score(c_true[:, i], np.round(prob_C[:, i]), average="macro") + f1_value = f1_score( + c_true[:, i], np.round(prob_C[:, i]), average="macro" + ) f1_c.append(f1_value) - c_pred.append(np.round(prob_C[:, i]).astype(int).reshape(-1, 1)) + c_pred.append( + np.round(prob_C[:, i]).astype(int).reshape(-1, 1) + ) cfs.append( - confusion_matrix(c_true[:, i], np.round(prob_C[:, i]), normalize="true") + confusion_matrix( + c_true[:, i], np.round(prob_C[:, i]), normalize="true" + ) ) ece_value = produce_ece_curve( - prob_C[:, i], np.round(prob_C[:, i]).astype(int), c_true[:, i].astype(float) + prob_C[:, i], + np.round(prob_C[:, i]).astype(int), + c_true[:, i].astype(float), ) ece_c.append(ece_value) to_rtn[f"ECE of concept {i}"] = ece_value @@ -107,8 +124,12 @@ def compute_f1(df): all_c_true = np.concatenate(all_c_true, axis=0) all_c_pred = np.concatenate(all_c_pred, axis=0) - to_rtn["F1 all labels"] = f1_score(all_y_true, all_y_pred, average="macro") - to_rtn["F1 all concepts"] = f1_score(all_c_true, all_c_pred, average="macro") + to_rtn["F1 all labels"] = f1_score( + all_y_true, all_y_pred, average="macro" + ) + to_rtn["F1 all concepts"] = f1_score( + all_c_true, all_c_pred, average="macro" + ) return to_rtn @@ -145,7 +166,13 @@ def merge_dict(d1, d2): def _bin_initializer(num_bins: int): # Builds the bin return { - i: {"COUNT": 0, "CONF": 0, "ACC": 0, "BIN_ACC": 0, "BIN_CONF": 0} + i: { + "COUNT": 0, + "CONF": 0, + "ACC": 0, + "BIN_ACC": 0, + "BIN_CONF": 0, + } for i in range(num_bins) } @@ -154,7 +181,9 @@ def _populate_bins(confs, preds, labels, num_bins: int): # initializes n bins (a bin contains probability from x to x + smth (where smth is greater than zero)) bin_dict = _bin_initializer(num_bins) - for i, (confidence, prediction, label) in enumerate(zip(confs, preds, labels)): + for i, (confidence, prediction, label) in enumerate( + zip(confs, preds, labels) + ): try: binn = int(math.ceil(num_bins * confidence - 1)) except: @@ -177,8 +206,12 @@ def _populate_bins(confs, preds, labels, num_bins: int): return bin_dict -def expected_calibration_error(confs, preds, labels, num_bins: int = 10): - bin_dict = _populate_bins(confs, preds, labels, num_bins) # populate the bins +def expected_calibration_error( + confs, preds, labels, num_bins: int = 10 +): + bin_dict = _populate_bins( + confs, preds, labels, num_bins + ) # populate the bins num_samples = len(labels) # number of samples (n) ece = sum( (bin_info["BIN_ACC"] - bin_info["BIN_CONF"]).__abs__() @@ -196,7 +229,9 @@ def produce_ece_curve(p, pred, true, multilabel: bool = False): ece_data = list() for i in range(p.shape[1]): ece_data.append( - expected_calibration_error(p[:, i], pred[:, i], true[:, i])[0] + expected_calibration_error( + p[:, i], pred[:, i], true[:, i] + )[0] ) return np.mean(np.asarray(ece_data), axis=0) else: @@ -223,7 +258,10 @@ def convert_to_json_serializable(obj): elif isinstance(obj, (list, tuple)): return [convert_to_json_serializable(item) for item in obj] elif isinstance(obj, dict): - return {key: convert_to_json_serializable(value) for key, value in obj.items()} + return { + key: convert_to_json_serializable(value) + for key, value in obj.items() + } elif isinstance(obj, (int, float, bool, str, type(None))): return obj else: @@ -275,14 +313,20 @@ def get_stat( for i in range(21): keys.append(f"ECE of concept {i}") - df_bir_full = get_dfs_merged([seed], category, lambda_h, lambda_kl, train=train) + df_bir_full = get_dfs_merged( + [seed], category, lambda_h, lambda_kl, train=train + ) df_bir_list = [get_f1_per_dict(df_bir_full[0], is_list=False)] - bir_dict = filter_dict(to_dict(pd.concat(df_bir_list, ignore_index=True)), keys) + bir_dict = filter_dict( + to_dict(pd.concat(df_bir_list, ignore_index=True)), keys + ) bir_full_to_log = {} for key, value in bir_dict.items(): bir_full_to_log[f"Factorized_on_{train}_{key}"] = value - bir_full_to_log = from_list_to_value(convert_to_json_serializable(bir_full_to_log)) + bir_full_to_log = from_list_to_value( + convert_to_json_serializable(bir_full_to_log) + ) print(bir_full_to_log) if set_wandb: @@ -291,15 +335,20 @@ def get_stat( try: for i in range(n_model): - df_bir_f = get_dfs_single(seed, i, category, lambda_h, lambda_kl, train) + df_bir_f = get_dfs_single( + seed, i, category, lambda_h, lambda_kl, train + ) df_bir_list = [get_f1_per_dict(df_bir_f, is_list=False)] bir_dict = filter_dict( - to_dict(pd.concat(df_bir_list, ignore_index=True)), keys + to_dict(pd.concat(df_bir_list, ignore_index=True)), + keys, ) bir_full_to_log = {} for key, value in bir_dict.items(): - bir_full_to_log[f"Single_n_{i}_on_{train}_{key}"] = value + bir_full_to_log[f"Single_n_{i}_on_{train}_{key}"] = ( + value + ) bir_full_to_log = from_list_to_value( convert_to_json_serializable(bir_full_to_log) ) diff --git a/BDD_OIA/trainers_BDD.py b/BDD_OIA/trainers_BDD.py index ed90cc5..1928b39 100644 --- a/BDD_OIA/trainers_BDD.py +++ b/BDD_OIA/trainers_BDD.py @@ -4,24 +4,26 @@ We modified so as to fit the semi-supervised fashion. """ -# standard imports -import sys import os -import tqdm -import time import pdb import shutil -import torch + +# standard imports +import sys +import time + +import matplotlib.pyplot as plt import numpy as np -from torch.autograd import Variable +import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim -import matplotlib.pyplot as plt +import tqdm import wandb # Local imports from SENN.utils import AverageMeter +from torch.autograd import Variable # =============================================================================== # ==================== REGULARIZER UTILITIES ============================ @@ -62,7 +64,10 @@ def save_checkpoint(state, is_best, outpath, seed): filename = os.path.join(outpath, f"checkpoint-{seed}.pth.tar") torch.save(state, filename) if is_best: - shutil.copyfile(filename, os.path.join(outpath, f"model_best-{seed}.pth.tar")) + shutil.copyfile( + filename, + os.path.join(outpath, f"model_best-{seed}.pth.tar"), + ) """ @@ -148,12 +153,16 @@ def train( self.train_epoch(epoch, train_loader, val_loader, pcbm) if self.args.wandb is not None: - wandb.log({"lr": float(self.scheduler.get_last_lr()[0])}) + wandb.log( + {"lr": float(self.scheduler.get_last_lr()[0])} + ) # # validate evaluation if val_loader is not None: val_prec1 = 1 - val_prec1, last_loss = self.validate(val_loader, epoch + 1, pcbm) + val_prec1, last_loss = self.validate( + val_loader, epoch + 1, pcbm + ) # if self.args.wandb is not None: # wandb.log( {'val-loss': last_loss} ) @@ -221,7 +230,9 @@ def concept_learning_loss_for_weak_supervision( hh_labeled_list = self.model(inputs) self.model.ignore_prob_log = False else: - hh_labeled_list, h_x, real = self.model.conceptizer(inputs) + hh_labeled_list, h_x, real = self.model.conceptizer( + inputs + ) concepts = concepts.to(self.device) if not senn: @@ -242,9 +253,12 @@ def concept_learning_loss_for_weak_supervision( hh_labeled_list[0], concepts[0].to(self.device) ) for j in range(1, len(hh_labeled_list)): - labeled_loss = labeled_loss + F.binary_cross_entropy( - hh_labeled_list[j], - concepts[j].to(self.device), + labeled_loss = ( + labeled_loss + + F.binary_cross_entropy( + hh_labeled_list[j], + concepts[j].to(self.device), + ) ) # labeled_loss = labeled_loss + torch.nn.BCELoss() F.binary_cross_entropy( @@ -257,7 +271,8 @@ def concept_learning_loss_for_weak_supervision( if i in self.args.which_c: labeled_loss = ( F.binary_cross_entropy( - hh_labeled_list[0, i], concepts[0, i].to(self.device) + hh_labeled_list[0, i], + concepts[0, i].to(self.device), ) / L ) @@ -284,7 +299,9 @@ def concept_learning_loss_for_weak_supervision( if not senn: # save loss (only value) to the all_losses list - all_losses["labeled_h"] = labeled_loss.data.cpu().numpy() * self.eta + all_losses["labeled_h"] = ( + labeled_loss.data.cpu().numpy() * self.eta + ) # use in def train_batch (class GradPenaltyTrainer) return info_loss, hh_labeled_list @@ -295,14 +312,14 @@ def entropy_loss(self, pred_c, all_losses, epoch): # real uses the discriminator's loss avg_c = torch.mean(pred_c, dim=0) - total_ent = -avg_c[0] * torch.log(avg_c[0]) - (1 - avg_c[0]) * torch.log( + total_ent = -avg_c[0] * torch.log(avg_c[0]) - ( 1 - avg_c[0] - ) + ) * torch.log(1 - avg_c[0]) total_ent /= np.log(2) for i in range(1, 21): - ent_i = -avg_c[i] * torch.log(avg_c[i]) - (1 - avg_c[i]) * torch.log( + ent_i = -avg_c[i] * torch.log(avg_c[i]) - ( 1 - avg_c[i] - ) + ) * torch.log(1 - avg_c[i]) ent_i /= np.log(2) assert ent_i <= 1 and ent_i >= 0, (ent_i, avg_c[i]) @@ -330,7 +347,9 @@ def train_epoch: print errors, losses of each epoch """ - def train_epoch(self, epoch, train_loader, val_loader=None, pcbm=False): + def train_epoch( + self, epoch, train_loader, val_loader=None, pcbm=False + ): # initialization of print's values batch_time = AverageMeter() @@ -344,7 +363,9 @@ def train_epoch(self, epoch, train_loader, val_loader=None, pcbm=False): end = time.time() - for i, (inputs, targets, concepts) in enumerate(train_loader, 0): + for i, (inputs, targets, concepts) in enumerate( + train_loader, 0 + ): # measure data loading time data_time.update(time.time() - end) @@ -356,8 +377,10 @@ def train_epoch(self, epoch, train_loader, val_loader=None, pcbm=False): targets = targets.cuda(self.device) # go to def train_batch (class GradPenaltyTrainer) - outputs, loss, loss_dict, hh_labeled, pretrained_out = self.train_batch( - inputs, targets, concepts, epoch, pcbm + outputs, loss, loss_dict, hh_labeled, pretrained_out = ( + self.train_batch( + inputs, targets, concepts, epoch, pcbm + ) ) if self.args.wandb is not None: @@ -371,16 +394,24 @@ def train_epoch(self, epoch, train_loader, val_loader=None, pcbm=False): # measure accuracy and record loss if self.nclasses > 4: # mainly use this line (current) - prec1, _ = self.accuracy(outputs.data, targets.data, topk=(1, 5)) + prec1, _ = self.accuracy( + outputs.data, targets.data, topk=(1, 5) + ) elif self.nclasses in [3, 4]: prec1, _ = self.accuracy( - outputs.data, targets.data, topk=(1, self.nclasses) + outputs.data, + targets.data, + topk=(1, self.nclasses), ) else: - prec1, _ = self.binary_accuracy(outputs.data, targets.data), [100] + prec1, _ = self.binary_accuracy( + outputs.data, targets.data + ), [100] # update each value of print's values - losses.update(loss.data.cpu().numpy(), pretrained_out.size(0)) + losses.update( + loss.data.cpu().numpy(), pretrained_out.size(0) + ) top1.update(prec1[0], pretrained_out.size(0)) if not self.args.senn: @@ -497,8 +528,16 @@ def validate(self, val_loader, epoch, fold=None, pcbm=False): all_losses = {"prediction": pred_loss.cpu().data.numpy()} # compute loss of known concets and discriminator - h_loss, hh_labeled = self.concept_learning_loss_for_weak_supervision( - inputs, all_losses, concepts, self.args.cbm, self.args.senn, epoch, pcbm + h_loss, hh_labeled = ( + self.concept_learning_loss_for_weak_supervision( + inputs, + all_losses, + concepts, + self.args.cbm, + self.args.senn, + epoch, + pcbm, + ) ) loss_h += self.entropy_loss( @@ -515,14 +554,22 @@ def validate(self, val_loader, epoch, fold=None, pcbm=False): # measure accuracy and record loss if self.nclasses > 4: # mainly use this line (current) - prec1, _ = self.accuracy(output.data, targets, topk=(1, 5)) + prec1, _ = self.accuracy( + output.data, targets, topk=(1, 5) + ) elif self.nclasses == 3: - prec1, _ = self.accuracy(output.data, targets, topk=(1, 3)) + prec1, _ = self.accuracy( + output.data, targets, topk=(1, 3) + ) else: - prec1, _ = self.binary_accuracy(output.data, targets), [100] + prec1, _ = self.binary_accuracy( + output.data, targets + ), [100] # update each value of print's values - losses.update(pred_loss.data.cpu().numpy(), inputs.size(0)) + losses.update( + pred_loss.data.cpu().numpy(), inputs.size(0) + ) top1.update(prec1[0], inputs.size(0)) # measure accuracy of concepts @@ -546,7 +593,10 @@ def validate(self, val_loader, epoch, fold=None, pcbm=False): "Val: [{0}/{1}]\t" "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" "Loss {loss.val:.4f} ({loss.avg:.4f})\t".format( - i, len(val_loader), batch_time=batch_time, loss=losses + i, + len(val_loader), + batch_time=batch_time, + loss=losses, ) ) else: @@ -556,7 +606,10 @@ def validate(self, val_loader, epoch, fold=None, pcbm=False): "Val: [{0}/{1}]\t" "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" "Loss {loss.val:.4f} ({loss.avg:.4f})\t".format( - i, len(val_loader), batch_time=batch_time, loss=losses + i, + len(val_loader), + batch_time=batch_time, + loss=losses, ) ) val_loss_dict = { @@ -590,7 +643,9 @@ def test_and_save: NOTE: many code is the same to def train_epoch """ - def test_and_save_csv(self, test_loader, save_file_name, fold=None, pcbm=False): + def test_and_save_csv( + self, test_loader, save_file_name, fold=None, pcbm=False + ): print("Saving CSV..", save_file_name, "am I pcbm?", pcbm) @@ -624,11 +679,17 @@ def test_and_save_csv(self, test_loader, save_file_name, fold=None, pcbm=False): # measure accuracy and record loss if self.nclasses > 4: # mainly use this line (current) - prec1, _ = self.accuracy(output.data, targets, topk=(1, 5)) + prec1, _ = self.accuracy( + output.data, targets, topk=(1, 5) + ) elif self.nclasses == 3: - prec1, _ = self.accuracy(output.data, targets, topk=(1, 3)) + prec1, _ = self.accuracy( + output.data, targets, topk=(1, 3) + ) else: - prec1, _ = self.binary_accuracy(output.data, targets), [100] + prec1, _ = self.binary_accuracy( + output.data, targets + ), [100] # update each value of print's values losses.update(loss.data.cpu().numpy(), inputs.size(0)) @@ -689,7 +750,10 @@ def test_and_save_csv(self, test_loader, save_file_name, fold=None, pcbm=False): "Test on " + fold + ": [{0}/{1}]\t" "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" "Loss {loss.val:.4f} ({loss.avg:.4f})\t".format( - i, len(test_loader), batch_time=batch_time, loss=losses + i, + len(test_loader), + batch_time=batch_time, + loss=losses, ) ) @@ -725,7 +789,10 @@ def test_and_save_csv(self, test_loader, save_file_name, fold=None, pcbm=False): "Test: [{0}/{1}]\t" "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" "Loss {loss.val:.4f} ({loss.avg:.4f})\t".format( - i, len(test_loader), batch_time=batch_time, loss=losses + i, + len(test_loader), + batch_time=batch_time, + loss=losses, ) ) @@ -744,7 +811,9 @@ def concept_error: (added by Sawada) """ def concept_error(self, output, target): - err = torch.Tensor(1).fill_((output.round().eq(target)).float().mean() * 100) + err = torch.Tensor(1).fill_( + (output.round().eq(target)).float().mean() * 100 + ) err = (100.0 - err.data[0]) / 100 return err @@ -761,7 +830,9 @@ def binary_accuracy: def binary_accuracy(self, output, target): """Computes the accuracy""" - return torch.Tensor(1).fill_((output.round().eq(target)).float().mean() * 100) + return torch.Tensor(1).fill_( + (output.round().eq(target)).float().mean() * 100 + ) """ def accuracy: @@ -788,7 +859,9 @@ def accuracy(self, output, target, topk=(1,)): # if topk = (1,5), then, k=1 and k=5 res = [] for k in topk: - correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) + correct_k = ( + correct[:k].view(-1).float().sum(0, keepdim=True) + ) res.append(correct_k.mul_(100.0 / batch_size)) return res @@ -805,14 +878,18 @@ def plot_losses: def plot_losses(self, save_path=None): print("Saving this...", save_path) - loss_types = [k for k in self.loss_history[0].keys() if k != "iter"] + loss_types = [ + k for k in self.loss_history[0].keys() if k != "iter" + ] losses = {k: [] for k in loss_types} iters = [] for e in self.loss_history: iters.append(e["iter"]) for k in loss_types: losses[k].append(e[k]) - fig, ax = plt.subplots(1, len(loss_types), figsize=(4 * len(loss_types), 5)) + fig, ax = plt.subplots( + 1, len(loss_types), figsize=(4 * len(loss_types), 5) + ) if len(loss_types) == 1: ax = [ax] # Hacky, fix for i, k in enumerate(loss_types): @@ -830,14 +907,18 @@ def plot_losses(self, save_path=None): #### VALIDATION - loss_types = [k for k in self.val_loss_history[0].keys() if k != "iter"] + loss_types = [ + k for k in self.val_loss_history[0].keys() if k != "iter" + ] losses = {k: [] for k in loss_types} iters = [] for e in self.val_loss_history: iters.append(e["iter"]) for k in loss_types: losses[k].append(e[k]) - fig, ax = plt.subplots(1, len(loss_types), figsize=(4 * len(loss_types), 5)) + fig, ax = plt.subplots( + 1, len(loss_types), figsize=(4 * len(loss_types), 5) + ) if len(loss_types) == 1: ax = [ax] # Hacky, fix for i, k in enumerate(loss_types): @@ -890,15 +971,21 @@ def __init__(self, model, args, device): # hyparparameters used in the loss function self.lambd = ( - args.theta_reg_lambda if ("theta_reg_lambda" in args) else 1e-6 + args.theta_reg_lambda + if ("theta_reg_lambda" in args) + else 1e-6 ) # for regularization strenght self.eta = ( - args.h_labeled_param if ("h_labeled_param" in args) else 0.0 + args.h_labeled_param + if ("h_labeled_param" in args) + else 0.0 ) # for wealky supervised self.gamma = ( args.info_hypara if ("info_hypara" in args) else 0.0 ) # for wealky supervised - self.w_entropy = args.w_entropy if ("w_entropy" in args) else 0.0 + self.w_entropy = ( + args.w_entropy if ("w_entropy" in args) else 0.0 + ) print("self.eta:", self.eta) print("self.w_entropy", self.w_entropy) @@ -961,7 +1048,9 @@ def __init__(self, model, args, device): self.model.parameters(), lr=args.lr, betas=optim_betas ) elif args.opt == "rmsprop": - self.optimizer = optim.RMSprop(self.model.parameters(), lr=args.lr) + self.optimizer = optim.RMSprop( + self.model.parameters(), lr=args.lr + ) elif args.opt == "sgd": self.optimizer = optim.SGD( self.model.parameters(), @@ -1012,8 +1101,16 @@ def train_batch(self, inputs, targets, concepts, epoch, pcbm): all_losses = {"prediction": pred_loss.cpu().data.numpy()} # compute loss of known concets and discriminator - h_loss, hh_labeled = self.concept_learning_loss_for_weak_supervision( - inputs, all_losses, concepts, self.args.cbm, self.args.senn, epoch, pcbm + h_loss, hh_labeled = ( + self.concept_learning_loss_for_weak_supervision( + inputs, + all_losses, + concepts, + self.args.cbm, + self.args.senn, + epoch, + pcbm, + ) ) # add entropy on concepts @@ -1025,7 +1122,9 @@ def train_batch(self, inputs, targets, concepts, epoch, pcbm): if pcbm: def kl_divergence(mu, logsigma, reduction="sum"): - kl = -0.5 * (1 + logsigma - mu.pow(2) - logsigma.exp()) + kl = -0.5 * ( + 1 + logsigma - mu.pow(2) - logsigma.exp() + ) return kl.mean() self.model.gaussian_vars = True @@ -1071,7 +1170,9 @@ def BCE_forloop(self, tar, pred): loss = F.binary_cross_entropy(tar[0, :4], pred[0, :4]) for i in range(1, len(tar)): - loss = loss + F.binary_cross_entropy(tar[i, :4], pred[i, :4]) + loss = loss + F.binary_cross_entropy( + tar[i, :4], pred[i, :4] + ) return loss # return loss /4 diff --git a/BDD_OIA/visualization.py b/BDD_OIA/visualization.py index 1fc55c9..3ab33c8 100644 --- a/BDD_OIA/visualization.py +++ b/BDD_OIA/visualization.py @@ -1,9 +1,9 @@ -from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay -from numpy import ndarray -from typing import List, Tuple, Dict +from typing import Dict, List, Tuple -import numpy as np import matplotlib.pyplot as plt +import numpy as np +from numpy import ndarray +from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix OUTPUT_FOLDER = "./plots" @@ -17,9 +17,13 @@ def create_output_folder() -> None: def convert_concepts_to_indexes(values, reference_array): # Create a mapping dictionary from values to indices in the second array - value_to_index = {value: index for index, value in enumerate(reference_array)} + value_to_index = { + value: index for index, value in enumerate(reference_array) + } # Map the elements of the first array to indices from the second array - mapped_indices = np.array([value_to_index[value] for value in values]) + mapped_indices = np.array( + [value_to_index[value] for value in values] + ) return mapped_indices @@ -42,7 +46,9 @@ def produce_confusion_matrix( ground_truth = y_true.astype(str) predictions = y_pred.astype(str) - all_labels = np.union1d(np.unique(ground_truth), np.unique(predictions)) + all_labels = np.union1d( + np.unique(ground_truth), np.unique(predictions) + ) label_encoder.fit(all_labels) ground_truth = label_encoder.transform(ground_truth) @@ -74,12 +80,20 @@ def produce_confusion_matrix( x_labels = ax.get_xticklabels() tick_positions = np.arange(len(x_labels)) - tick_labels = [x_labels[i] if i % ntimes == 0 else "" for i in range(len(x_labels))] - ax.set_xticks(tick_positions, tick_labels, rotation=90, fontsize=10) + tick_labels = [ + x_labels[i] if i % ntimes == 0 else "" + for i in range(len(x_labels)) + ] + ax.set_xticks( + tick_positions, tick_labels, rotation=90, fontsize=10 + ) y_labels = ax.get_yticklabels() tick_positions = np.arange(len(y_labels)) - tick_labels = [y_labels[i] if i % ntimes == 0 else "" for i in range(len(y_labels))] + tick_labels = [ + y_labels[i] if i % ntimes == 0 else "" + for i in range(len(y_labels)) + ] ax.set_yticks(tick_positions, tick_labels, fontsize=10) # Set title and color bar label @@ -111,7 +125,9 @@ def produce_world_probability_table( ax.axis("off") # Create the table with blue color - table_data = [[key, f"{value:.2f}"] for key, value in zip(keys, values)] + table_data = [ + [key, f"{value:.2f}"] for key, value in zip(keys, values) + ] table = ax.table( cellText=table_data, colLabels=[key_string, key_value], @@ -125,7 +141,9 @@ def produce_world_probability_table( table.set_fontsize(10) table.scale(1.5, 1.5) - ax.set_title(title, fontweight="bold", fontsize=16, color="#3366cc") + ax.set_title( + title, fontweight="bold", fontsize=16, color="#3366cc" + ) # Specify the file path and name where you want to save the image file_path = f"{OUTPUT_FOLDER}/mean_probability_table.png" @@ -164,7 +182,12 @@ def produce_alpha_matrix( # tick_positions = np.arange(len(concept_labels)) # tick_labels = [concept_labels[i] if i % ntimes == 0 else '' for i in range(len(concept_labels))] # plt.xticks(tick_positions, tick_labels, rotation=90, fontsize=10) - plt.xticks(np.arange(len(concept_labels)), concept_labels, rotation=90, fontsize=10) + plt.xticks( + np.arange(len(concept_labels)), + concept_labels, + rotation=90, + fontsize=10, + ) # Set y-axis ticks and labels # tick_positions = np.arange(len(keys)) @@ -188,18 +211,39 @@ def produce_alpha_matrix( def produce_scatter_multi_class( - x_values_list, y_values_list, labels, dataset, suffix, colors=None, markers=None + x_values_list, + y_values_list, + labels, + dataset, + suffix, + colors=None, + markers=None, ): if colors is None: - colors = ["blue", "red", "green", "orange", "purple", "brown", "pink", "gray"] + colors = [ + "blue", + "red", + "green", + "orange", + "purple", + "brown", + "pink", + "gray", + ] if markers is None: markers = ["o", "s", "^", "D", "v", ">", "<", "p"] # Create a scatter plot for each class - for i, (x_values, y_values) in enumerate(zip(x_values_list, y_values_list)): + for i, (x_values, y_values) in enumerate( + zip(x_values_list, y_values_list) + ): plt.scatter( - [x_values], [y_values], color=colors[i], marker=markers[i], label=labels[i] + [x_values], + [y_values], + color=colors[i], + marker=markers[i], + label=labels[i], ) max_x = max(x_values_list) @@ -215,7 +259,9 @@ def produce_scatter_multi_class( plt.legend() # Save or display the plot - file_path = f"{OUTPUT_FOLDER}/{dataset}_hc_ece_scatter_plot{suffix}.png" + file_path = ( + f"{OUTPUT_FOLDER}/{dataset}_hc_ece_scatter_plot{suffix}.png" + ) plt.savefig(file_path, dpi=150) # Close @@ -242,7 +288,9 @@ def plot_grouped_entropies( index = np.arange(num_categories, dtype=float) # Adjust the index positions to create separation between groups - total_group_width = num_categories * bar_width + (num_categories - 1) * group_gap + total_group_width = ( + num_categories * bar_width + (num_categories - 1) * group_gap + ) linspace_values = np.linspace( -total_group_width / 2, total_group_width / 2, num_categories ) @@ -284,7 +332,9 @@ def plot_grouped_entropies( ax.legend() if save: - file_path = f"{OUTPUT_FOLDER}/{dataset}_{prefix}_hc_bar_plot.png" + file_path = ( + f"{OUTPUT_FOLDER}/{dataset}_{prefix}_hc_bar_plot.png" + ) fig.tight_layout() plt.savefig(file_path, dpi=150) @@ -300,15 +350,21 @@ def produce_calibration_curve( num_bins = len(bin_info) # Extract relevant information from bin_info - bin_confidence = [bin_info[i]["BIN_CONF"] for i in range(num_bins)] + bin_confidence = [ + bin_info[i]["BIN_CONF"] for i in range(num_bins) + ] bin_accuracy = [bin_info[i]["BIN_ACC"] for i in range(num_bins)] bin_counts = [bin_info[i]["COUNT"] for i in range(num_bins)] # Calculate the center of each bin - bin_centers = np.linspace(1 / (2 * num_bins), 1 - 1 / (2 * num_bins), num_bins) + bin_centers = np.linspace( + 1 / (2 * num_bins), 1 - 1 / (2 * num_bins), num_bins + ) # Create a subplot with two plots (2 rows, 1 column) - fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 10), sharex=True) + fig, (ax1, ax2) = plt.subplots( + 2, 1, figsize=(10, 10), sharex=True + ) txt = "Concept" if concept else "Label" fig.suptitle(f"{txt} Calibration Curve ECE: {ece:.2f}") @@ -326,13 +382,23 @@ def produce_calibration_curve( ) # Plot grey dashed vertical lines for weighted average confidence and accuracy - avg_confidence = np.sum(np.array(bin_confidence) * bin_counts) / np.sum(bin_counts) - avg_accuracy = np.sum(np.array(bin_accuracy) * bin_counts) / np.sum(bin_counts) + avg_confidence = np.sum( + np.array(bin_confidence) * bin_counts + ) / np.sum(bin_counts) + avg_accuracy = np.sum( + np.array(bin_accuracy) * bin_counts + ) / np.sum(bin_counts) ax1.axvline( - x=avg_confidence, color="red", linestyle="--", label="Weighted Avg. Confidence" + x=avg_confidence, + color="red", + linestyle="--", + label="Weighted Avg. Confidence", ) ax1.axvline( - x=avg_accuracy, color="black", linestyle="--", label="Weighted Avg. Accuracy" + x=avg_accuracy, + color="black", + linestyle="--", + label="Weighted Avg. Accuracy", ) # Customize the second plot @@ -366,7 +432,13 @@ def produce_calibration_curve( ) # Plot the ideal line (diagonal) - ax2.plot([0, 1], [0, 1], linestyle="--", color="gray", label="Perfect Calibration") + ax2.plot( + [0, 1], + [0, 1], + linestyle="--", + color="gray", + label="Perfect Calibration", + ) # Customize the first plot ax2.set_xlabel("Mean Predicted Probability (Confidence)") @@ -391,7 +463,9 @@ def produce_bar_plot( indices = np.arange(len(data)) # Create the bar plot with improved styling - plt.bar(indices, data, color="blue", edgecolor="black", linewidth=1.2) + plt.bar( + indices, data, color="blue", edgecolor="black", linewidth=1.2 + ) # Adding labels and title plt.xlabel(xlabel, fontsize=12) diff --git a/BDD_OIA/worlds_BDD.py b/BDD_OIA/worlds_BDD.py index 6b9102c..5f21f2c 100644 --- a/BDD_OIA/worlds_BDD.py +++ b/BDD_OIA/worlds_BDD.py @@ -1,6 +1,7 @@ -import numpy as np from functools import reduce from itertools import chain + +import numpy as np import torch @@ -17,13 +18,19 @@ def squeeze_last_dimensions(arr, numpy=True): def convert_np_array_to_binary(array): return np.apply_along_axis( - lambda row: int("".join(map(str, row[::-1])), 2), axis=1, arr=array.astype(int) + lambda row: int("".join(map(str, row[::-1])), 2), + axis=1, + arr=array.astype(int), ) -def compute_world_probability(concepts_list, numpy=True, clip_value=1e-5): +def compute_world_probability( + concepts_list, numpy=True, clip_value=1e-5 +): if not numpy: - world_matrix = squeeze_last_dimensions(concepts_list[0], numpy) + world_matrix = squeeze_last_dimensions( + concepts_list[0], numpy + ) for tensor in concepts_list[1:]: world_matrix = world_matrix.unsqueeze(-1) world_matrix = torch.matmul( @@ -32,7 +39,9 @@ def compute_world_probability(concepts_list, numpy=True, clip_value=1e-5): collapsed_dim = np.prod(world_matrix.shape[2:]) world_matrix = world_matrix.view( - world_matrix.shape[0], world_matrix.shape[1], collapsed_dim + world_matrix.shape[0], + world_matrix.shape[1], + collapsed_dim, ) world_matrix = torch.mean(world_matrix, dim=0) @@ -54,7 +63,12 @@ def compute_world_probability(concepts_list, numpy=True, clip_value=1e-5): collapsed_dim = np.prod(world_matrix.shape[2:]) world_matrix = np.reshape( - world_matrix, (world_matrix.shape[0], world_matrix.shape[1], collapsed_dim) + world_matrix, + ( + world_matrix.shape[0], + world_matrix.shape[1], + collapsed_dim, + ), ) # mean across the num_models @@ -67,10 +81,25 @@ def create_concepts_array(concepts_list, num_ones, i=2, numpy=True): n_models = concepts_list.shape[0] batch_size = concepts_list.shape[1] - shape = (n_models, batch_size) + ((1,) * i) + (2,) + ((1,) * (num_ones - i)) + shape = ( + (n_models, batch_size) + + ((1,) * i) + + (2,) + + ((1,) * (num_ones - i)) + ) - slice_1 = tuple([slice(None), slice(None)] + [0] * i + [0] + [0] * (num_ones - i)) - slice_2 = tuple([slice(None), slice(None)] + [0] * i + [1] + [0] * (num_ones - i)) + slice_1 = tuple( + [slice(None), slice(None)] + + [0] * i + + [0] + + [0] * (num_ones - i) + ) + slice_2 = tuple( + [slice(None), slice(None)] + + [0] * i + + [1] + + [0] * (num_ones - i) + ) c_array = np.zeros(shape) diff --git a/XOR_MNIST/backbones/__init__.py b/XOR_MNIST/backbones/__init__.py index 8f0df81..2f9b412 100644 --- a/XOR_MNIST/backbones/__init__.py +++ b/XOR_MNIST/backbones/__init__.py @@ -1,4 +1,5 @@ import math + import torch import torch.nn as nn diff --git a/XOR_MNIST/backbones/addmnist_joint.py b/XOR_MNIST/backbones/addmnist_joint.py index ffa12e2..2616fc1 100644 --- a/XOR_MNIST/backbones/addmnist_joint.py +++ b/XOR_MNIST/backbones/addmnist_joint.py @@ -1,12 +1,16 @@ import torch.nn -from torch import nn - from backbones.base.ops import * +from torch import nn class MNISTPairsEncoder(nn.Module): def __init__( - self, img_channels=1, hidden_channels=32, c_dim=20, latent_dim=20, dropout=0.5 + self, + img_channels=1, + hidden_channels=32, + c_dim=20, + latent_dim=20, + dropout=0.5, ): super(MNISTPairsEncoder, self).__init__() @@ -90,12 +94,22 @@ def forward(self, x): # print(x.size()) - c, mu, logvar = self.dense_c(x), self.dense_mu(x), self.dense_logvar(x) + c, mu, logvar = ( + self.dense_c(x), + self.dense_mu(x), + self.dense_logvar(x), + ) # return encodings for each object involved - c = torch.stack(torch.split(c, self.c_dim // 2, dim=-1), dim=1) - mu = torch.stack(torch.split(mu, self.latent_dim // 2, dim=-1), dim=1) - logvar = torch.stack(torch.split(logvar, self.latent_dim // 2, dim=-1), dim=1) + c = torch.stack( + torch.split(c, self.c_dim // 2, dim=-1), dim=1 + ) + mu = torch.stack( + torch.split(mu, self.latent_dim // 2, dim=-1), dim=1 + ) + logvar = torch.stack( + torch.split(logvar, self.latent_dim // 2, dim=-1), dim=1 + ) return c, mu, logvar @@ -159,7 +173,9 @@ def forward(self, x: torch.Tensor): # Unflatten Input x = self.dense(x) - x = self.unflatten(x, self.hidden_channels * 4, self.unflatten_dim) + x = self.unflatten( + x, self.hidden_channels * 4, self.unflatten_dim + ) # MNISTPairsDecoder block 1 x = self.dec_block_1(x) diff --git a/XOR_MNIST/backbones/addmnist_repeated.py b/XOR_MNIST/backbones/addmnist_repeated.py index 38cab16..9bdb94b 100644 --- a/XOR_MNIST/backbones/addmnist_repeated.py +++ b/XOR_MNIST/backbones/addmnist_repeated.py @@ -1,12 +1,16 @@ import torch.nn -from torch import nn - from backbones.base.ops import * +from torch import nn class MNISTRepeatedEncoder(nn.Module): def __init__( - self, img_channels=1, hidden_channels=32, c_dim=10, latent_dim=10, dropout=0.5 + self, + img_channels=1, + hidden_channels=32, + c_dim=10, + latent_dim=10, + dropout=0.5, ): super(MNISTRepeatedEncoder, self).__init__() @@ -156,7 +160,11 @@ def forward1(self, x): # print(x.size()) - c, mu, logvar = self.dense_c1(x), self.dense_mu1(x), self.dense_logvar1(x) + c, mu, logvar = ( + self.dense_c1(x), + self.dense_mu1(x), + self.dense_logvar1(x), + ) return c, mu, logvar @@ -182,7 +190,11 @@ def forward2(self, x): # print(x.size()) - c, mu, logvar = self.dense_c2(x), self.dense_mu2(x), self.dense_logvar2(x) + c, mu, logvar = ( + self.dense_c2(x), + self.dense_mu2(x), + self.dense_logvar2(x), + ) return c, mu, logvar @@ -260,7 +272,9 @@ def forward(self, x: torch.Tensor): # Unflatten Input x = self.dense(x) - x = self.unflatten(x, self.hidden_channels * 4, self.unflatten_dim) + x = self.unflatten( + x, self.hidden_channels * 4, self.unflatten_dim + ) # MNISTPairsDecoder block 1 x = self.dec_block_1(x) diff --git a/XOR_MNIST/backbones/addmnist_single.py b/XOR_MNIST/backbones/addmnist_single.py index 9dfbd4e..0d3fce4 100644 --- a/XOR_MNIST/backbones/addmnist_single.py +++ b/XOR_MNIST/backbones/addmnist_single.py @@ -1,11 +1,16 @@ import torch.nn -from torch import nn from backbones.base.ops import * +from torch import nn class MNISTSingleEncoder(nn.Module): def __init__( - self, img_channels=1, hidden_channels=32, c_dim=10, latent_dim=16, dropout=0.5 + self, + img_channels=1, + hidden_channels=32, + c_dim=10, + latent_dim=16, + dropout=0.5, ): super(MNISTSingleEncoder, self).__init__() @@ -97,11 +102,19 @@ def forward(self, x): x = self.flatten( x ) # batch_size, dim1, dim2, dim3 -> batch_size, dim1*dim2*dim3 - c, mu, logvar = self.dense_c(x), self.dense_mu(x), self.dense_logvar(x) + c, mu, logvar = ( + self.dense_c(x), + self.dense_mu(x), + self.dense_logvar(x), + ) # return encodings for each object involved c = torch.stack(torch.split(c, self.c_dim, dim=-1), dim=1) - mu = torch.stack(torch.split(mu, self.latent_dim, dim=-1), dim=1) - logvar = torch.stack(torch.split(logvar, self.latent_dim, dim=-1), dim=1) + mu = torch.stack( + torch.split(mu, self.latent_dim, dim=-1), dim=1 + ) + logvar = torch.stack( + torch.split(logvar, self.latent_dim, dim=-1), dim=1 + ) return c, mu, logvar diff --git a/XOR_MNIST/backbones/base/base_decoder.py b/XOR_MNIST/backbones/base/base_decoder.py index 05e863b..7e0fc43 100644 --- a/XOR_MNIST/backbones/base/base_decoder.py +++ b/XOR_MNIST/backbones/base/base_decoder.py @@ -13,7 +13,9 @@ def __init__(self, latent_dim, num_channels, image_size): def init_layers(self): for block in self._modules: for m in self._modules[block]: - if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): + if isinstance(m, nn.Conv2d) or isinstance( + m, nn.ConvTranspose2d + ): init.xavier_normal_(m.weight.data) if isinstance(m, nn.Linear): init.kaiming_normal_(m.weight.data) diff --git a/XOR_MNIST/backbones/base/ops.py b/XOR_MNIST/backbones/base/ops.py index f6ab87d..e87e20f 100644 --- a/XOR_MNIST/backbones/base/ops.py +++ b/XOR_MNIST/backbones/base/ops.py @@ -1,5 +1,5 @@ -from torch import nn import torch +from torch import nn class Flatten3D(nn.Module): @@ -34,4 +34,6 @@ def forward(self, input): class UnFlatten(nn.Module): def forward(self, input, hidden_channels, dim): - return input.reshape(input.size(0), hidden_channels, dim[0], dim[1]) + return input.reshape( + input.size(0), hidden_channels, dim[0], dim[1] + ) diff --git a/XOR_MNIST/backbones/disent_encoder_decoder.py b/XOR_MNIST/backbones/disent_encoder_decoder.py index 1e57b75..f8b254b 100644 --- a/XOR_MNIST/backbones/disent_encoder_decoder.py +++ b/XOR_MNIST/backbones/disent_encoder_decoder.py @@ -1,8 +1,5 @@ import torch - - -from torch import Tensor -from torch import nn +from torch import Tensor, nn class EncoderConv64(nn.Module): @@ -17,7 +14,10 @@ class EncoderConv64(nn.Module): def __init__(self, x_shape=(3, 64, 64), z_size=6, z_multiplier=1): # checks (C, H, W) = x_shape - assert (H, W) == (64, 64), "This model only works with image size 64x64." + assert (H, W) == ( + 64, + 64, + ), "This model only works with image size 64x64." super().__init__() self.z_size = z_size @@ -25,21 +25,37 @@ def __init__(self, x_shape=(3, 64, 64), z_size=6, z_multiplier=1): self.model = nn.Sequential( nn.Conv2d( - in_channels=C, out_channels=32, kernel_size=4, stride=2, padding=2 + in_channels=C, + out_channels=32, + kernel_size=4, + stride=2, + padding=2, ), nn.ReLU(inplace=True), nn.Conv2d( - in_channels=32, out_channels=32, kernel_size=4, stride=2, padding=2 + in_channels=32, + out_channels=32, + kernel_size=4, + stride=2, + padding=2, ), nn.ReLU(inplace=True), nn.Conv2d( - in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=2 + in_channels=32, + out_channels=64, + kernel_size=4, + stride=2, + padding=2, ), nn.ReLU( inplace=True ), # This was reverted to kernel size 4x4 from 2x2, to match beta-vae paper nn.Conv2d( - in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=2 + in_channels=64, + out_channels=64, + kernel_size=4, + stride=2, + padding=2, ), nn.ReLU( inplace=True @@ -71,7 +87,10 @@ def __init__( z_multiplier=1, ): (C, H, W) = x_shape - assert (H, W) == (64, 64), "This model only works with image size 64x64." + assert (H, W) == ( + 64, + 64, + ), "This model only works with image size 64x64." super().__init__() self.z_size = z_size @@ -83,19 +102,35 @@ def __init__( nn.ReLU(inplace=True), nn.Unflatten(dim=1, unflattened_size=[64, 4, 4]), nn.ConvTranspose2d( - in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1 + in_channels=64, + out_channels=64, + kernel_size=4, + stride=2, + padding=1, ), nn.ReLU(inplace=True), nn.ConvTranspose2d( - in_channels=64, out_channels=32, kernel_size=4, stride=2, padding=1 + in_channels=64, + out_channels=32, + kernel_size=4, + stride=2, + padding=1, ), nn.ReLU(inplace=True), nn.ConvTranspose2d( - in_channels=32, out_channels=32, kernel_size=4, stride=2, padding=1 + in_channels=32, + out_channels=32, + kernel_size=4, + stride=2, + padding=1, ), nn.ReLU(inplace=True), nn.ConvTranspose2d( - in_channels=32, out_channels=C, kernel_size=4, stride=2, padding=1 + in_channels=32, + out_channels=C, + kernel_size=4, + stride=2, + padding=1, ), ) diff --git a/XOR_MNIST/backbones/kand_encoder.py b/XOR_MNIST/backbones/kand_encoder.py index 379e958..b317e3f 100644 --- a/XOR_MNIST/backbones/kand_encoder.py +++ b/XOR_MNIST/backbones/kand_encoder.py @@ -1,4 +1,5 @@ import copy + import torch.nn from torch import nn @@ -10,7 +11,9 @@ def forward(self, input): class UnFlatten(nn.Module): def forward(self, input, hidden_channels, dim): - return input.reshape(input.size(0), hidden_channels, dim[0], dim[1]) + return input.reshape( + input.size(0), hidden_channels, dim[0], dim[1] + ) class TripleCNNEncoder(nn.Module): @@ -160,7 +163,9 @@ def forward(self, x): return logits for i in range(3): - vars = torch.stack(torch.split(self.backbone(xs[i]), 3, dim=-1)) + vars = torch.stack( + torch.split(self.backbone(xs[i]), 3, dim=-1) + ) logits.append(vars) logits = torch.stack(logits) diff --git a/XOR_MNIST/backbones/resnet.py b/XOR_MNIST/backbones/resnet.py index c58b055..8369af5 100644 --- a/XOR_MNIST/backbones/resnet.py +++ b/XOR_MNIST/backbones/resnet.py @@ -1,15 +1,15 @@ +import torch +import torch.nn as nn import torchvision +from torch import Tensor from torchvision.models import ( + ResNet18_Weights, + ResNet50_Weights, + ResNet101_Weights, resnet18, resnet50, resnet101, - ResNet101_Weights, - ResNet50_Weights, - ResNet18_Weights, ) -import torch.nn as nn -import torch -from torch import Tensor class ResNetEncoder(nn.Module): @@ -24,7 +24,9 @@ class ResNetEncoder(nn.Module): def __init__(self, z_dim, z_multiplier=1): super().__init__() - self.pretrained = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1) + self.pretrained = resnet18( + weights=ResNet18_Weights.IMAGENET1K_V1 + ) features = nn.ModuleList(self.pretrained.children())[:-1] # set the ResNet18 backbone as feature extractor @@ -50,7 +52,10 @@ def __init__( z_multiplier=1, ): (C, H, W) = x_shape - assert (H, W) == (64, 64), "This model only works with image size 64x64." + assert (H, W) == ( + 64, + 64, + ), "This model only works with image size 64x64." super().__init__() self.z_size = z_size @@ -62,19 +67,35 @@ def __init__( nn.ReLU(inplace=True), nn.Unflatten(dim=1, unflattened_size=[64, 4, 4]), nn.ConvTranspose2d( - in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1 + in_channels=64, + out_channels=64, + kernel_size=4, + stride=2, + padding=1, ), nn.ReLU(inplace=True), nn.ConvTranspose2d( - in_channels=64, out_channels=32, kernel_size=4, stride=2, padding=1 + in_channels=64, + out_channels=32, + kernel_size=4, + stride=2, + padding=1, ), nn.ReLU(inplace=True), nn.ConvTranspose2d( - in_channels=32, out_channels=32, kernel_size=4, stride=2, padding=1 + in_channels=32, + out_channels=32, + kernel_size=4, + stride=2, + padding=1, ), nn.ReLU(inplace=True), nn.ConvTranspose2d( - in_channels=32, out_channels=C, kernel_size=4, stride=2, padding=1 + in_channels=32, + out_channels=C, + kernel_size=4, + stride=2, + padding=1, ), ) diff --git a/XOR_MNIST/backbones/simple_encoder.py b/XOR_MNIST/backbones/simple_encoder.py index 0573f8b..6325194 100644 --- a/XOR_MNIST/backbones/simple_encoder.py +++ b/XOR_MNIST/backbones/simple_encoder.py @@ -1,8 +1,5 @@ import torch - - -from torch import Tensor -from torch import nn +from torch import Tensor, nn class SimpleMLP(nn.Module): diff --git a/XOR_MNIST/datasets/__init__.py b/XOR_MNIST/datasets/__init__.py index 8906c00..0adcfcc 100644 --- a/XOR_MNIST/datasets/__init__.py +++ b/XOR_MNIST/datasets/__init__.py @@ -1,6 +1,6 @@ -import os -import inspect import importlib +import inspect +import os from argparse import Namespace diff --git a/XOR_MNIST/datasets/addmnist.py b/XOR_MNIST/datasets/addmnist.py index 99200c1..17c4159 100644 --- a/XOR_MNIST/datasets/addmnist.py +++ b/XOR_MNIST/datasets/addmnist.py @@ -1,8 +1,11 @@ -from datasets.utils.base_dataset import BaseDataset, get_loader -from datasets.utils.mnist_creation import load_2MNIST -from backbones.addmnist_joint import MNISTPairsEncoder, MNISTPairsDecoder +from backbones.addmnist_joint import ( + MNISTPairsDecoder, + MNISTPairsEncoder, +) from backbones.addmnist_repeated import MNISTRepeatedEncoder from backbones.addmnist_single import MNISTSingleEncoder +from datasets.utils.base_dataset import BaseDataset, get_loader +from datasets.utils.mnist_creation import load_2MNIST class ADDMNIST(BaseDataset): @@ -11,7 +14,9 @@ class ADDMNIST(BaseDataset): def get_data_loaders(self): dataset_train, dataset_val, dataset_test = load_2MNIST( - c_sup=self.args.c_sup, which_c=self.args.which_c, args=self.args + c_sup=self.args.c_sup, + which_c=self.args.which_c, + args=self.args, ) self.dataset_train = dataset_train @@ -21,8 +26,12 @@ def get_data_loaders(self): self.train_loader = get_loader( dataset_train, self.args.batch_size, val_test=False ) - self.val_loader = get_loader(dataset_val, self.args.batch_size, val_test=True) - self.test_loader = get_loader(dataset_test, self.args.batch_size, val_test=True) + self.val_loader = get_loader( + dataset_val, self.args.batch_size, val_test=True + ) + self.test_loader = get_loader( + dataset_test, self.args.batch_size, val_test=True + ) return self.train_loader, self.val_loader, self.test_loader diff --git a/XOR_MNIST/datasets/halfmnist.py b/XOR_MNIST/datasets/halfmnist.py index dde5bed..be8043b 100644 --- a/XOR_MNIST/datasets/halfmnist.py +++ b/XOR_MNIST/datasets/halfmnist.py @@ -1,9 +1,13 @@ +from copy import deepcopy + +import numpy as np +from backbones.addmnist_joint import ( + MNISTPairsDecoder, + MNISTPairsEncoder, +) +from backbones.addmnist_single import MNISTSingleEncoder from datasets.utils.base_dataset import BaseDataset, get_loader from datasets.utils.mnist_creation import load_2MNIST -from backbones.addmnist_joint import MNISTPairsEncoder, MNISTPairsDecoder -from backbones.addmnist_single import MNISTSingleEncoder -import numpy as np -from copy import deepcopy class HALFMNIST(BaseDataset): @@ -12,7 +16,9 @@ class HALFMNIST(BaseDataset): def get_data_loaders(self): dataset_train, dataset_val, dataset_test = load_2MNIST( - c_sup=self.args.c_sup, which_c=self.args.which_c, args=self.args + c_sup=self.args.c_sup, + which_c=self.args.which_c, + args=self.args, ) ood_test = self.get_ood_test(dataset_test) @@ -29,9 +35,15 @@ def get_data_loaders(self): self.train_loader = get_loader( dataset_train, self.args.batch_size, val_test=False ) - self.val_loader = get_loader(dataset_val, self.args.batch_size, val_test=True) - self.test_loader = get_loader(dataset_test, self.args.batch_size, val_test=True) - self.ood_loader = get_loader(ood_test, self.args.batch_size, val_test=True) + self.val_loader = get_loader( + dataset_val, self.args.batch_size, val_test=True + ) + self.test_loader = get_loader( + dataset_test, self.args.batch_size, val_test=True + ) + self.ood_loader = get_loader( + ood_test, self.args.batch_size, val_test=True + ) return self.train_loader, self.val_loader, self.test_loader @@ -172,17 +184,25 @@ def filtrate(self, train_dataset, val_dataset, test_dataset): test_mask = np.logical_or(test_c_mask1, test_c_mask2) - train_dataset.data = train_dataset.data[train_mask] # [:2000, :, :] + train_dataset.data = train_dataset.data[ + train_mask + ] # [:2000, :, :] val_dataset.data = val_dataset.data[val_mask] test_dataset.data = test_dataset.data[test_mask] - train_dataset.concepts = train_dataset.concepts[train_mask] # [:2000, :] + train_dataset.concepts = train_dataset.concepts[ + train_mask + ] # [:2000, :] val_dataset.concepts = val_dataset.concepts[val_mask] test_dataset.concepts = test_dataset.concepts[test_mask] - train_dataset.targets = np.array(train_dataset.targets)[train_mask] # [:2000] + train_dataset.targets = np.array(train_dataset.targets)[ + train_mask + ] # [:2000] val_dataset.targets = np.array(val_dataset.targets)[val_mask] - test_dataset.targets = np.array(test_dataset.targets)[test_mask] + test_dataset.targets = np.array(test_dataset.targets)[ + test_mask + ] return train_dataset, val_dataset, test_dataset @@ -237,7 +257,9 @@ def get_ood_test(self, test_dataset): test_mask_in_range = np.logical_and(mask_col0, mask_col1) test_mask_value = np.logical_and(~test_c_mask1, ~test_c_mask2) - test_mask = np.logical_and(test_mask_in_range, test_mask_value) + test_mask = np.logical_and( + test_mask_in_range, test_mask_value + ) ood_test.data = ood_test.data[test_mask] ood_test.concepts = ood_test.concepts[test_mask] diff --git a/XOR_MNIST/datasets/kandinsky.py b/XOR_MNIST/datasets/kandinsky.py index cbe060a..d3d266f 100644 --- a/XOR_MNIST/datasets/kandinsky.py +++ b/XOR_MNIST/datasets/kandinsky.py @@ -1,8 +1,12 @@ +import time + +from backbones.disent_encoder_decoder import ( + DecoderConv64, + EncoderConv64, +) +from backbones.resnet import ResNetEncoder from datasets.utils.base_dataset import BaseDataset, KAND_get_loader from datasets.utils.kand_creation import KAND_Dataset -from backbones.disent_encoder_decoder import DecoderConv64, EncoderConv64 -from backbones.resnet import ResNetEncoder -import time class Kandinsky(BaseDataset): @@ -33,19 +37,36 @@ def get_data_loaders(self): print(f"Loaded datasets in {time.time()-start} s.") - print("Len loaders: \n train:", len(dataset_train), "\n val:", len(dataset_val)) - print(" len test:", len(dataset_test)) # , '\n len ood', len(dataset_ood)) + print( + "Len loaders: \n train:", + len(dataset_train), + "\n val:", + len(dataset_val), + ) + print( + " len test:", len(dataset_test) + ) # , '\n len ood', len(dataset_ood)) if not self.args.preprocess: train_loader = KAND_get_loader( dataset_train, self.args.batch_size, val_test=False ) - val_loader = KAND_get_loader(dataset_val, 1000, val_test=True) - test_loader = KAND_get_loader(dataset_test, 1000, val_test=True) + val_loader = KAND_get_loader( + dataset_val, 1000, val_test=True + ) + test_loader = KAND_get_loader( + dataset_test, 1000, val_test=True + ) else: - train_loader = KAND_get_loader(dataset_train, 1, val_test=False) - val_loader = KAND_get_loader(dataset_val, 1, val_test=True) - test_loader = KAND_get_loader(dataset_test, 1, val_test=True) + train_loader = KAND_get_loader( + dataset_train, 1, val_test=False + ) + val_loader = KAND_get_loader( + dataset_val, 1, val_test=True + ) + test_loader = KAND_get_loader( + dataset_test, 1, val_test=True + ) # self.ood_loader = get_loader(dataset_ood, self.args.batch_size, val_test=True) @@ -53,13 +74,17 @@ def get_data_loaders(self): def get_backbone(self, args=None): if self.args.preprocess: - return ResNetEncoder(z_dim=18, z_multiplier=2), DecoderConv64( + return ResNetEncoder( + z_dim=18, z_multiplier=2 + ), DecoderConv64( x_shape=(3, 64, 64), z_size=18, z_multiplier=2 ) else: return EncoderConv64( x_shape=(3, 64, 64), z_size=18, z_multiplier=2 - ), DecoderConv64(x_shape=(3, 64, 64), z_size=18, z_multiplier=2) + ), DecoderConv64( + x_shape=(3, 64, 64), z_size=18, z_multiplier=2 + ) def get_split(self): return 3, () diff --git a/XOR_MNIST/datasets/minikandinsky.py b/XOR_MNIST/datasets/minikandinsky.py index 377127b..08f1ea0 100644 --- a/XOR_MNIST/datasets/minikandinsky.py +++ b/XOR_MNIST/datasets/minikandinsky.py @@ -1,8 +1,12 @@ -from datasets.utils.base_dataset import BaseDataset, KAND_get_loader -from datasets.utils.kand_creation import KAND_Dataset, miniKAND_Dataset -from backbones.kand_encoder import TripleCNNEncoder, TripleMLP import time +from backbones.kand_encoder import TripleCNNEncoder, TripleMLP +from datasets.utils.base_dataset import BaseDataset, KAND_get_loader +from datasets.utils.kand_creation import ( + KAND_Dataset, + miniKAND_Dataset, +) + class MiniKandinsky(BaseDataset): NAME = "minikandinsky" @@ -12,11 +16,17 @@ def get_data_loaders(self): if not hasattr(self, "dataset_train"): self.dataset_train = miniKAND_Dataset( - base_path="data/kand-3k", split="train", finetuning=self.args.finetuning + base_path="data/kand-3k", + split="train", + finetuning=self.args.finetuning, ) - dataset_val = miniKAND_Dataset(base_path="data/kand-3k", split="val") - dataset_test = miniKAND_Dataset(base_path="data/kand-3k", split="test") + dataset_val = miniKAND_Dataset( + base_path="data/kand-3k", split="val" + ) + dataset_test = miniKAND_Dataset( + base_path="data/kand-3k", split="test" + ) # dataset_ood = KAND_Dataset(base_path='data/kandinsky/data',split='ood') print(f"Loaded datasets in {time.time()-start} s.") @@ -27,18 +37,32 @@ def get_data_loaders(self): "\n val:", len(dataset_val), ) - print(" len test:", len(dataset_test)) # , '\n len ood', len(dataset_ood)) + print( + " len test:", len(dataset_test) + ) # , '\n len ood', len(dataset_ood)) if not self.args.preprocess: train_loader = KAND_get_loader( - self.dataset_train, self.args.batch_size, val_test=False + self.dataset_train, + self.args.batch_size, + val_test=False, + ) + val_loader = KAND_get_loader( + dataset_val, 500, val_test=True + ) + test_loader = KAND_get_loader( + dataset_test, 500, val_test=True ) - val_loader = KAND_get_loader(dataset_val, 500, val_test=True) - test_loader = KAND_get_loader(dataset_test, 500, val_test=True) else: - train_loader = KAND_get_loader(self.dataset_train, 1, val_test=False) - val_loader = KAND_get_loader(dataset_val, 1, val_test=True) - test_loader = KAND_get_loader(dataset_test, 1, val_test=True) + train_loader = KAND_get_loader( + self.dataset_train, 1, val_test=False + ) + val_loader = KAND_get_loader( + dataset_val, 1, val_test=True + ) + test_loader = KAND_get_loader( + dataset_test, 1, val_test=True + ) # self.ood_loader = get_loader(dataset_ood, self.args.batch_size, val_test=True) @@ -47,20 +71,32 @@ def get_data_loaders(self): def give_full_supervision(self): if not hasattr(self, "dataset_train"): self.dataset_train = miniKAND_Dataset( - base_path="data/kand-3k", split="train", finetuning=self.args.finetuning + base_path="data/kand-3k", + split="train", + finetuning=self.args.finetuning, ) - self.dataset_train.concepts = self.dataset_train.original_concepts + self.dataset_train.concepts = ( + self.dataset_train.original_concepts + ) def give_supervision_to(self, data_idx, figure_idx, obj_idx): if not hasattr(self, "dataset_train"): self.dataset_train = miniKAND_Dataset( - base_path="data/kand-3k", split="train", finetuning=self.args.finetuning + base_path="data/kand-3k", + split="train", + finetuning=self.args.finetuning, ) - self.dataset_train.concepts = self.dataset_train.original_concepts - self.dataset_train.mask_concepts_specific(data_idx, figure_idx, obj_idx) + self.dataset_train.concepts = ( + self.dataset_train.original_concepts + ) + self.dataset_train.mask_concepts_specific( + data_idx, figure_idx, obj_idx + ) def get_train_loader_as_val(self): - return KAND_get_loader(self.dataset_train, self.args.batch_size, val_test=True) + return KAND_get_loader( + self.dataset_train, self.args.batch_size, val_test=True + ) def get_backbone(self, args=None): return TripleMLP(latent_dim=6), 0 diff --git a/XOR_MNIST/datasets/prekandinsky.py b/XOR_MNIST/datasets/prekandinsky.py index 138c77f..90556d8 100644 --- a/XOR_MNIST/datasets/prekandinsky.py +++ b/XOR_MNIST/datasets/prekandinsky.py @@ -1,8 +1,9 @@ +import time + +from backbones.disent_encoder_decoder import DecoderConv64 +from backbones.simple_encoder import SimpleMLP from datasets.utils.base_dataset import BaseDataset, KAND_get_loader from datasets.utils.kand_creation import PreKAND_Dataset -from backbones.simple_encoder import SimpleMLP -from backbones.disent_encoder_decoder import DecoderConv64 -import time class PreKandinsky(BaseDataset): @@ -11,23 +12,38 @@ class PreKandinsky(BaseDataset): def get_data_loaders(self): start = time.time() - dataset_train = PreKAND_Dataset(base_path="data/kand-preprocess", split="train") - dataset_val = PreKAND_Dataset(base_path="data/kand-preprocess", split="val") - dataset_test = PreKAND_Dataset(base_path="data/kand-preprocess", split="test") + dataset_train = PreKAND_Dataset( + base_path="data/kand-preprocess", split="train" + ) + dataset_val = PreKAND_Dataset( + base_path="data/kand-preprocess", split="val" + ) + dataset_test = PreKAND_Dataset( + base_path="data/kand-preprocess", split="test" + ) # dataset_ood = KAND_Dataset(base_path='data/kandinsky/data',split='ood') dataset_train.mask_concepts("red-and-squares") print(f"Loaded datasets in {time.time()-start} s.") - print("Len loaders: \n train:", len(dataset_train), "\n val:", len(dataset_val)) - print(" len test:", len(dataset_test)) # , '\n len ood', len(dataset_ood)) + print( + "Len loaders: \n train:", + len(dataset_train), + "\n val:", + len(dataset_val), + ) + print( + " len test:", len(dataset_test) + ) # , '\n len ood', len(dataset_ood)) train_loader = KAND_get_loader( dataset_train, self.args.batch_size, val_test=False ) val_loader = KAND_get_loader(dataset_val, 1000, val_test=True) - test_loader = KAND_get_loader(dataset_test, 1000, val_test=True) + test_loader = KAND_get_loader( + dataset_test, 1000, val_test=True + ) # self.ood_loader = get_loader(dataset_ood, self.args.batch_size, val_test=True) diff --git a/XOR_MNIST/datasets/restrictedmnist.py b/XOR_MNIST/datasets/restrictedmnist.py index 6e29cb1..6fb8067 100644 --- a/XOR_MNIST/datasets/restrictedmnist.py +++ b/XOR_MNIST/datasets/restrictedmnist.py @@ -1,9 +1,13 @@ +from copy import deepcopy + +import numpy as np +from backbones.addmnist_joint import ( + MNISTPairsDecoder, + MNISTPairsEncoder, +) +from backbones.addmnist_single import MNISTSingleEncoder from datasets.utils.base_dataset import BaseDataset, get_loader from datasets.utils.mnist_creation import load_2MNIST -from backbones.addmnist_joint import MNISTPairsEncoder, MNISTPairsDecoder -from backbones.addmnist_single import MNISTSingleEncoder -import numpy as np -from copy import deepcopy class RESTRICTEDMNIST(BaseDataset): @@ -12,7 +16,9 @@ class RESTRICTEDMNIST(BaseDataset): def get_data_loaders(self): dataset_train, dataset_val, dataset_test = load_2MNIST( - c_sup=self.args.c_sup, which_c=self.args.which_c, args=self.args + c_sup=self.args.c_sup, + which_c=self.args.which_c, + args=self.args, ) ood_test = self.get_ood_test(dataset_test) @@ -29,17 +35,23 @@ def get_data_loaders(self): self.train_loader = get_loader( dataset_train, self.args.batch_size, val_test=False ) - self.val_loader = get_loader(dataset_val, self.args.batch_size, val_test=True) - self.test_loader = get_loader(dataset_test, self.args.batch_size, val_test=True) - self.ood_loader = get_loader(ood_test, self.args.batch_size, val_test=False) + self.val_loader = get_loader( + dataset_val, self.args.batch_size, val_test=True + ) + self.test_loader = get_loader( + dataset_test, self.args.batch_size, val_test=True + ) + self.ood_loader = get_loader( + ood_test, self.args.batch_size, val_test=False + ) return self.train_loader, self.val_loader, self.test_loader def get_backbone(self): if not self.args.joint: - return MNISTSingleEncoder(c_dim=5, latent_dim=5), MNISTPairsDecoder( - c_dim=10, latent_dim=10 - ) + return MNISTSingleEncoder( + c_dim=5, latent_dim=5 + ), MNISTPairsDecoder(c_dim=10, latent_dim=10) else: return NotImplementedError("Wrong choice") @@ -154,9 +166,13 @@ def filtrate(self, train_dataset, val_dataset, test_dataset): val_dataset.concepts = val_dataset.concepts[val_mask] test_dataset.concepts = test_dataset.concepts[test_mask] - train_dataset.targets = np.array(train_dataset.targets)[train_mask] + train_dataset.targets = np.array(train_dataset.targets)[ + train_mask + ] val_dataset.targets = np.array(val_dataset.targets)[val_mask] - test_dataset.targets = np.array(test_dataset.targets)[test_mask] + test_dataset.targets = np.array(test_dataset.targets)[ + test_mask + ] return train_dataset, val_dataset, test_dataset @@ -202,7 +218,9 @@ def get_ood_test(self, test_dataset): test_mask_in_range = np.logical_and(mask_col0, mask_col1) test_mask_value = np.logical_and(~test_c_mask1, ~test_c_mask2) - test_mask = np.logical_and(test_mask_in_range, test_mask_value) + test_mask = np.logical_and( + test_mask_in_range, test_mask_value + ) ood_test.data = ood_test.data[test_mask] ood_test.concepts = ood_test.concepts[test_mask] diff --git a/XOR_MNIST/datasets/shortcutmnist.py b/XOR_MNIST/datasets/shortcutmnist.py index 39bbbd3..dc56e74 100644 --- a/XOR_MNIST/datasets/shortcutmnist.py +++ b/XOR_MNIST/datasets/shortcutmnist.py @@ -1,9 +1,13 @@ +from copy import deepcopy + +import numpy as np +from backbones.addmnist_joint import ( + MNISTPairsDecoder, + MNISTPairsEncoder, +) +from backbones.addmnist_single import MNISTSingleEncoder from datasets.utils.base_dataset import BaseDataset, get_loader from datasets.utils.mnist_creation import load_2MNIST -from backbones.addmnist_joint import MNISTPairsEncoder, MNISTPairsDecoder -from backbones.addmnist_single import MNISTSingleEncoder -import numpy as np -from copy import deepcopy class SHORTMNIST(BaseDataset): @@ -12,7 +16,9 @@ class SHORTMNIST(BaseDataset): def get_data_loaders(self): dataset_train, dataset_val, dataset_test = load_2MNIST( - c_sup=self.args.c_sup, which_c=self.args.which_c, args=self.args + c_sup=self.args.c_sup, + which_c=self.args.which_c, + args=self.args, ) ood_test = self.get_ood_test(dataset_test) @@ -27,9 +33,15 @@ def get_data_loaders(self): self.train_loader = get_loader( dataset_train, self.args.batch_size, val_test=False ) - self.val_loader = get_loader(dataset_val, self.args.batch_size, val_test=True) - self.test_loader = get_loader(dataset_test, self.args.batch_size, val_test=True) - self.ood_loader = get_loader(ood_test, self.args.batch_size, val_test=False) + self.val_loader = get_loader( + dataset_val, self.args.batch_size, val_test=True + ) + self.test_loader = get_loader( + dataset_test, self.args.batch_size, val_test=True + ) + self.ood_loader = get_loader( + ood_test, self.args.batch_size, val_test=False + ) return self.train_loader, self.val_loader, self.test_loader @@ -276,9 +288,13 @@ def filtrate(self, train_dataset, val_dataset, test_dataset): val_dataset.concepts = val_dataset.concepts[val_mask] test_dataset.concepts = test_dataset.concepts[test_mask] - train_dataset.targets = np.array(train_dataset.targets)[train_mask] + train_dataset.targets = np.array(train_dataset.targets)[ + train_mask + ] val_dataset.targets = np.array(val_dataset.targets)[val_mask] - test_dataset.targets = np.array(test_dataset.targets)[test_mask] + test_dataset.targets = np.array(test_dataset.targets)[ + test_mask + ] def get_ood_test(self, test_dataset): @@ -364,7 +380,9 @@ def get_ood_test(self, test_dataset): test_mask_in_range = np.logical_and(mask_col0, mask_col1) test_mask_value = np.logical_and(~test_c_mask1, ~test_c_mask2) - test_mask = np.logical_and(test_mask_in_range, test_mask_value) + test_mask = np.logical_and( + test_mask_in_range, test_mask_value + ) ood_test.data = ood_test.data[test_mask] ood_test.concepts = ood_test.concepts[test_mask] diff --git a/XOR_MNIST/datasets/utils/base_dataset.py b/XOR_MNIST/datasets/utils/base_dataset.py index dba67c2..7d637c9 100644 --- a/XOR_MNIST/datasets/utils/base_dataset.py +++ b/XOR_MNIST/datasets/utils/base_dataset.py @@ -1,13 +1,13 @@ from abc import abstractmethod from argparse import Namespace -from torch import nn as nn -from torchvision.transforms import transforms -from torch.utils.data import DataLoader -from typing import Tuple, List -from torchvision import datasets +from typing import List, Tuple + import numpy as np import torch.optim -from torch.utils.data import WeightedRandomSampler +from torch import nn as nn +from torch.utils.data import DataLoader, WeightedRandomSampler +from torchvision import datasets +from torchvision.transforms import transforms class BaseDataset: @@ -86,16 +86,24 @@ def get_loader(dataset, batch_size, num_workers=4, val_test=False): samples_weight = torch.from_numpy(samples_weight) sampler = WeightedRandomSampler( - samples_weight.type("torch.DoubleTensor"), len(samples_weight) + samples_weight.type("torch.DoubleTensor"), + len(samples_weight), ) return DataLoader( - dataset, batch_size=batch_size, num_workers=num_workers, sampler=sampler + dataset, + batch_size=batch_size, + num_workers=num_workers, + sampler=sampler, ) def KAND_get_loader( - dataset, batch_size, num_workers=4, val_test=False, preprocess=False + dataset, + batch_size, + num_workers=4, + val_test=False, + preprocess=False, ): if val_test: diff --git a/XOR_MNIST/datasets/utils/kand_creation.py b/XOR_MNIST/datasets/utils/kand_creation.py index 7f9246c..257fef1 100644 --- a/XOR_MNIST/datasets/utils/kand_creation.py +++ b/XOR_MNIST/datasets/utils/kand_creation.py @@ -1,16 +1,20 @@ +import glob +import itertools import os + +import joblib +import numpy as np import torch import torch.utils.data import torchvision.transforms as transforms -import numpy as np, joblib, glob -import itertools from PIL import Image - from torchvision.datasets.folder import pil_loader class KAND_Dataset(torch.utils.data.Dataset): - def __init__(self, base_path, split, preprocess=False, finetuning=0): + def __init__( + self, base_path, split, preprocess=False, finetuning=0 + ): # path and train/val/test type self.base_path = base_path @@ -53,7 +57,9 @@ def __init__(self, base_path, split, preprocess=False, finetuning=0): labels = np.array(labels).reshape(1, -1) self.labels.append(labels) - concepts = np.concatenate(concepts, axis=0).reshape(1, -1, 6) + concepts = np.concatenate(concepts, axis=0).reshape( + 1, -1, 6 + ) self.concepts.append(concepts) self.concepts = np.concatenate(self.concepts, axis=0) @@ -69,7 +75,11 @@ def __init__(self, base_path, split, preprocess=False, finetuning=0): def mask_concepts(self, cond): start = self.finetuning * 100 if start > 0: - print("Activate finetuning on ", start, "elements of the training set") + print( + "Activate finetuning on ", + start, + "elements of the training set", + ) if cond == "red": self.concepts[start:, :, :3] = -1 for i, j in itertools.product(range(3), range(3, 6)): @@ -106,7 +116,10 @@ def __getitem__(self, item): img_id = self.img_number[item] image_id = os.path.join( - self.base_path, self.split, "images", str(img_id).zfill(5) + ".png" + self.base_path, + self.split, + "images", + str(img_id).zfill(5) + ".png", ) image = pil_loader(image_id) @@ -120,7 +133,9 @@ def __len__(self): class PreKAND_Dataset(torch.utils.data.Dataset): - def __init__(self, base_path, split, preprocess=False, finetuning=0): + def __init__( + self, base_path, split, preprocess=False, finetuning=0 + ): self.base_path = base_path self.split = split @@ -138,13 +153,22 @@ def __init__(self, base_path, split, preprocess=False, finetuning=0): for item in range(len(self.list_images)): img_id = os.path.join( - self.base_path, self.split, "images", str(item).zfill(5) + ".npy" + self.base_path, + self.split, + "images", + str(item).zfill(5) + ".npy", ) tgt_id = os.path.join( - self.base_path, self.split, "labels", str(item).zfill(5) + ".npy" + self.base_path, + self.split, + "labels", + str(item).zfill(5) + ".npy", ) cnp_id = os.path.join( - self.base_path, self.split, "concepts", str(item).zfill(5) + ".npy" + self.base_path, + self.split, + "concepts", + str(item).zfill(5) + ".npy", ) self.imgs.append(np.load(img_id)) @@ -158,7 +182,11 @@ def __init__(self, base_path, split, preprocess=False, finetuning=0): def mask_concepts(self, cond): start = self.finetuning * 100 if start > 0: - print("Activate finetuning on ", start, "elements of the training set") + print( + "Activate finetuning on ", + start, + "elements of the training set", + ) if cond == "red": self.concepts[start:, :, :3] = -1 for i, j in itertools.product(range(3), range(3, 6)): @@ -190,7 +218,9 @@ def __len__(self): class miniKAND_Dataset(torch.utils.data.Dataset): - def __init__(self, base_path, split, preprocess=False, finetuning=0): + def __init__( + self, base_path, split, preprocess=False, finetuning=0 + ): # path and train/val/test type self.base_path = base_path @@ -201,7 +231,9 @@ def __init__(self, base_path, split, preprocess=False, finetuning=0): self.preprocess = preprocess # collecting images - self.list_images = glob.glob(os.path.join(self.base_path, self.split, "*")) + self.list_images = glob.glob( + os.path.join(self.base_path, self.split, "*") + ) self.list_images = list(sorted(self.list_images)) self.img_number = [i for i in range(len(self.list_images))] @@ -232,7 +264,9 @@ def __init__(self, base_path, split, preprocess=False, finetuning=0): labels = np.array(labels).reshape(1, -1) self.labels.append(labels) - concepts = np.concatenate(concepts, axis=0).reshape(1, -1, 6) + concepts = np.concatenate(concepts, axis=0).reshape( + 1, -1, 6 + ) self.concepts.append(concepts) self.concepts = np.concatenate(self.concepts, axis=0) @@ -244,7 +278,11 @@ def __init__(self, base_path, split, preprocess=False, finetuning=0): def mask_concepts(self, cond, obj=None): start = self.finetuning if start > 0: - print("Activate finetuning on ", start, "elements of the training set") + print( + "Activate finetuning on ", + start, + "elements of the training set", + ) assert obj is None or (obj >= 0 and obj <= 8) @@ -255,9 +293,13 @@ def mask_concepts(self, cond, obj=None): n_obj = obj % self.concepts.shape[1] for i, j in itertools.product(range(3), range(6)): if i == n_figure and ( - j == n_obj or j == (n_obj + self.concepts.shape[2] // 2) + j == n_obj + or j == (n_obj + self.concepts.shape[2] // 2) ): - print("Il boia colpisce", self.concepts[:start, :, :].shape) + print( + "Il boia colpisce", + self.concepts[:start, :, :].shape, + ) print("La vittima è", self.concepts[:start, i, j]) pass else: @@ -355,7 +397,6 @@ def __len__(self): # # print(train_data[i][2],'->', train_data[i][1]) if __name__ == "__main__": - print("DIO MERDA") train_data = miniKAND_Dataset("../../data/kand-3k", "train") print(len(train_data)) diff --git a/XOR_MNIST/datasets/utils/mnist_creation.py b/XOR_MNIST/datasets/utils/mnist_creation.py index 7be3f51..5bff6c1 100644 --- a/XOR_MNIST/datasets/utils/mnist_creation.py +++ b/XOR_MNIST/datasets/utils/mnist_creation.py @@ -1,37 +1,50 @@ +import copy +import itertools import os from itertools import product from pathlib import Path -from random import sample, choice -from torchvision import transforms +from random import choice, sample + import numpy as np import torch from torch import Tensor, load from torch.utils.data import Dataset -from torchvision import datasets +from torchvision import datasets, transforms from tqdm import tqdm -import copy, itertools def get_label(c1, c2, labels, args): if args.task == "addition": return labels - elif args.task == "product" and not args.model in ["mnistltn", "mnistltnrec"]: + elif args.task == "product" and not args.model in [ + "mnistltn", + "mnistltnrec", + ]: n_queries = [0] for i, j in itertools.product(range(1, 10), range(1, 10)): n_queries.append(i * j) n_queries = np.unique(np.array(n_queries)) prod = c1 * c2 - q = [np.argmax(np.array(n_queries == p, dtype=int)) for p in prod] + q = [ + np.argmax(np.array(n_queries == p, dtype=int)) + for p in prod + ] q = np.array(q) return q - elif args.task == "product" and args.model in ["mnistltn", "mnistltnrec"]: + elif args.task == "product" and args.model in [ + "mnistltn", + "mnistltnrec", + ]: return c1 * c2 - elif args.task == "multiop" and not args.model in ["mnistltn", "mnistltnrec"]: + elif args.task == "multiop" and not args.model in [ + "mnistltn", + "mnistltnrec", + ]: multiop = np.vstack((c1 + c2, c1 * c2)).T ls = [] @@ -50,7 +63,10 @@ def get_label(c1, c2, labels, args): ls.append(3) return np.array(ls) - elif args.task == "multiop" and args.model in ["mnistltn", "mnistltnrec"]: + elif args.task == "multiop" and args.model in [ + "mnistltn", + "mnistltnrec", + ]: multiop = c1**2 + c2**2 + c1 * c2 return multiop @@ -60,13 +76,18 @@ class nMNIST(Dataset): def __init__(self, split: str, data_path, args): print(f"Loading {split} data") - self.data, self.labels = self.read_data(path=data_path, split=split) + self.data, self.labels = self.read_data( + path=data_path, split=split + ) self.targets = self.labels[:, -1:].reshape(-1) self.concepts = self.labels[:, :-1] self.real_concepts = np.copy(self.labels[:, :-1]) self.targets = get_label( - self.real_concepts[:, 0], self.real_concepts[:, 1], self.targets, args + self.real_concepts[:, 0], + self.real_concepts[:, 1], + self.targets, + args, ) normalize = transforms.Normalize((0.1307,), (0.3081,)) @@ -110,7 +131,8 @@ def read_data(self, path, split): def reset_counter(self): self.world_counter = { - c: self.samples_x_world // self.batch_size for c in self.worlds + c: self.samples_x_world // self.batch_size + for c in self.worlds } # self.world_counter = {c: self.samples_x_world for c in self.worlds} @@ -127,10 +149,16 @@ def create_sample(X, target_sequence, digit2idx): def create_dataset( - n_digit=2, sequence_len=2, samples_x_world=100, train=True, download=False + n_digit=2, + sequence_len=2, + samples_x_world=100, + train=True, + download=False, ): # Download data - MNIST = datasets.MNIST(root="./data/raw/", train=train, download=download) + MNIST = datasets.MNIST( + root="./data/raw/", train=train, download=download + ) x, y = MNIST.data, MNIST.targets @@ -157,7 +185,9 @@ def create_dataset( for i, label in enumerate(labels): if tuple(label[:2]) == k: v.add(i) - label2idx = {k: torch.tensor(list(v)) for k, v in label2idx.items()} + label2idx = { + k: torch.tensor(list(v)) for k, v in label2idx.items() + } return np.array(imgs).astype("int32"), np.array(labels), label2idx @@ -172,8 +202,12 @@ def check_dataset(n_digits, data_folder, data_file, dataset_dim): print("No dataset found.") # Define dataset dimension so to have teh same number of worlds n_worlds = n_digits * n_digits - samples_x_world = {k: int(d / n_worlds) for k, d in dataset_dim.items()} - dataset_dim = {k: s * n_worlds for k, s in samples_x_world.items()} + samples_x_world = { + k: int(d / n_worlds) for k, d in dataset_dim.items() + } + dataset_dim = { + k: s * n_worlds for k, s in samples_x_world.items() + } train_imgs, train_labels, train_indexes = create_dataset( n_digit=n_digits, @@ -208,11 +242,17 @@ def check_dataset(n_digits, data_folder, data_file, dataset_dim): "test": {"images": test_imgs, "labels": test_labels}, } - indexes = {"train": train_indexes, "val": val_indexes, "test": test_indexes} + indexes = { + "train": train_indexes, + "val": val_indexes, + "test": test_indexes, + } torch.save(data, data_path) for key, value in indexes.items(): - torch.save(value, os.path.join(data_folder, f"{key}_indexes.pt")) + torch.save( + value, os.path.join(data_folder, f"{key}_indexes.pt") + ) print(f"Dataset saved in {data_folder}") @@ -226,11 +266,15 @@ def load_2MNIST( ): # Load data data_folder = os.path.dirname(os.path.abspath(__file__)) - data_folder = os.path.join(data_folder, f"2mnist_{n_digits}digits") + data_folder = os.path.join( + data_folder, f"2mnist_{n_digits}digits" + ) data_file = f"2mnist_{n_digits}digits.pt" # Check whether dataset exists, if not build it - check_dataset(n_digits, data_folder, data_file, dataset_dimensions) + check_dataset( + n_digits, data_folder, data_file, dataset_dimensions + ) train_set, val_set, test_set = load_data( data_file=data_file, data_folder=data_folder, @@ -247,7 +291,9 @@ def generate_r_seq(size): return np.random.rand(size) -def load_data(data_file, data_folder, c_sup=1, which_c=[-1], args=None): +def load_data( + data_file, data_folder, c_sup=1, which_c=[-1], args=None +): # Prepare data data_path = os.path.join(data_folder, data_file) diff --git a/XOR_MNIST/datasets/utils/old_kand_creation.py b/XOR_MNIST/datasets/utils/old_kand_creation.py index 3ea048e..3585738 100644 --- a/XOR_MNIST/datasets/utils/old_kand_creation.py +++ b/XOR_MNIST/datasets/utils/old_kand_creation.py @@ -1,9 +1,11 @@ +import glob import os + +import joblib +import numpy as np import torch import torch.utils.data import torchvision.transforms as transforms -import numpy as np, joblib, glob - from torchvision.datasets.folder import pil_loader @@ -57,7 +59,10 @@ def __getitem__(self, item): img_id = self.img_number[item] image_id = os.path.join( - self.base_path, self.split, "images", str(img_id).zfill(5) + ".png" + self.base_path, + self.split, + "images", + str(img_id).zfill(5) + ".png", ) image = pil_loader(image_id) @@ -117,8 +122,6 @@ def logic_triangle_circle(concepts): if __name__ == "__main__": - print("DIO MERDA") - train_data = KAND_Dataset("../../data/kandinsky-30k", "train") print(len(train_data)) @@ -150,7 +153,9 @@ def logic_triangle_circle(concepts): print(dset.split, " ", frac) - print(dset.split, " ", np.sum(labels[:, -1] == 1) / len(labels)) + print( + dset.split, " ", np.sum(labels[:, -1] == 1) / len(labels) + ) print(labels.shape) diff --git a/XOR_MNIST/example/dpl_models.py b/XOR_MNIST/example/dpl_models.py index e460e33..32618b4 100644 --- a/XOR_MNIST/example/dpl_models.py +++ b/XOR_MNIST/example/dpl_models.py @@ -1,6 +1,7 @@ -import torch -import numpy as np import itertools + +import numpy as np +import torch import torch.nn.functional as F @@ -8,7 +9,9 @@ def create_world_matrix(): W = torch.zeros((2, 2, 2, 2), dtype=torch.float) for i, j, k in itertools.product(range(2), range(2), range(2)): W[:, i, j, k] = ( - torch.tensor([1, 0]) if (i + j + k % 2 == 0) else torch.tensor([0, 1]) + torch.tensor([1, 0]) + if (i + j + k % 2 == 0) + else torch.tensor([0, 1]) ) return W @@ -25,9 +28,15 @@ def __init__(self, hidden_neurons): super(MLPRecon, self).__init__() self.hidden = hidden_neurons - self.layer1 = torch.nn.Linear(3, self.hidden, dtype=torch.float) - self.layer2 = torch.nn.Linear(self.hidden, self.hidden, dtype=torch.float) - self.layer3 = torch.nn.Linear(self.hidden, 3, dtype=torch.float) + self.layer1 = torch.nn.Linear( + 3, self.hidden, dtype=torch.float + ) + self.layer2 = torch.nn.Linear( + self.hidden, self.hidden, dtype=torch.float + ) + self.layer3 = torch.nn.Linear( + self.hidden, 3, dtype=torch.float + ) def forward(self, x): x = torch.nn.ReLU()(self.layer1(x)) @@ -49,12 +58,22 @@ def __init__(self, hidden_neurons, args): if not args.disent: self.hidden = hidden_neurons - self.layer1 = torch.nn.Linear(3, self.hidden, dtype=torch.float) - self.layer2 = torch.nn.Linear(self.hidden, self.hidden, dtype=torch.float) - self.layer3 = torch.nn.Linear(self.hidden, 3, dtype=torch.float) + self.layer1 = torch.nn.Linear( + 3, self.hidden, dtype=torch.float + ) + self.layer2 = torch.nn.Linear( + self.hidden, self.hidden, dtype=torch.float + ) + self.layer3 = torch.nn.Linear( + self.hidden, 3, dtype=torch.float + ) self.mlp = torch.nn.Sequential( - self.layer1, torch.nn.ReLU(), self.layer2, torch.nn.ReLU(), self.layer3 + self.layer1, + torch.nn.ReLU(), + self.layer2, + torch.nn.ReLU(), + self.layer3, ) elif not args.s_w: self.layer1 = torch.nn.Linear(1, 1, dtype=torch.float) @@ -88,12 +107,23 @@ def forward(self, x): pC = torch.zeros((8, 6), device=self.device) for i in range(3): - pC[:, 2 * i] = (1 - logitC[:, i].sigmoid() + 1e-5) / (1 + 2 * 1e-5) - pC[:, 2 * i + 1] = (logitC[:, i].sigmoid() + 1e-5) / (1 + 2 * 1e-5) + pC[:, 2 * i] = (1 - logitC[:, i].sigmoid() + 1e-5) / ( + 1 + 2 * 1e-5 + ) + pC[:, 2 * i + 1] = (logitC[:, i].sigmoid() + 1e-5) / ( + 1 + 2 * 1e-5 + ) pred = torch.zeros((8, 2), device=self.device) - for i, j, k in itertools.product(range(2), range(2), range(2)): - pred[:, 1] += pC[:, i] * pC[:, 2 + j] * pC[:, 4 + k] * self.W[i, j, k] + for i, j, k in itertools.product( + range(2), range(2), range(2) + ): + pred[:, 1] += ( + pC[:, i] + * pC[:, 2 + j] + * pC[:, 4 + k] + * self.W[i, j, k] + ) pred[:, 0] = 1 - pred[:, 1] pred = (pred + 1e-3) / (1 + 2 * 1e-3) # print(pC, pred) @@ -120,7 +150,9 @@ def forward(self, x): C[:, 2 * i] = logitC[:, i] / 2 C[:, 2 * i + 1] = -logitC[:, i] / 2 C[:, :2] = F.gumbel_softmax(C[:, :2], tau=1, hard=True, dim=1) - C[:, 2:4] = F.gumbel_softmax(C[:, 2:4], tau=1, hard=True, dim=1) + C[:, 2:4] = F.gumbel_softmax( + C[:, 2:4], tau=1, hard=True, dim=1 + ) C[:, 4:] = F.gumbel_softmax(C[:, 4:], tau=1, hard=True, dim=1) recon = self.decoder(C[:, [1, 3, 5]]) diff --git a/XOR_MNIST/example/dpl_train.py b/XOR_MNIST/example/dpl_train.py index 078c15f..30b4fd6 100644 --- a/XOR_MNIST/example/dpl_train.py +++ b/XOR_MNIST/example/dpl_train.py @@ -1,14 +1,13 @@ -import torch -import numpy as np import itertools -import torch.nn.functional as F -from warmup_scheduler import GradualWarmupScheduler +import numpy as np +import torch +import torch.nn.functional as F +import wandb from example.dpl_models import DPL, DPL_R -from example.xor_utils import progress_bar from example.nesy_losses import shannon_entropy - -import wandb +from example.xor_utils import progress_bar +from warmup_scheduler import GradualWarmupScheduler torch.set_printoptions(precision=3, sci_mode=False) @@ -19,8 +18,12 @@ def train_DPL(G, Y, args): model = DPL(3, args) # define optimizer - optim = torch.optim.Adam(model.parameters(), lr=args.lr) # lr=0.05 - scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.99) + optim = torch.optim.Adam( + model.parameters(), lr=args.lr + ) # lr=0.05 + scheduler = torch.optim.lr_scheduler.ExponentialLR( + optim, gamma=0.99 + ) w_scheduler = GradualWarmupScheduler(optim, 1, 10) # default @@ -44,7 +47,9 @@ def train_DPL(G, Y, args): l1 = loss.item() - l0 if args.entropy: - loss += args.gamma * (1 - shannon_entropy(torch.sigmoid(c))) + loss += args.gamma * ( + 1 - shannon_entropy(torch.sigmoid(c)) + ) l2 = loss.item() - l1 - l0 if args.wandb: @@ -66,7 +71,8 @@ def train_DPL(G, Y, args): # early stopping if ( - np.abs(update_loss.mean() - update_loss[j % 10]) / update_loss.mean() + np.abs(update_loss.mean() - update_loss[j % 10]) + / update_loss.mean() < 0.0001 ): break @@ -89,7 +95,9 @@ def train_DPL_REC(G, Y, args): # define optimizer optim = torch.optim.Adam(model.parameters(), lr=args.lr) - scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.9) + scheduler = torch.optim.lr_scheduler.ExponentialLR( + optim, gamma=0.9 + ) w_scheduler = GradualWarmupScheduler(optim, 1, 10) # default @@ -106,9 +114,15 @@ def train_DPL_REC(G, Y, args): l0, l1 = loss.item(), 0 - loss += args.gamma * F.binary_cross_entropy(recon[:, 0], G[:, 0]) - loss += args.gamma * F.binary_cross_entropy(recon[:, 1], G[:, 1]) - loss += args.gamma * F.binary_cross_entropy(recon[:, 2], G[:, 2]) + loss += args.gamma * F.binary_cross_entropy( + recon[:, 0], G[:, 0] + ) + loss += args.gamma * F.binary_cross_entropy( + recon[:, 1], G[:, 1] + ) + loss += args.gamma * F.binary_cross_entropy( + recon[:, 2], G[:, 2] + ) lrec = loss.item() - l0 @@ -119,7 +133,9 @@ def train_DPL_REC(G, Y, args): l1 = loss.item() - l0 - lrec if args.entropy: - loss += args.gamma * (1 - shannon_entropy(torch.sigmoid(c))) + loss += args.gamma * ( + 1 - shannon_entropy(torch.sigmoid(c)) + ) l2 = loss.item() - l1 - l0 - lrec if args.wandb: @@ -146,7 +162,8 @@ def train_DPL_REC(G, Y, args): # early stopping if ( - np.abs(update_loss.mean() - update_loss[j % 10]) / update_loss.mean() + np.abs(update_loss.mean() - update_loss[j % 10]) + / update_loss.mean() < 0.0001 ): break diff --git a/XOR_MNIST/example/ltn_models.py b/XOR_MNIST/example/ltn_models.py index 7b4550f..26ef11a 100644 --- a/XOR_MNIST/example/ltn_models.py +++ b/XOR_MNIST/example/ltn_models.py @@ -1,8 +1,9 @@ -import torch -import numpy as np import itertools -import torch.nn.functional as F + import ltn +import numpy as np +import torch +import torch.nn.functional as F class MLPRecon(torch.nn.Module): @@ -10,9 +11,15 @@ def __init__(self, hidden_neurons): super(MLPRecon, self).__init__() self.hidden = hidden_neurons - self.layer1 = torch.nn.Linear(3, self.hidden, dtype=torch.float) - self.layer2 = torch.nn.Linear(self.hidden, self.hidden, dtype=torch.float) - self.layer3 = torch.nn.Linear(self.hidden, 3, dtype=torch.float) + self.layer1 = torch.nn.Linear( + 3, self.hidden, dtype=torch.float + ) + self.layer2 = torch.nn.Linear( + self.hidden, self.hidden, dtype=torch.float + ) + self.layer3 = torch.nn.Linear( + self.hidden, 3, dtype=torch.float + ) def forward(self, x): x = torch.nn.ReLU()(self.layer1(x)) @@ -40,12 +47,22 @@ def __init__(self, hidden_neurons, args): if not args.disent: self.hidden = hidden_neurons - self.layer1 = torch.nn.Linear(3, self.hidden, dtype=torch.float) - self.layer2 = torch.nn.Linear(self.hidden, self.hidden, dtype=torch.float) - self.layer3 = torch.nn.Linear(self.hidden, 3, dtype=torch.float) + self.layer1 = torch.nn.Linear( + 3, self.hidden, dtype=torch.float + ) + self.layer2 = torch.nn.Linear( + self.hidden, self.hidden, dtype=torch.float + ) + self.layer3 = torch.nn.Linear( + self.hidden, 3, dtype=torch.float + ) self.mlp = torch.nn.Sequential( - self.layer1, torch.nn.ReLU(), self.layer2, torch.nn.ReLU(), self.layer3 + self.layer1, + torch.nn.ReLU(), + self.layer2, + torch.nn.ReLU(), + self.layer3, ) elif not args.s_w: self.layer1 = torch.nn.Linear(1, 1, dtype=torch.float) @@ -81,12 +98,18 @@ def forward(self, x): pC = torch.zeros((8, 6), device=self.device) for i in range(3): - pC[:, 2 * i] = (1 - logitC[:, i].sigmoid() + 1e-3) / (1 + 2 * 1e-3) - pC[:, 2 * i + 1] = (logitC[:, i].sigmoid() + 1e-3) / (1 + 2 * 1e-3) + pC[:, 2 * i] = (1 - logitC[:, i].sigmoid() + 1e-3) / ( + 1 + 2 * 1e-3 + ) + pC[:, 2 * i + 1] = (logitC[:, i].sigmoid() + 1e-3) / ( + 1 + 2 * 1e-3 + ) # explicit logical rule on argmax pred = ( - pC[:, :2].argmax(dim=1) + pC[:, 2:4].argmax(dim=1) + pC[:, 4:].argmax(dim=1) + pC[:, :2].argmax(dim=1) + + pC[:, 2:4].argmax(dim=1) + + pC[:, 4:].argmax(dim=1) ) % 2 # assert torch.sum(torch.round(pred, decimals=3))/len(pred) == 1, (pred,torch.sum(torch.round(pred, decimals=3))/len(pred)) # pred = torch.einsum('ki,kj,kl, mijl->km', pC[:,:2], pC[:,2:4], pC[:,4:], self.W) diff --git a/XOR_MNIST/example/ltn_train.py b/XOR_MNIST/example/ltn_train.py index c05b77d..27ee68d 100644 --- a/XOR_MNIST/example/ltn_train.py +++ b/XOR_MNIST/example/ltn_train.py @@ -1,13 +1,11 @@ -import torch import numpy as np +import torch import torch.nn.functional as F -from warmup_scheduler import GradualWarmupScheduler - +import wandb from example.ltn_models import LTN, LTN_R -from example.xor_utils import progress_bar from example.nesy_losses import sat_agg_loss, shannon_entropy - -import wandb +from example.xor_utils import progress_bar +from warmup_scheduler import GradualWarmupScheduler def train_LTN(G, Y, args): @@ -16,7 +14,9 @@ def train_LTN(G, Y, args): # define optimizer optim = torch.optim.Adam(model.parameters(), lr=args.lr) - scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.9) + scheduler = torch.optim.lr_scheduler.ExponentialLR( + optim, gamma=0.9 + ) w_scheduler = GradualWarmupScheduler(optim, 1, 10) # default @@ -31,7 +31,9 @@ def train_LTN(G, Y, args): optim.zero_grad() pred, logitC, pC = model(G) # loss = F.cross_entropy(pred, Y) - loss = sat_agg_loss(model, pC[:, 0:2], pC[:, 2:4], pC[:, 4:], Y, grade) + loss = sat_agg_loss( + model, pC[:, 0:2], pC[:, 2:4], pC[:, 4:], Y, grade + ) l0, l1 = loss.item(), 0 @@ -42,7 +44,9 @@ def train_LTN(G, Y, args): l1 = loss.item() - l0 if args.entropy: - loss += args.gamma * (1 - shannon_entropy(torch.sigmoid(logitC))) + loss += args.gamma * ( + 1 - shannon_entropy(torch.sigmoid(logitC)) + ) l2 = loss.item() - l1 - l0 if args.wandb: @@ -71,7 +75,8 @@ def train_LTN(G, Y, args): # early stopping if ( - np.abs(update_loss.mean() - update_loss[j % 10]) / update_loss.mean() + np.abs(update_loss.mean() - update_loss[j % 10]) + / update_loss.mean() < 0.0001 ): break @@ -88,7 +93,9 @@ def train_LTN_R(G, Y, args): # define optimizer optim = torch.optim.Adam(model.parameters(), lr=args.lr) - scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.9) + scheduler = torch.optim.lr_scheduler.ExponentialLR( + optim, gamma=0.9 + ) w_scheduler = GradualWarmupScheduler(optim, 1, 10) # default @@ -103,7 +110,9 @@ def train_LTN_R(G, Y, args): optim.zero_grad() pred, recon, logitC, pC = model(G) # loss = F.cross_entropy(pred, Y) - loss = sat_agg_loss(model.encoder, pC[:, 0:2], pC[:, 2:4], pC[:, 4:], Y, grade) + loss = sat_agg_loss( + model.encoder, pC[:, 0:2], pC[:, 2:4], pC[:, 4:], Y, grade + ) l0, l1 = loss.item(), 0 @@ -118,7 +127,9 @@ def train_LTN_R(G, Y, args): l1 = loss.item() - l0 - lrec if args.entropy: - loss += args.gamma * (1 - shannon_entropy(torch.sigmoid(logitC))) + loss += args.gamma * ( + 1 - shannon_entropy(torch.sigmoid(logitC)) + ) l2 = loss.item() - l1 - l0 - lrec if args.wandb: @@ -147,7 +158,8 @@ def train_LTN_R(G, Y, args): # early stopping if ( - np.abs(update_loss.mean() - update_loss[j % 10]) / update_loss.mean() + np.abs(update_loss.mean() - update_loss[j % 10]) + / update_loss.mean() < 0.0001 ): break diff --git a/XOR_MNIST/example/nesy_losses.py b/XOR_MNIST/example/nesy_losses.py index d42e1ab..55343bc 100644 --- a/XOR_MNIST/example/nesy_losses.py +++ b/XOR_MNIST/example/nesy_losses.py @@ -1,7 +1,8 @@ -import torch -import numpy as np import itertools + import ltn +import numpy as np +import torch from example.ltn_models import LTN @@ -33,7 +34,13 @@ def shannon_entropy(pC): for i in range(3): aC = torch.mean(pC, dim=0) - H += -(aC[i] * torch.log(aC[i]) + (1 - aC[i]) * torch.log(1 - aC[i])) / 3 + H += ( + -( + aC[i] * torch.log(aC[i]) + + (1 - aC[i]) * torch.log(1 - aC[i]) + ) + / 3 + ) return H / np.log(2) @@ -49,9 +56,13 @@ def semantic_loss(pC, Y): loss = 0 for i, j, k in itertools.product(range(2), range(2), range(2)): if (i + j + k) % 2 == 0: - loss += pY[:, 0] * pC[:, 0 + i] * pC[:, 2 + j] * pC[:, 4 + k] + loss += ( + pY[:, 0] * pC[:, 0 + i] * pC[:, 2 + j] * pC[:, 4 + k] + ) else: - loss += pY[:, 1] * pC[:, 0 + i] * pC[:, 2 + j] * pC[:, 4 + k] + loss += ( + pY[:, 1] * pC[:, 0 + i] * pC[:, 2 + j] * pC[:, 4 + k] + ) loss += 1e-5 return -loss.log().mean() @@ -70,15 +81,22 @@ def sat_agg_loss(model: LTN, p1, p2, p3, labels, grade): b_3 = ltn.Variable("b_3", torch.tensor(range(2))) And = ltn.Connective(ltn.fuzzy_ops.AndProd()) - Exists = ltn.Quantifier(ltn.fuzzy_ops.AggregPMean(p=2), quantifier="e") - Forall = ltn.Quantifier(ltn.fuzzy_ops.AggregPMeanError(p=2), quantifier="f") + Exists = ltn.Quantifier( + ltn.fuzzy_ops.AggregPMean(p=2), quantifier="e" + ) + Forall = ltn.Quantifier( + ltn.fuzzy_ops.AggregPMeanError(p=2), quantifier="f" + ) SatAgg = ltn.fuzzy_ops.SatAgg() sat_agg = Forall( ltn.diag(bit1, bit2, bit3, y_true), Exists( [b_1, b_2, b_3], - And(And(model.ltn(bit1, b_1), model.ltn(bit2, b_2)), model.ltn(bit3, b_3)), + And( + And(model.ltn(bit1, b_1), model.ltn(bit2, b_2)), + model.ltn(bit3, b_3), + ), cond_vars=[b_1, b_2, b_3, y_true], cond_fn=lambda b_1, b_2, b_3, z: torch.eq( (b_1.value + b_2.value + b_3.value) % 2, z.value diff --git a/XOR_MNIST/example/sl_models.py b/XOR_MNIST/example/sl_models.py index b2330a6..2f4b9d5 100644 --- a/XOR_MNIST/example/sl_models.py +++ b/XOR_MNIST/example/sl_models.py @@ -7,9 +7,15 @@ def __init__(self, hidden_neurons): super(MLPRecon, self).__init__() self.hidden = hidden_neurons - self.layer1 = torch.nn.Linear(3, self.hidden, dtype=torch.float) - self.layer2 = torch.nn.Linear(self.hidden, self.hidden, dtype=torch.float) - self.layer3 = torch.nn.Linear(self.hidden, 3, dtype=torch.float) + self.layer1 = torch.nn.Linear( + 3, self.hidden, dtype=torch.float + ) + self.layer2 = torch.nn.Linear( + self.hidden, self.hidden, dtype=torch.float + ) + self.layer3 = torch.nn.Linear( + self.hidden, 3, dtype=torch.float + ) def forward(self, x): x = torch.nn.ReLU()(self.layer1(x)) @@ -28,12 +34,22 @@ def __init__(self, hidden_neurons, args): self.args = args if not args.disent: - self.layer1 = torch.nn.Linear(3, self.hidden, dtype=torch.float) - self.layer2 = torch.nn.Linear(self.hidden, self.hidden, dtype=torch.float) - self.layer3 = torch.nn.Linear(self.hidden, 3, dtype=torch.float) + self.layer1 = torch.nn.Linear( + 3, self.hidden, dtype=torch.float + ) + self.layer2 = torch.nn.Linear( + self.hidden, self.hidden, dtype=torch.float + ) + self.layer3 = torch.nn.Linear( + self.hidden, 3, dtype=torch.float + ) self.mlp = torch.nn.Sequential( - self.layer1, torch.nn.ReLU(), self.layer2, torch.nn.ReLU(), self.layer3 + self.layer1, + torch.nn.ReLU(), + self.layer2, + torch.nn.ReLU(), + self.layer3, ) elif not args.s_w: self.layer1 = torch.nn.Linear(1, 1, dtype=torch.float) @@ -55,7 +71,9 @@ def __init__(self, hidden_neurons, args): self.device = "cpu" - self.W = torch.nn.Parameter(torch.rand(size=(2, 2, 2, 2))).to(self.device) + self.W = torch.nn.Parameter(torch.rand(size=(2, 2, 2, 2))).to( + self.device + ) def forward(self, x): if not self.args.disent: @@ -73,8 +91,12 @@ def forward(self, x): # get probs pC = torch.zeros((8, 6), device=self.device) for i in range(3): - pC[:, 2 * i] = (1 - logitC[:, i].sigmoid() + 1e-5) / (1 + 2 * 1e-5) - pC[:, 2 * i + 1] = (logitC[:, i].sigmoid() + 1e-5) / (1 + 2 * 1e-5) + pC[:, 2 * i] = (1 - logitC[:, i].sigmoid() + 1e-5) / ( + 1 + 2 * 1e-5 + ) + pC[:, 2 * i + 1] = (logitC[:, i].sigmoid() + 1e-5) / ( + 1 + 2 * 1e-5 + ) pred = self.predictor(torch.tanh(logitC)) return pred, logitC, pC @@ -97,7 +119,9 @@ def forward(self, x): C[:, 2 * i] = logitC[:, i] / 2 C[:, 2 * i + 1] = -logitC[:, i] / 2 C[:, :2] = F.gumbel_softmax(C[:, :2], tau=1, hard=True, dim=1) - C[:, 2:4] = F.gumbel_softmax(C[:, 2:4], tau=1, hard=True, dim=1) + C[:, 2:4] = F.gumbel_softmax( + C[:, 2:4], tau=1, hard=True, dim=1 + ) C[:, 4:] = F.gumbel_softmax(C[:, 4:], tau=1, hard=True, dim=1) recon = self.decoder(C[:, [1, 3, 5]]) diff --git a/XOR_MNIST/example/sl_train.py b/XOR_MNIST/example/sl_train.py index d20314b..d03f668 100644 --- a/XOR_MNIST/example/sl_train.py +++ b/XOR_MNIST/example/sl_train.py @@ -1,13 +1,11 @@ -import torch import numpy as np +import torch import torch.nn.functional as F -from warmup_scheduler import GradualWarmupScheduler - +import wandb +from example.nesy_losses import semantic_loss, shannon_entropy from example.sl_models import SL, SL_R from example.xor_utils import progress_bar -from example.nesy_losses import semantic_loss, shannon_entropy - -import wandb +from warmup_scheduler import GradualWarmupScheduler def train_SL(G, Y, args): @@ -16,7 +14,9 @@ def train_SL(G, Y, args): # define optimizer optim = torch.optim.Adam(model.parameters(), lr=args.lr) - scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.9) + scheduler = torch.optim.lr_scheduler.ExponentialLR( + optim, gamma=0.9 + ) w_scheduler = GradualWarmupScheduler(optim, 1, 10) # default @@ -44,7 +44,9 @@ def train_SL(G, Y, args): l1 = loss.item() - l0 - ls if args.entropy: - loss += args.gamma * (1 - shannon_entropy(torch.sigmoid(logitC))) + loss += args.gamma * ( + 1 - shannon_entropy(torch.sigmoid(logitC)) + ) l2 = loss.item() - l1 - l0 - ls if args.wandb: @@ -71,7 +73,8 @@ def train_SL(G, Y, args): # early stopping if ( - np.abs(update_loss.mean() - update_loss[j % 10]) / update_loss.mean() + np.abs(update_loss.mean() - update_loss[j % 10]) + / update_loss.mean() < 0.0001 ): break @@ -88,7 +91,9 @@ def train_SL_R(G, Y, args): # define optimizer optim = torch.optim.Adam(model.parameters(), lr=args.lr) - scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.9) + scheduler = torch.optim.lr_scheduler.ExponentialLR( + optim, gamma=0.9 + ) w_scheduler = GradualWarmupScheduler(optim, 1, 10) # default @@ -117,11 +122,20 @@ def train_SL_R(G, Y, args): l1 = loss.item() - l0 - ls - lrec if args.entropy: - loss += args.gamma * (1 - shannon_entropy(torch.sigmoid(logitC))) + loss += args.gamma * ( + 1 - shannon_entropy(torch.sigmoid(logitC)) + ) l2 = loss.item() - l1 - l0 - ls - lrec if args.wandb: - wandb.log({"y-loss": l0, "s-loss": ls, "rec-loss": lrec, "epoch": step}) + wandb.log( + { + "y-loss": l0, + "s-loss": ls, + "rec-loss": lrec, + "epoch": step, + } + ) if args.csup is not None: wandb.log({"c-loss": l1}) if args.entropy: @@ -144,7 +158,8 @@ def train_SL_R(G, Y, args): # early stopping if ( - np.abs(update_loss.mean() - update_loss[j % 10]) / update_loss.mean() + np.abs(update_loss.mean() - update_loss[j % 10]) + / update_loss.mean() < 0.0001 ): break diff --git a/XOR_MNIST/example/xor_main.py b/XOR_MNIST/example/xor_main.py index c8707a1..cb6d9c0 100644 --- a/XOR_MNIST/example/xor_main.py +++ b/XOR_MNIST/example/xor_main.py @@ -1,24 +1,29 @@ -import torch -import numpy as np +import os +import sys from argparse import ArgumentParser -import os, sys +import numpy as np +import torch import wandb - from example.dpl_train import train_DPL, train_DPL_REC -from example.sl_train import train_SL, train_SL_R from example.ltn_train import train_LTN, train_LTN_R -from example.xor_utils import show_cf, set_random_seed +from example.sl_train import train_SL, train_SL_R +from example.xor_utils import set_random_seed, show_cf def prepare_args(): parser = ArgumentParser() parser.add_argument( - "--model", type=str, default="dpl", choices=["dpl", "sl", "ltn"] + "--model", + type=str, + default="dpl", + choices=["dpl", "sl", "ltn"], ) # mitigation strategies - parser.add_argument("--rec", action="store_true", help="Activate RECON") + parser.add_argument( + "--rec", action="store_true", help="Activate RECON" + ) parser.add_argument( "--csup", type=int, @@ -32,25 +37,44 @@ def prepare_args(): help="How many concept are supervised in order.", ) parser.add_argument( - "--disent", action="store_true", default=False, help="Activate disentaglement" + "--disent", + action="store_true", + default=False, + help="Activate disentaglement", ) parser.add_argument( - "--s_w", action="store_true", default=False, help="Activate shared weights" + "--s_w", + action="store_true", + default=False, + help="Activate shared weights", ) # hyperparams - parser.add_argument("--lr", type=float, default=0.05, help="Learning rate.") parser.add_argument( - "--gamma", type=float, default=10, help="Weight of Mitigations." + "--lr", type=float, default=0.05, help="Learning rate." + ) + parser.add_argument( + "--gamma", + type=float, + default=10, + help="Weight of Mitigations.", ) # setup - parser.add_argument("--seed", type=int, default=42, help="Set random seed.") parser.add_argument( - "--wandb", action="store_true", default=False, help="Enable log in wandb" + "--seed", type=int, default=42, help="Set random seed." + ) + parser.add_argument( + "--wandb", + action="store_true", + default=False, + help="Enable log in wandb", ) parser.add_argument( - "--project", type=str, default="XOR", help="Enable log in wandb" + "--project", + type=str, + default="XOR", + help="Enable log in wandb", ) args = parser.parse_args() return args @@ -83,7 +107,10 @@ def xor_run(args, _plot=False): if args.wandb: print("\n---wandb on\n") wandb.init( - project=args.project, entity="yours", name=str(args.model), config=args + project=args.project, + entity="yours", + name=str(args.model), + config=args, ) print( diff --git a/XOR_MNIST/example/xor_metrics.py b/XOR_MNIST/example/xor_metrics.py index b75f93e..24e4df4 100644 --- a/XOR_MNIST/example/xor_metrics.py +++ b/XOR_MNIST/example/xor_metrics.py @@ -1,6 +1,5 @@ import numpy as np import torch - from sklearn.metrics import accuracy_score @@ -14,4 +13,8 @@ def extract_statistics(G, Y, C, pred): ypred = np.around(pred) - return accuracy_score(Y, ypred), accuracy_score(gt, cpred), [gt, Y, cpred, ypred] + return ( + accuracy_score(Y, ypred), + accuracy_score(gt, cpred), + [gt, Y, cpred, ypred], + ) diff --git a/XOR_MNIST/example/xor_utils.py b/XOR_MNIST/example/xor_utils.py index 2724d45..41b5c03 100644 --- a/XOR_MNIST/example/xor_utils.py +++ b/XOR_MNIST/example/xor_utils.py @@ -1,19 +1,19 @@ import itertools -import torch -import numpy as np -import matplotlib.pyplot as plt -import sys import random -from typing import Union +import sys from datetime import datetime -from sklearn.metrics import accuracy_score, confusion_matrix +from typing import Union + +import matplotlib.pyplot as plt +import numpy as np +import torch import torch.nn.functional as F import wandb - -from example.xor_metrics import extract_statistics from example.dpl_models import DPL, DPL_R -from example.sl_models import SL, SL_R from example.ltn_models import LTN, LTN_R +from example.sl_models import SL, SL_R +from example.xor_metrics import extract_statistics +from sklearn.metrics import accuracy_score, confusion_matrix def set_random_seed(seed: int) -> None: @@ -27,7 +27,9 @@ def set_random_seed(seed: int) -> None: torch.cuda.manual_seed_all(seed) -def progress_bar(i: int, max_iter: int, epoch: Union[int, str], loss: float) -> None: +def progress_bar( + i: int, max_iter: int, epoch: Union[int, str], loss: float +) -> None: """ Prints out the progress bar on the stderr file. :param i: the current iteration @@ -38,7 +40,9 @@ def progress_bar(i: int, max_iter: int, epoch: Union[int, str], loss: float) -> """ # if not (i + 1) % 10 or (i + 1) == max_iter: progress = min(float((i + 1) / max_iter), 1) - progress_bar = ("█" * int(50 * progress)) + ("┈" * (50 - int(50 * progress))) + progress_bar = ("█" * int(50 * progress)) + ( + "┈" * (50 - int(50 * progress)) + ) print( "\r[ {} ] epoch {}: |{}| loss: {}".format( datetime.now().strftime("%m-%d | %H:%M"), @@ -60,10 +64,18 @@ def show_cf(model, G, Y, args, _plot=False): elif isinstance(model, DPL_R): preds, recon, cs = model(G) print("Recon Loss", F.binary_cross_entropy(recon, G).item()) - for i, j, k in itertools.product(range(2), range(2), range(2)): - C = torch.tensor([i, j, k], dtype=torch.float, device=model.device) + for i, j, k in itertools.product( + range(2), range(2), range(2) + ): + C = torch.tensor( + [i, j, k], dtype=torch.float, device=model.device + ) logitC = model.encoder( - torch.tensor([[i, j, k]], dtype=torch.float, device=model.device) + torch.tensor( + [[i, j, k]], + dtype=torch.float, + device=model.device, + ) )[1].detach() pC = logitC.sigmoid() @@ -73,10 +85,18 @@ def show_cf(model, G, Y, args, _plot=False): elif isinstance(model, SL_R): preds, recon, cs, pC = model(G) print("Recon Loss", F.binary_cross_entropy(recon, G).item()) - for i, j, k in itertools.product(range(2), range(2), range(2)): - C = torch.tensor([i, j, k], dtype=torch.float, device=model.device) + for i, j, k in itertools.product( + range(2), range(2), range(2) + ): + C = torch.tensor( + [i, j, k], dtype=torch.float, device=model.device + ) logitC = model.encoder( - torch.tensor([[i, j, k]], dtype=torch.float, device=model.device) + torch.tensor( + [[i, j, k]], + dtype=torch.float, + device=model.device, + ) )[1].detach() pC = logitC.sigmoid() @@ -86,10 +106,18 @@ def show_cf(model, G, Y, args, _plot=False): elif isinstance(model, LTN_R): preds, recon, cs, pC = model(G) print("Recon Loss", F.binary_cross_entropy(recon, G).item()) - for i, j, k in itertools.product(range(2), range(2), range(2)): - C = torch.tensor([i, j, k], dtype=torch.float, device=model.device) + for i, j, k in itertools.product( + range(2), range(2), range(2) + ): + C = torch.tensor( + [i, j, k], dtype=torch.float, device=model.device + ) logitC = model.encoder( - torch.tensor([[i, j, k]], dtype=torch.float, device=model.device) + torch.tensor( + [[i, j, k]], + dtype=torch.float, + device=model.device, + ) )[1].detach() pC = logitC.sigmoid() @@ -118,14 +146,20 @@ def show_cf(model, G, Y, args, _plot=False): wandb.log( { "cf-labels": wandb.plot.confusion_matrix( - None, Ys, ypred, class_names=[str(i) for i in range(2)] + None, + Ys, + ypred, + class_names=[str(i) for i in range(2)], ) } ) wandb.log( { "cf-concepts": wandb.plot.confusion_matrix( - None, Gs, cpred, class_names=[str(i) for i in range(8)] + None, + Gs, + cpred, + class_names=[str(i) for i in range(8)], ) } ) diff --git a/XOR_MNIST/experiments.py b/XOR_MNIST/experiments.py index b600a6d..b5428d0 100644 --- a/XOR_MNIST/experiments.py +++ b/XOR_MNIST/experiments.py @@ -1,7 +1,8 @@ # Experiments module # Contains args for different experiments -import itertools import copy +import itertools + from exp_best_args import * @@ -112,7 +113,9 @@ def launch_XOR(args): args_list = [] for element in itertools.product(*hyperparameters): args1 = copy.copy(args) - args1.model, args1.rec, args1.entropy, args.gamma, args.lr = element + args1.model, args1.rec, args1.entropy, args.gamma, args.lr = ( + element + ) print(args1, "\n") args_list.append(args1) return args_list @@ -142,7 +145,9 @@ def launch_XOR_exp1(args): args_list = [] for element in itertools.product(*hyperparameters): args1 = copy.copy(args) - args1.model, args1.disent, args1.s_w, args1.seed, args1.lr = element + args1.model, args1.disent, args1.s_w, args1.seed, args1.lr = ( + element + ) if args1.disent or not args1.s_w: args1 = set_best_args_XOR(args1) diff --git a/XOR_MNIST/main.py b/XOR_MNIST/main.py index cf5345d..ef92065 100644 --- a/XOR_MNIST/main.py +++ b/XOR_MNIST/main.py @@ -1,22 +1,25 @@ # This is the main module # It provides an overview of the program purpose and functionality. -import sys, os -import torch import argparse -import importlib -import setproctitle, socket, uuid import datetime +import importlib +import os +import signal +import socket +import sys +import uuid +import setproctitle +import torch from datasets import get_dataset from models import get_model -from utils.train import train, train_active -from utils.test import test -from utils.preprocess_resnet import preprocess -from utils.conf import * -import signal from utils.args import * -from utils.checkpoint import save_model, create_load_ckpt +from utils.checkpoint import create_load_ckpt, save_model +from utils.conf import * +from utils.preprocess_resnet import preprocess +from utils.test import test +from utils.train import train, train_active conf_path = os.getcwd() + "." sys.path.append(conf_path) @@ -34,7 +37,9 @@ def __init__(self): Returns: None: This function does not return a value. """ - super().__init__("External signal received: forcing termination") + super().__init__( + "External signal received: forcing termination" + ) def __handle_signal(signum: int, frame): @@ -83,7 +88,8 @@ def parse_args(): parser.add_argument( "--load_best_args", action="store_true", - help="Loads the best arguments for each method, " "dataset and memory buffer.", + help="Loads the best arguments for each method, " + "dataset and memory buffer.", ) torch.set_num_threads(4) @@ -96,13 +102,20 @@ def parse_args(): get_parser = getattr(mod, "get_parser") parser = get_parser() parser.add_argument( - "--project", type=str, default="Reasoning-Shortcuts", help="wandb project" + "--project", + type=str, + default="Reasoning-Shortcuts", + help="wandb project", ) add_test_args(parser) args = parser.parse_args() # this is the return # load args related to seed etc. - set_random_seed(args.seed) if args.seed is not None else set_random_seed(42) + ( + set_random_seed(args.seed) + if args.seed is not None + else set_random_seed(42) + ) return args @@ -136,7 +149,9 @@ def main(args): # set job name setproctitle.setproctitle( "{}_{}_{}".format( - args.model, args.buffer_size if "buffer_size" in args else 0, args.dataset + args.model, + args.buffer_size if "buffer_size" in args else 0, + args.dataset, ) ) @@ -149,7 +164,9 @@ def main(args): quit() if args.posthoc: - test(model, dataset, args) # test the model if post-hoc is passed + test( + model, dataset, args + ) # test the model if post-hoc is passed elif args.active_learning: train_active( model, dataset, loss, args diff --git a/XOR_MNIST/models/__init__.py b/XOR_MNIST/models/__init__.py index f21825f..f141343 100644 --- a/XOR_MNIST/models/__init__.py +++ b/XOR_MNIST/models/__init__.py @@ -1,5 +1,5 @@ -import os import importlib +import os def get_all_models(): @@ -13,13 +13,17 @@ def get_all_models(): names = {} for model in get_all_models(): mod = importlib.import_module("models." + model) - class_name = {x.lower(): x for x in mod.__dir__()}[model.replace("_", "")] + class_name = {x.lower(): x for x in mod.__dir__()}[ + model.replace("_", "") + ] names[model] = getattr(mod, class_name) def get_model(args, encoder, decoder, n_images, c_split): if args.model == "cext": - return names[args.model](encoder, n_images=n_images, c_split=c_split) + return names[args.model]( + encoder, n_images=n_images, c_split=c_split + ) elif args.model in [ "mnistdpl", "mnistsl", @@ -36,5 +40,9 @@ def get_model(args, encoder, decoder, n_images, c_split): ) # only discriminative else: return names[args.model]( - encoder, decoder, n_images=n_images, c_split=c_split, args=args + encoder, + decoder, + n_images=n_images, + c_split=c_split, + args=args, ) diff --git a/XOR_MNIST/models/cext.py b/XOR_MNIST/models/cext.py index 3671c2e..1c673b8 100644 --- a/XOR_MNIST/models/cext.py +++ b/XOR_MNIST/models/cext.py @@ -1,12 +1,14 @@ import torch from models.utils.deepproblog_modules import DeepProblogModel -from utils.losses import ADDMNIST_Concept_Match, KAND_Concept_Match from utils.args import * from utils.conf import get_device +from utils.losses import ADDMNIST_Concept_Match, KAND_Concept_Match def get_parser() -> ArgumentParser: - parser = ArgumentParser(description="Learning via" "Concept Extractor .") + parser = ArgumentParser( + description="Learning via" "Concept Extractor ." + ) add_management_args(parser) add_experiment_args(parser) return parser @@ -15,7 +17,9 @@ def get_parser() -> ArgumentParser: class CExt(torch.nn.Module): NAME = "cext" - def __init__(self, encoder, n_images=1, c_split=()): # c_dim=20, latent_dim=0): + def __init__( + self, encoder, n_images=1, c_split=() + ): # c_dim=20, latent_dim=0): super(CExt, self).__init__() # bones of the model diff --git a/XOR_MNIST/models/cvae.py b/XOR_MNIST/models/cvae.py index 30bd5bb..2f9c1bd 100644 --- a/XOR_MNIST/models/cvae.py +++ b/XOR_MNIST/models/cvae.py @@ -1,13 +1,15 @@ import torch import torch.nn.functional as F from models.utils.deepproblog_modules import DeepProblogModel -from utils.losses import ADDMNIST_rec_class from utils.args import * from utils.conf import get_device +from utils.losses import ADDMNIST_rec_class def get_parser() -> ArgumentParser: - parser = ArgumentParser(description="Variational-Conceptual Autoencoders.") + parser = ArgumentParser( + description="Variational-Conceptual Autoencoders." + ) add_management_args(parser) add_experiment_args(parser) return parser @@ -16,7 +18,9 @@ def get_parser() -> ArgumentParser: class CVAE(torch.nn.Module): NAME = "cvae" - def __init__(self, encoder=None, decoder=None, n_images=1, c_split=()): + def __init__( + self, encoder=None, decoder=None, n_images=1, c_split=() + ): super(CVAE, self).__init__() # bones of model @@ -52,7 +56,10 @@ def forward(self, x): for cdim in self.c_split: hard_cs.append( F.gumbel_softmax( - c[:, index : index + cdim], tau=1, hard=True, dim=-1 + c[:, index : index + cdim], + tau=1, + hard=True, + dim=-1, ) ) index += cdim @@ -67,7 +74,12 @@ def forward(self, x): mus = torch.cat(mus, dim=-1) logvars = torch.cat(logvars, dim=-1) recs = torch.cat(recs, dim=-1) - return {"RECS": recs, "CS": cs, "MUS": mus, "LOGVARS": logvars} + return { + "RECS": recs, + "CS": cs, + "MUS": mus, + "LOGVARS": logvars, + } @staticmethod def get_loss(args): diff --git a/XOR_MNIST/models/kanddpl.py b/XOR_MNIST/models/kanddpl.py index 492b6cb..9fbdb0e 100644 --- a/XOR_MNIST/models/kanddpl.py +++ b/XOR_MNIST/models/kanddpl.py @@ -1,12 +1,12 @@ # Kandinksy DPL import torch from models.utils.deepproblog_modules import DeepProblogModel +from models.utils.ops import outer_product +from models.utils.utils_problog import * from utils.args import * from utils.conf import get_device -from models.utils.utils_problog import * -from utils.losses import * from utils.dpl_loss import KAND_DPL -from models.utils.ops import outer_product +from utils.losses import * def get_parser() -> ArgumentParser: @@ -15,7 +15,9 @@ def get_parser() -> ArgumentParser: Returns: argparse: argument parser """ - parser = ArgumentParser(description="Learning via" "Concept Extractor .") + parser = ArgumentParser( + description="Learning via" "Concept Extractor ." + ) add_management_args(parser) add_experiment_args(parser) return parser @@ -117,12 +119,24 @@ def forward(self, x, activate_simple_concepts=False): clen = len(cs[0].shape) - cs = torch.stack(cs, dim=1) if clen > 1 else torch.cat(cs, dim=1) - pCs = torch.stack(pCs, dim=1) if clen > 1 else torch.cat(pCs, dim=1) + cs = ( + torch.stack(cs, dim=1) + if clen > 1 + else torch.cat(cs, dim=1) + ) + pCs = ( + torch.stack(pCs, dim=1) + if clen > 1 + else torch.cat(pCs, dim=1) + ) py = self.combine_queries(preds) - preds = torch.stack(preds, dim=1) if clen > 1 else torch.cat(preds, dim=1) + preds = ( + torch.stack(preds, dim=1) + if clen > 1 + else torch.cat(preds, dim=1) + ) # Problog inference to compute worlds and query probability distributions # py, worlds_prob = self.problog_inference(pCs) @@ -152,11 +166,15 @@ def problog_inference(self, pCs, query=None): worlds_prob = worlds_tensor.reshape(-1, 3**self.n_facts) # Compute query probability P(q) - query_prob = torch.zeros(size=(len(pCs), self.n_predicates), device=pCs.device) + query_prob = torch.zeros( + size=(len(pCs), self.n_predicates), device=pCs.device + ) for i in range(self.n_predicates): query = i - query_prob[:, i] = self.compute_query(query, worlds_prob).view(-1) + query_prob[:, i] = self.compute_query( + query, worlds_prob + ).view(-1) # add a small offset # query_prob += 1e-5 @@ -178,11 +196,16 @@ def combine_queries(self, preds): """ y_worlds = outer_product(*preds).reshape(-1, 9**self.n_images) - py = torch.zeros(size=(len(preds[0]), self.nr_classes), device=preds[0].device) + py = torch.zeros( + size=(len(preds[0]), self.nr_classes), + device=preds[0].device, + ) for i in range(self.nr_classes): and_rule = self.and_rule[:, i] - query_prob = torch.sum(and_rule * y_worlds, dim=1, keepdim=True) + query_prob = torch.sum( + and_rule * y_worlds, dim=1, keepdim=True + ) py[:, i] = query_prob.view(-1) @@ -230,7 +253,9 @@ def soft_clamp(h, dim=-1): return h # TODO: the 3 here is hardcoded, relax to arbitrary concept encodings? - pCi = torch.split(z, 3, dim=-1) # [batch_size, 24] -> [8, batch_size, 3] + pCi = torch.split( + z, 3, dim=-1 + ) # [batch_size, 24] -> [8, batch_size, 3] norm_concepts = torch.cat( [soft_clamp(c) for c in pCi], dim=-1 @@ -251,7 +276,11 @@ def get_loss(args): Raises: err: NotImplementedError """ - if args.dataset in ["kandinsky", "prekandinsky", "minikandinsky"]: + if args.dataset in [ + "kandinsky", + "prekandinsky", + "minikandinsky", + ]: return KAND_DPL(KAND_Cumulative) else: return NotImplementedError("Wrong dataset choice") diff --git a/XOR_MNIST/models/kandpreprocess.py b/XOR_MNIST/models/kandpreprocess.py index d43578a..e88f45f 100644 --- a/XOR_MNIST/models/kandpreprocess.py +++ b/XOR_MNIST/models/kandpreprocess.py @@ -1,8 +1,8 @@ import torch from models.utils.deepproblog_modules import DeepProblogModel +from models.utils.utils_problog import * from utils.args import * from utils.conf import get_device -from models.utils.utils_problog import * from utils.losses import * @@ -12,7 +12,9 @@ def get_parser() -> ArgumentParser: Returns: argparse: argument parser """ - parser = ArgumentParser(description="Learning via" "Concept Extractor .") + parser = ArgumentParser( + description="Learning via" "Concept Extractor ." + ) add_management_args(parser) add_experiment_args(parser) return parser @@ -89,7 +91,11 @@ def forward(self, x): clen = len(cs[0].shape) - embs = torch.stack(cs, dim=1) if clen > 1 else torch.cat(cs, dim=1) + embs = ( + torch.stack(cs, dim=1) + if clen > 1 + else torch.cat(cs, dim=1) + ) return {"EMBS": embs} diff --git a/XOR_MNIST/models/minikanddpl.py b/XOR_MNIST/models/minikanddpl.py index 9f4de7d..36333fb 100644 --- a/XOR_MNIST/models/minikanddpl.py +++ b/XOR_MNIST/models/minikanddpl.py @@ -1,12 +1,12 @@ # MINI KANDINKSY for DPL import torch from models.utils.deepproblog_modules import DeepProblogModel +from models.utils.ops import outer_product +from models.utils.utils_problog import * from utils.args import * from utils.conf import get_device -from models.utils.utils_problog import * -from utils.losses import * from utils.dpl_loss import KAND_DPL -from models.utils.ops import outer_product +from utils.losses import * def get_parser() -> ArgumentParser: @@ -15,7 +15,9 @@ def get_parser() -> ArgumentParser: Returns: argpars: argument parser """ - parser = ArgumentParser(description="Learning via" "Concept Extractor .") + parser = ArgumentParser( + description="Learning via" "Concept Extractor ." + ) add_management_args(parser) add_experiment_args(parser) return parser @@ -68,8 +70,10 @@ def __init__( # Worlds-queries matrix # if args.task == 'base': self.n_facts = 6 - self.w_q, self.and_rule, self.or_rule = build_worlds_queries_matrix_KAND( - self.n_images, self.n_facts, 3, task=args.task + self.w_q, self.and_rule, self.or_rule = ( + build_worlds_queries_matrix_KAND( + self.n_images, self.n_facts, 3, task=args.task + ) ) self.n_predicates = 3 self.nr_classes = 2 @@ -114,24 +118,46 @@ def forward(self, x, activate_simple_concepts=False): shapes_prob, colors_prob = self.problog_inference(pc) - cs.append(lc), pCs.append(pc), spreds.append(shapes_prob), cpreds.append( - colors_prob - ) + cs.append(lc), pCs.append(pc), spreds.append( + shapes_prob + ), cpreds.append(colors_prob) clen = len(cs[0].shape) - cs = torch.stack(cs, dim=1) if clen > 1 else torch.cat(cs, dim=1) - pCs = torch.stack(pCs, dim=1) if clen > 1 else torch.cat(pCs, dim=1) + cs = ( + torch.stack(cs, dim=1) + if clen > 1 + else torch.cat(cs, dim=1) + ) + pCs = ( + torch.stack(pCs, dim=1) + if clen > 1 + else torch.cat(pCs, dim=1) + ) py = self.combine_queries(spreds, cpreds) - spreds = torch.stack(spreds, dim=1) if clen > 1 else torch.cat(spreds, dim=1) - cpreds = torch.stack(cpreds, dim=1) if clen > 1 else torch.cat(cpreds, dim=1) + spreds = ( + torch.stack(spreds, dim=1) + if clen > 1 + else torch.cat(spreds, dim=1) + ) + cpreds = ( + torch.stack(cpreds, dim=1) + if clen > 1 + else torch.cat(cpreds, dim=1) + ) # Problog inference to compute worlds and query probability distributions # py, worlds_prob = self.problog_inference(pCs) - return {"CS": cs, "YS": py, "pCS": pCs, "sPREDS": spreds, "cPREDS": cpreds} + return { + "CS": cs, + "YS": py, + "pCS": pCs, + "sPREDS": spreds, + "cPREDS": cpreds, + } def problog_inference(self, pCs, query=None): """Problog inference @@ -157,8 +183,12 @@ def problog_inference(self, pCs, query=None): *all_c_s[3:] # [batch_size, 1, 3*8] -> [6, batch_size, 3] ) # [6, batch_size, 3] -> [batch_size, 3,3,3, 3,3,3] - shapes_prob = shapes_worlds_tensor.reshape(-1, 3 ** (self.n_facts // 2)) - colors_prob = colors_worlds_tensor.reshape(-1, 3 ** (self.n_facts // 2)) + shapes_prob = shapes_worlds_tensor.reshape( + -1, 3 ** (self.n_facts // 2) + ) + colors_prob = colors_worlds_tensor.reshape( + -1, 3 ** (self.n_facts // 2) + ) # Compute query probability shapes_query_prob = torch.zeros( @@ -170,8 +200,12 @@ def problog_inference(self, pCs, query=None): for i in range(self.n_predicates): query = i - shapes_query_prob[:, i] = self.compute_query(query, shapes_prob).view(-1) - colors_query_prob[:, i] = self.compute_query(query, colors_prob).view(-1) + shapes_query_prob[:, i] = self.compute_query( + query, shapes_prob + ).view(-1) + colors_query_prob[:, i] = self.compute_query( + query, colors_prob + ).view(-1) # shapes_check = torch.zeros(size=(len(pCs), self.nr_classes), device=pCs.device) # colors_check = torch.zeros(size=(len(pCs), self.nr_classes), device=pCs.device) @@ -195,20 +229,30 @@ def combine_queries(self, spreds, cpreds): Returns: py: pattern prediction """ - s_worlds = outer_product(*spreds).reshape(-1, 3**self.n_images) - c_worlds = outer_product(*cpreds).reshape(-1, 3**self.n_images) + s_worlds = outer_product(*spreds).reshape( + -1, 3**self.n_images + ) + c_worlds = outer_product(*cpreds).reshape( + -1, 3**self.n_images + ) ps = torch.zeros( - size=(len(spreds[0]), self.nr_classes), device=spreds[0].device + size=(len(spreds[0]), self.nr_classes), + device=spreds[0].device, ) pc = torch.zeros( - size=(len(cpreds[0]), self.nr_classes), device=cpreds[0].device + size=(len(cpreds[0]), self.nr_classes), + device=cpreds[0].device, ) for i in range(self.nr_classes): and_rule = self.and_rule[:, i] - prob_s = torch.sum(and_rule * s_worlds, dim=1, keepdim=True) - prob_c = torch.sum(and_rule * c_worlds, dim=1, keepdim=True) + prob_s = torch.sum( + and_rule * s_worlds, dim=1, keepdim=True + ) + prob_c = torch.sum( + and_rule * c_worlds, dim=1, keepdim=True + ) ps[:, i] = prob_s.view(-1) pc[:, i] = prob_c.view(-1) @@ -216,12 +260,15 @@ def combine_queries(self, spreds, cpreds): total_prob = outer_product(ps, pc).reshape(-1, 4) py = torch.zeros( - size=(len(spreds[0]), self.nr_classes), device=spreds[0].device + size=(len(spreds[0]), self.nr_classes), + device=spreds[0].device, ) for i in range(self.nr_classes): or_rule = self.or_rule[:, i] - query_prob = torch.sum(or_rule * total_prob, dim=1, keepdim=True) + query_prob = torch.sum( + or_rule * total_prob, dim=1, keepdim=True + ) py[:, i] = query_prob.view(-1) @@ -269,7 +316,9 @@ def soft_clamp(h, dim=-1): return h # TODO: the 3 here is hardcoded, relax to arbitrary concept encodings? - pCi = torch.split(z, 3, dim=-1) # [batch_size, 24] -> [8, batch_size, 3] + pCi = torch.split( + z, 3, dim=-1 + ) # [batch_size, 24] -> [8, batch_size, 3] norm_concepts = torch.cat( [soft_clamp(c) for c in pCi], dim=-1 @@ -290,7 +339,11 @@ def get_loss(args): Raises: err: NotImplementedError if loss is not specified """ - if args.dataset in ["kandinsky", "prekandinsky", "minikandinsky"]: + if args.dataset in [ + "kandinsky", + "prekandinsky", + "minikandinsky", + ]: return KAND_DPL(KAND_Cumulative) else: return NotImplementedError("Wrong dataset choice") diff --git a/XOR_MNIST/models/mnistdpl.py b/XOR_MNIST/models/mnistdpl.py index a64c51a..790492c 100644 --- a/XOR_MNIST/models/mnistdpl.py +++ b/XOR_MNIST/models/mnistdpl.py @@ -1,12 +1,14 @@ # DPL model for MNIST import torch -from models.utils.deepproblog_modules import DeepProblogModel +from models.utils.deepproblog_modules import ( + DeepProblogModel, + GraphSemiring, +) +from models.utils.utils_problog import * from utils.args import * from utils.conf import get_device -from models.utils.deepproblog_modules import GraphSemiring -from models.utils.utils_problog import * -from utils.losses import * from utils.dpl_loss import ADDMNIST_DPL +from utils.losses import * def get_parser() -> ArgumentParser: @@ -15,7 +17,9 @@ def get_parser() -> ArgumentParser: Returns: argparse: argument parser """ - parser = ArgumentParser(description="Learning via" "Concept Extractor .") + parser = ArgumentParser( + description="Learning via" "Concept Extractor ." + ) add_management_args(parser) add_experiment_args(parser) return parser @@ -68,19 +72,31 @@ def __init__( # Worlds-queries matrix if args.task == "addition": self.n_facts = ( - 10 if not args.dataset in ["halfmnist", "restrictedmnist"] else 5 + 10 + if not args.dataset + in ["halfmnist", "restrictedmnist"] + else 5 + ) + self.w_q = build_worlds_queries_matrix( + 2, self.n_facts, "addmnist" ) - self.w_q = build_worlds_queries_matrix(2, self.n_facts, "addmnist") self.nr_classes = 19 elif args.task == "product": self.n_facts = ( - 10 if not args.dataset in ["halfmnist", "restrictedmnist"] else 5 + 10 + if not args.dataset + in ["halfmnist", "restrictedmnist"] + else 5 + ) + self.w_q = build_worlds_queries_matrix( + 2, self.n_facts, "productmnist" ) - self.w_q = build_worlds_queries_matrix(2, self.n_facts, "productmnist") self.nr_classes = 37 elif args.task == "multiop": self.n_facts = 5 - self.w_q = build_worlds_queries_matrix(2, self.n_facts, "multiopmnist") + self.w_q = build_worlds_queries_matrix( + 2, self.n_facts, "multiopmnist" + ) self.nr_classes = 3 # opt and device @@ -106,7 +122,11 @@ def forward(self, x): cs.append(lc) clen = len(cs[0].shape) - cs = torch.stack(cs, dim=1) if clen == 2 else torch.cat(cs, dim=1) + cs = ( + torch.stack(cs, dim=1) + if clen == 2 + else torch.cat(cs, dim=1) + ) # normalize concept preditions pCs = self.normalize_concepts(cs) @@ -152,7 +172,9 @@ def problog_inference(self, pCs, query=None): for i in range(self.nr_classes): query = i - query_prob[:, i] = self.compute_query(query, worlds_prob).view(-1) + query_prob[:, i] = self.compute_query( + query, worlds_prob + ).view(-1) # add a small offset query_prob += 1e-5 @@ -213,7 +235,9 @@ def normalize_concepts(self, z, split=2): Z2 = torch.sum(prob_digit2, dim=-1, keepdim=True) prob_digit2 = prob_digit2 / Z2 # Normalization - return torch.stack([prob_digit1, prob_digit2], dim=1).view(-1, 2, self.n_facts) + return torch.stack([prob_digit1, prob_digit2], dim=1).view( + -1, 2, self.n_facts + ) @staticmethod def get_loss(args): @@ -228,7 +252,12 @@ def get_loss(args): Raises: err: NotImplementedError if the loss function is not available """ - if args.dataset in ["addmnist", "shortmnist", "restrictedmnist", "halfmnist"]: + if args.dataset in [ + "addmnist", + "shortmnist", + "restrictedmnist", + "halfmnist", + ]: return ADDMNIST_DPL(ADDMNIST_Cumulative) else: return NotImplementedError("Wrong dataset choice") diff --git a/XOR_MNIST/models/mnistdplrec.py b/XOR_MNIST/models/mnistdplrec.py index 15b149a..759d952 100644 --- a/XOR_MNIST/models/mnistdplrec.py +++ b/XOR_MNIST/models/mnistdplrec.py @@ -1,11 +1,13 @@ # MNIST DPL with Reconstruction import torch -from models.utils.deepproblog_modules import DeepProblogModel -from utils.dpl_loss import ADDMNIST_DPL +from models.utils.deepproblog_modules import ( + DeepProblogModel, + GraphSemiring, +) +from models.utils.utils_problog import * from utils.args import * from utils.conf import get_device -from models.utils.deepproblog_modules import GraphSemiring -from models.utils.utils_problog import * +from utils.dpl_loss import ADDMNIST_DPL from utils.losses import * @@ -15,7 +17,9 @@ def get_parser() -> ArgumentParser: Returns: argpars: argument parser """ - parser = ArgumentParser(description="Learning via" "Concept Extractor .") + parser = ArgumentParser( + description="Learning via" "Concept Extractor ." + ) add_management_args(parser) add_experiment_args(parser) return parser @@ -107,7 +111,11 @@ def forward(self, x): latents.append((mu + eps * logvar.exp()).view(L, -1)) for i in range(len(self.c_split)): - latents.append(F.gumbel_softmax(c[:, i, :], tau=1, hard=True, dim=-1)) + latents.append( + F.gumbel_softmax( + c[:, i, :], tau=1, hard=True, dim=-1 + ) + ) latents = torch.cat(latents, dim=1) @@ -116,10 +124,20 @@ def forward(self, x): # return everything clen = len(cs[0].shape) - cs = torch.stack(cs, dim=1) if clen == 2 else torch.cat(cs, dim=1) - mus = torch.stack(mus, dim=-1) if clen == 2 else torch.cat(mus, dim=1) + cs = ( + torch.stack(cs, dim=1) + if clen == 2 + else torch.cat(cs, dim=1) + ) + mus = ( + torch.stack(mus, dim=-1) + if clen == 2 + else torch.cat(mus, dim=1) + ) logvars = ( - torch.stack(logvars, dim=-1) if clen == 2 else torch.cat(logvars, dim=1) + torch.stack(logvars, dim=-1) + if clen == 2 + else torch.cat(logvars, dim=1) ) # normalize concept preditions @@ -164,7 +182,9 @@ def problog_inference(self, pCs, query=None): # print(probs.size()) - worlds_prob = probs.reshape(-1, self.c_split[0] * self.c_split[0]) + worlds_prob = probs.reshape( + -1, self.c_split[0] * self.c_split[0] + ) # Compute query probability P(q) query_prob = torch.zeros( @@ -173,7 +193,9 @@ def problog_inference(self, pCs, query=None): for i in range(self.nr_classes): query = i - query_prob[:, i] = self.compute_query(query, worlds_prob).view(-1) + query_prob[:, i] = self.compute_query( + query, worlds_prob + ).view(-1) # add a small offset query_prob += 1e-5 @@ -234,7 +256,9 @@ def normalize_concepts(self, z, split=2): Z2 = torch.sum(prob_digit2, dim=-1, keepdim=True) prob_digit2 = prob_digit2 / Z2 # Normalization - return torch.stack([prob_digit1, prob_digit2], dim=1).view(-1, 2, 10) + return torch.stack([prob_digit1, prob_digit2], dim=1).view( + -1, 2, 10 + ) @staticmethod def get_loss(args): diff --git a/XOR_MNIST/models/mnistltn.py b/XOR_MNIST/models/mnistltn.py index 2fecd8f..68de66a 100644 --- a/XOR_MNIST/models/mnistltn.py +++ b/XOR_MNIST/models/mnistltn.py @@ -1,10 +1,10 @@ # LTN architecture for MNIST import torch +from models.cext import CExt +from models.utils.utils_problog import build_worlds_queries_matrix from utils.args import * from utils.conf import get_device from utils.losses import * -from models.cext import CExt -from models.utils.utils_problog import build_worlds_queries_matrix from utils.ltn_loss import ADDMNIST_SAT_AGG @@ -14,7 +14,9 @@ def get_parser() -> ArgumentParser: Returns: argpars: argument parser """ - parser = ArgumentParser(description="Learning via" "Concept Extractor .") + parser = ArgumentParser( + description="Learning via" "Concept Extractor ." + ) add_management_args(parser) add_experiment_args(parser) return parser @@ -49,17 +51,26 @@ def __init__(self, encoder, n_images=2, c_split=(), args=None): if args.task == "addition": self.n_facts = ( - 10 if not args.dataset in ["halfmnist", "restrictedmnist"] else 5 + 10 + if not args.dataset + in ["halfmnist", "restrictedmnist"] + else 5 ) self.nr_classes = 19 elif args.task == "product": self.n_facts = ( - 10 if not args.dataset in ["halfmnist", "restrictedmnist"] else 5 + 10 + if not args.dataset + in ["halfmnist", "restrictedmnist"] + else 5 ) self.nr_classes = 37 elif args.task == "multiop": self.n_facts = ( - 10 if not args.dataset in ["halfmnist", "restrictedmnist"] else 5 + 10 + if not args.dataset + in ["halfmnist", "restrictedmnist"] + else 5 ) self.nr_classes = 3 @@ -85,7 +96,11 @@ def forward(self, x): lc, _, _ = self.encoder(xs[i]) # sizes are ok cs.append(lc) clen = len(cs[0].shape) - cs = torch.stack(cs, dim=1) if clen == 2 else torch.cat(cs, dim=1) + cs = ( + torch.stack(cs, dim=1) + if clen == 2 + else torch.cat(cs, dim=1) + ) # normalize concept preditions pCs = self.normalize_concepts(cs) @@ -143,7 +158,9 @@ def normalize_concepts(self, z, split=2): Z2 = torch.sum(prob_digit2, dim=-1, keepdim=True) prob_digit2 = prob_digit2 / Z2 # Normalization - return torch.stack([prob_digit1, prob_digit2], dim=1).view(-1, 2, self.n_facts) + return torch.stack([prob_digit1, prob_digit2], dim=1).view( + -1, 2, self.n_facts + ) def get_loss(self, args): """Returns the loss function for the architecture @@ -158,7 +175,12 @@ def get_loss(self, args): Raises: err: NotImplementedError if loss is not implemented """ - if args.dataset in ["addmnist", "shortmnist", "restrictedmnist", "halfmnist"]: + if args.dataset in [ + "addmnist", + "shortmnist", + "restrictedmnist", + "halfmnist", + ]: return ADDMNIST_SAT_AGG(ADDMNIST_Cumulative, self.task) else: return NotImplementedError("Wrong dataset choice") diff --git a/XOR_MNIST/models/mnistltnrec.py b/XOR_MNIST/models/mnistltnrec.py index 7b61ed8..8232233 100644 --- a/XOR_MNIST/models/mnistltnrec.py +++ b/XOR_MNIST/models/mnistltnrec.py @@ -1,10 +1,10 @@ # Mnist LTN with Reconstruction module import torch +from models.cext import CExt +from models.utils.utils_problog import build_worlds_queries_matrix from utils.args import * from utils.conf import get_device from utils.losses import * -from models.cext import CExt -from models.utils.utils_problog import build_worlds_queries_matrix from utils.ltn_loss import ADDMNIST_SAT_AGG @@ -14,7 +14,9 @@ def get_parser() -> ArgumentParser: Returns: argpars: argumentparser """ - parser = ArgumentParser(description="Learning via" "Concept Extractor .") + parser = ArgumentParser( + description="Learning via" "Concept Extractor ." + ) add_management_args(parser) add_experiment_args(parser) return parser @@ -28,7 +30,9 @@ class MnistLTNRec(CExt): MNIST OPERATIONS AMONG TWO DIGITS. IT WORKS ONLY IN THIS CONFIGURATION. """ - def __init__(self, encoder, decoder, n_images=2, c_split=(), args=None): + def __init__( + self, encoder, decoder, n_images=2, c_split=(), args=None + ): super(MnistLTNRec, self).__init__( encoder=encoder, n_images=n_images, c_split=c_split ) @@ -81,7 +85,11 @@ def forward(self, x): latents.append((mu + eps * logvar.exp()).view(L, -1)) for i in range(len(self.c_split)): - latents.append(F.gumbel_softmax(c[:, i, :], tau=1, hard=True, dim=-1)) + latents.append( + F.gumbel_softmax( + c[:, i, :], tau=1, hard=True, dim=-1 + ) + ) latents = torch.cat(latents, dim=1) @@ -90,10 +98,20 @@ def forward(self, x): # return everything clen = len(cs[0].shape) - cs = torch.stack(cs, dim=1) if clen == 2 else torch.cat(cs, dim=1) - mus = torch.stack(mus, dim=-1) if clen == 2 else torch.cat(mus, dim=1) + cs = ( + torch.stack(cs, dim=1) + if clen == 2 + else torch.cat(cs, dim=1) + ) + mus = ( + torch.stack(mus, dim=-1) + if clen == 2 + else torch.cat(mus, dim=1) + ) logvars = ( - torch.stack(logvars, dim=-1) if clen == 2 else torch.cat(logvars, dim=1) + torch.stack(logvars, dim=-1) + if clen == 2 + else torch.cat(logvars, dim=1) ) # normalize concept preditions @@ -101,7 +119,9 @@ def forward(self, x): # normalize concept preditions - pred = torch.argmax(pCs[:, 0, :], dim=-1) + torch.argmax(pCs[:, 1, :], dim=-1) + pred = torch.argmax(pCs[:, 0, :], dim=-1) + torch.argmax( + pCs[:, 1, :], dim=-1 + ) return { "YS": F.one_hot(pred, 19), @@ -142,7 +162,9 @@ def normalize_concepts(self, z, split=2): Z2 = torch.sum(prob_digit2, dim=-1, keepdim=True) prob_digit2 = prob_digit2 / Z2 # Normalization - return torch.stack([prob_digit1, prob_digit2], dim=1).view(-1, 2, 10) + return torch.stack([prob_digit1, prob_digit2], dim=1).view( + -1, 2, 10 + ) def get_loss(self, args): """Returns the loss function for this architecture diff --git a/XOR_MNIST/models/mnistpcbmdpl.py b/XOR_MNIST/models/mnistpcbmdpl.py index 9ec30d3..b59f91e 100644 --- a/XOR_MNIST/models/mnistpcbmdpl.py +++ b/XOR_MNIST/models/mnistpcbmdpl.py @@ -1,13 +1,15 @@ # MNIST PCBM with DPL architecture import torch -from models.utils.deepproblog_modules import DeepProblogModel from models.mnistdpl import MnistDPL +from models.utils.deepproblog_modules import ( + DeepProblogModel, + GraphSemiring, +) +from models.utils.utils_problog import * from utils.args import * from utils.conf import get_device -from models.utils.deepproblog_modules import GraphSemiring -from models.utils.utils_problog import * -from utils.losses import * from utils.dpl_loss import ADDMNIST_DPL +from utils.losses import * def get_parser() -> ArgumentParser: @@ -16,7 +18,9 @@ def get_parser() -> ArgumentParser: Returns: argpars: argument parser """ - parser = ArgumentParser(description="Learning via" "Concept Extractor .") + parser = ArgumentParser( + description="Learning via" "Concept Extractor ." + ) add_management_args(parser) add_experiment_args(parser) return parser @@ -56,7 +60,13 @@ def __init__( None: This function does not return a value. """ super(MnistPcbmDPL, self).__init__( - encoder, n_images, c_split, args, model_dict, n_facts, nr_classes + encoder, + n_images, + c_split, + args, + model_dict, + n_facts, + nr_classes, ) self.positives = torch.nn.Parameter( @@ -87,7 +97,9 @@ def batchwise_cdist(self, samples1, samples2, eps=1e-6): Returns: dist: batchwise distance """ - if len(samples1.size()) not in [3, 4, 5] or len(samples2.size()) not in [ + if len(samples1.size()) not in [3, 4, 5] or len( + samples2.size() + ) not in [ 3, 4, 5, @@ -109,7 +121,9 @@ def batchwise_cdist(self, samples1, samples2, eps=1e-6): samples2 = samples2.unsqueeze(3) samples1 = samples1.unsqueeze(1) samples2 = samples2.unsqueeze(0) - result = torch.sqrt(((samples1 - samples2) ** 2).sum(-1) + eps) + result = torch.sqrt( + ((samples1 - samples2) ** 2).sum(-1) + eps + ) return result.view(*result.shape[:-2], -1) else: raise RuntimeError( @@ -129,7 +143,12 @@ def batchwise_cdist(self, samples1, samples2, eps=1e-6): # return torch.sqrt(((samples1 - samples2) ** 2).sum(-1) + eps).view(batch_size, -1) def compute_distance( - self, pred_embeddings, z_tot, negative_scale=None, shift=None, reduction="mean" + self, + pred_embeddings, + z_tot, + negative_scale=None, + shift=None, + reduction="mean", ): """Compute distances between predicted embeddings and latents z @@ -146,7 +165,9 @@ def compute_distance( probability: mean probability """ negative_scale = ( - self.negative_scale if negative_scale is None else negative_scale + self.negative_scale + if negative_scale is None + else negative_scale ) distance = self.batchwise_cdist(pred_embeddings, z_tot) @@ -206,11 +227,15 @@ def forward(self, x): logsigma = torch.clip(logvars[i], max=10) # [batch, n_concepts, n_samples, latent_dim] - pred_embeddings = sample_gaussian_tensors(latents, logsigma, 10) + pred_embeddings = sample_gaussian_tensors( + latents, logsigma, 10 + ) # print("Pred embeddings", pred_embeddings.shape) - concept_logit, concept_prob = self.compute_distance(pred_embeddings, z_tot) + concept_logit, concept_prob = self.compute_distance( + pred_embeddings, z_tot + ) # print("concept_logit", concept_logit.shape) # print("concept_prob", concept_prob.shape) @@ -229,7 +254,13 @@ def forward(self, x): # Problog inference to compute worlds and query probability distributions py, worlds_prob = self.problog_inference(pCs) - return {"CS": pCs, "YS": py, "pCS": pCs, "MUS": mus, "LOGVARS": logvars} + return { + "CS": pCs, + "YS": py, + "pCS": pCs, + "MUS": mus, + "LOGVARS": logvars, + } def normalize_concepts(self, z, split=2): """Computes the probability for each ProbLog fact given the latent vector z @@ -266,7 +297,9 @@ def normalize_concepts(self, z, split=2): Z2 = torch.sum(prob_digit2, dim=-1, keepdim=True) prob_digit2 = prob_digit2 / Z2 # Normalization - return torch.stack([prob_digit1, prob_digit2], dim=1).view(-1, 2, self.n_facts) + return torch.stack([prob_digit1, prob_digit2], dim=1).view( + -1, 2, self.n_facts + ) @staticmethod def get_loss(args): @@ -281,7 +314,12 @@ def get_loss(args): Raises: err: NotImplementedError if the loss function is not present """ - if args.dataset in ["addmnist", "shortmnist", "restrictedmnist", "halfmnist"]: + if args.dataset in [ + "addmnist", + "shortmnist", + "restrictedmnist", + "halfmnist", + ]: return ADDMNIST_DPL(ADDMNIST_Cumulative, pcbm=True) else: return NotImplementedError("Wrong dataset choice") @@ -297,7 +335,8 @@ def start_optim(self, args): None: This function does not return a value. """ self.opt = torch.optim.Adam( - [*self.parameters(), self.positives, self.negatives], args.lr + [*self.parameters(), self.positives, self.negatives], + args.lr, ) diff --git a/XOR_MNIST/models/mnistpcbmltn.py b/XOR_MNIST/models/mnistpcbmltn.py index fe74969..3b66f10 100644 --- a/XOR_MNIST/models/mnistpcbmltn.py +++ b/XOR_MNIST/models/mnistpcbmltn.py @@ -1,15 +1,17 @@ import torch from models.mnistltn import MnistLTN -from utils.args import * -from utils.conf import get_device from models.utils.deepproblog_modules import GraphSemiring from models.utils.utils_problog import * +from utils.args import * +from utils.conf import get_device from utils.losses import * from utils.ltn_loss import ADDMNIST_SAT_AGG def get_parser() -> ArgumentParser: - parser = ArgumentParser(description="Learning via" "Concept Extractor .") + parser = ArgumentParser( + description="Learning via" "Concept Extractor ." + ) add_management_args(parser) add_experiment_args(parser) return parser @@ -23,7 +25,10 @@ class MnistPcbmLTN(MnistLTN): def __init__(self, encoder, n_images=2, c_split=(), args=None): super(MnistPcbmLTN, self).__init__( - encoder=encoder, n_images=n_images, c_split=c_split, args=args + encoder=encoder, + n_images=n_images, + c_split=c_split, + args=args, ) self.positives = torch.nn.Parameter( @@ -43,7 +48,9 @@ def __init__(self, encoder, n_images=2, c_split=(), args=None): self.shift = torch.ones(0, device=self.device) def batchwise_cdist(self, samples1, samples2, eps=1e-6): - if len(samples1.size()) not in [3, 4, 5] or len(samples2.size()) not in [ + if len(samples1.size()) not in [3, 4, 5] or len( + samples2.size() + ) not in [ 3, 4, 5, @@ -65,7 +72,9 @@ def batchwise_cdist(self, samples1, samples2, eps=1e-6): samples2 = samples2.unsqueeze(3) samples1 = samples1.unsqueeze(1) samples2 = samples2.unsqueeze(0) - result = torch.sqrt(((samples1 - samples2) ** 2).sum(-1) + eps) + result = torch.sqrt( + ((samples1 - samples2) ** 2).sum(-1) + eps + ) return result.view(*result.shape[:-2], -1) else: raise RuntimeError( @@ -74,10 +83,17 @@ def batchwise_cdist(self, samples1, samples2, eps=1e-6): ) def compute_distance( - self, pred_embeddings, z_tot, negative_scale=None, shift=None, reduction="mean" + self, + pred_embeddings, + z_tot, + negative_scale=None, + shift=None, + reduction="mean", ): negative_scale = ( - self.negative_scale if negative_scale is None else negative_scale + self.negative_scale + if negative_scale is None + else negative_scale ) distance = self.batchwise_cdist(pred_embeddings, z_tot) @@ -128,11 +144,15 @@ def forward(self, x): logsigma = torch.clip(logvars[i], max=10) # [batch, n_concepts, n_samples, latent_dim] - pred_embeddings = sample_gaussian_tensors(latents, logsigma, 10) + pred_embeddings = sample_gaussian_tensors( + latents, logsigma, 10 + ) # print("Pred embeddings", pred_embeddings.shape) - concept_logit, concept_prob = self.compute_distance(pred_embeddings, z_tot) + concept_logit, concept_prob = self.compute_distance( + pred_embeddings, z_tot + ) # print("concept_logit", concept_logit.shape) # print("concept_prob", concept_prob.shape) @@ -170,7 +190,13 @@ def forward(self, x): pred[mask] = torch.tensor(15, device=pred.device) pred = F.one_hot(pred, 16) - return {"CS": pCs, "YS": pred, "pCS": pCs, "MUS": mus, "LOGVARS": logvars} + return { + "CS": pCs, + "YS": pred, + "pCS": pCs, + "MUS": mus, + "LOGVARS": logvars, + } def normalize_concepts(self, z, split=2): """Computes the probability for each ProbLog fact given the latent vector z""" @@ -198,17 +224,27 @@ def normalize_concepts(self, z, split=2): Z2 = torch.sum(prob_digit2, dim=-1, keepdim=True) prob_digit2 = prob_digit2 / Z2 # Normalization - return torch.stack([prob_digit1, prob_digit2], dim=1).view(-1, 2, self.n_facts) + return torch.stack([prob_digit1, prob_digit2], dim=1).view( + -1, 2, self.n_facts + ) def get_loss(self, args): - if args.dataset in ["addmnist", "shortmnist", "restrictedmnist", "halfmnist"]: - return ADDMNIST_SAT_AGG(ADDMNIST_Cumulative, self.task, pcbm=True) + if args.dataset in [ + "addmnist", + "shortmnist", + "restrictedmnist", + "halfmnist", + ]: + return ADDMNIST_SAT_AGG( + ADDMNIST_Cumulative, self.task, pcbm=True + ) else: return NotImplementedError("Wrong dataset choice") def start_optim(self, args): self.opt = torch.optim.Adam( - [*self.parameters(), self.positives, self.negatives], args.lr + [*self.parameters(), self.positives, self.negatives], + args.lr, ) diff --git a/XOR_MNIST/models/mnistpcbmsl.py b/XOR_MNIST/models/mnistpcbmsl.py index dcfef2f..1838fa3 100644 --- a/XOR_MNIST/models/mnistpcbmsl.py +++ b/XOR_MNIST/models/mnistpcbmsl.py @@ -1,15 +1,17 @@ import torch from models.mnistsl import MnistSL -from utils.args import * -from utils.conf import get_device from models.utils.deepproblog_modules import GraphSemiring from models.utils.utils_problog import * +from utils.args import * +from utils.conf import get_device from utils.losses import * from utils.semantic_loss import ADDMNIST_SL def get_parser() -> ArgumentParser: - parser = ArgumentParser(description="Learning via" "Concept Extractor .") + parser = ArgumentParser( + description="Learning via" "Concept Extractor ." + ) add_management_args(parser) add_experiment_args(parser) return parser @@ -22,14 +24,28 @@ class MnistPcbmSL(MnistSL): """ def __init__( - self, encoder, n_images=2, c_split=(), args=None, n_facts=20, nr_classes=19 + self, + encoder, + n_images=2, + c_split=(), + args=None, + n_facts=20, + nr_classes=19, ): super(MnistPcbmSL, self).__init__( - encoder=encoder, n_images=n_images, c_split=c_split, args=args + encoder=encoder, + n_images=n_images, + c_split=c_split, + args=args, ) self.mlp = torch.nn.Sequential( - torch.nn.Linear(self.encoder.latent_dim * self.n_facts * self.n_images, 50), + torch.nn.Linear( + self.encoder.latent_dim + * self.n_facts + * self.n_images, + 50, + ), torch.nn.ReLU(), torch.nn.Linear(50, 50), torch.nn.ReLU(), @@ -47,7 +63,9 @@ def __init__( self.shift = torch.ones(0, device=self.device) def batchwise_cdist(self, samples1, samples2, eps=1e-6): - if len(samples1.size()) not in [3, 4, 5] or len(samples2.size()) not in [ + if len(samples1.size()) not in [3, 4, 5] or len( + samples2.size() + ) not in [ 3, 4, 5, @@ -69,7 +87,9 @@ def batchwise_cdist(self, samples1, samples2, eps=1e-6): samples2 = samples2.unsqueeze(3) samples1 = samples1.unsqueeze(1) samples2 = samples2.unsqueeze(0) - result = torch.sqrt(((samples1 - samples2) ** 2).sum(-1) + eps) + result = torch.sqrt( + ((samples1 - samples2) ** 2).sum(-1) + eps + ) return result.view(*result.shape[:-2], -1) else: raise RuntimeError( @@ -78,10 +98,17 @@ def batchwise_cdist(self, samples1, samples2, eps=1e-6): ) def compute_distance( - self, pred_embeddings, z_tot, negative_scale=None, shift=None, reduction="mean" + self, + pred_embeddings, + z_tot, + negative_scale=None, + shift=None, + reduction="mean", ): negative_scale = ( - self.negative_scale if negative_scale is None else negative_scale + self.negative_scale + if negative_scale is None + else negative_scale ) distance = self.batchwise_cdist(pred_embeddings, z_tot) @@ -134,13 +161,19 @@ def forward(self, x): n_samples = 10 # [batch, n_concepts, n_samples, latent_dim] - pred_embeddings = sample_gaussian_tensors(latents, logsigma, n_samples) + pred_embeddings = sample_gaussian_tensors( + latents, logsigma, n_samples + ) - z_logits.append(pred_embeddings.permute(0, 2, 1, 3).unsqueeze(-1)) + z_logits.append( + pred_embeddings.permute(0, 2, 1, 3).unsqueeze(-1) + ) # print("Pred embeddings", pred_embeddings.shape) - concept_logit, concept_prob = self.compute_distance(pred_embeddings, z_tot) + concept_logit, concept_prob = self.compute_distance( + pred_embeddings, z_tot + ) # print("concept_logit", concept_logit.shape) # print("concept_prob", concept_prob.shape) @@ -152,7 +185,9 @@ def forward(self, x): z_logits = torch.cat(z_logits, dim=-1) z_logits = z_logits.view( - -1, n_samples, self.encoder.latent_dim * self.n_facts * self.n_images + -1, + n_samples, + self.encoder.latent_dim * self.n_facts * self.n_images, ) preds = torch.zeros((B, self.nr_classes), device=self.device) @@ -171,7 +206,13 @@ def forward(self, x): # Problog inference to compute worlds and query probability distributions # preds = self.mlp(c_logits.view(-1,self.n_facts*2)) - return {"CS": pCs, "YS": preds, "pCS": pCs, "MUS": mus, "LOGVARS": logvars} + return { + "CS": pCs, + "YS": preds, + "pCS": pCs, + "MUS": mus, + "LOGVARS": logvars, + } def normalize_concepts(self, z, split=2): """Computes the probability for each ProbLog fact given the latent vector z""" @@ -199,17 +240,27 @@ def normalize_concepts(self, z, split=2): Z2 = torch.sum(prob_digit2, dim=-1, keepdim=True) prob_digit2 = prob_digit2 / Z2 # Normalization - return torch.stack([prob_digit1, prob_digit2], dim=1).view(-1, 2, self.n_facts) + return torch.stack([prob_digit1, prob_digit2], dim=1).view( + -1, 2, self.n_facts + ) def get_loss(self, args): - if args.dataset in ["addmnist", "shortmnist", "restrictedmnist", "halfmnist"]: - return ADDMNIST_SL(ADDMNIST_Cumulative, self.logic, args, pcbm=True) + if args.dataset in [ + "addmnist", + "shortmnist", + "restrictedmnist", + "halfmnist", + ]: + return ADDMNIST_SL( + ADDMNIST_Cumulative, self.logic, args, pcbm=True + ) else: return NotImplementedError("Wrong dataset choice") def start_optim(self, args): self.opt = torch.optim.Adam( - [*self.parameters(), self.positives, self.negatives], args.lr + [*self.parameters(), self.positives, self.negatives], + args.lr, ) diff --git a/XOR_MNIST/models/mnistsl.py b/XOR_MNIST/models/mnistsl.py index 978535e..55a0d19 100644 --- a/XOR_MNIST/models/mnistsl.py +++ b/XOR_MNIST/models/mnistsl.py @@ -1,10 +1,10 @@ # mnist sl module import torch +from models.cext import CExt +from models.utils.utils_problog import build_worlds_queries_matrix from utils.args import * from utils.conf import get_device from utils.losses import * -from models.cext import CExt -from models.utils.utils_problog import build_worlds_queries_matrix from utils.semantic_loss import ADDMNIST_SL @@ -14,7 +14,9 @@ def get_parser() -> ArgumentParser: Returns: argpars: argument parser """ - parser = ArgumentParser(description="Learning via" "Concept Extractor .") + parser = ArgumentParser( + description="Learning via" "Concept Extractor ." + ) add_management_args(parser) add_experiment_args(parser) return parser @@ -29,7 +31,13 @@ class MnistSL(CExt): """ def __init__( - self, encoder, n_images=2, c_split=(), args=None, n_facts=20, nr_classes=19 + self, + encoder, + n_images=2, + c_split=(), + args=None, + n_facts=20, + nr_classes=19, ): """Initialize method @@ -52,19 +60,31 @@ def __init__( # Worlds-queries matrix if args.task == "addition": self.n_facts = ( - 10 if not args.dataset in ["halfmnist", "restrictedmnist"] else 5 + 10 + if not args.dataset + in ["halfmnist", "restrictedmnist"] + else 5 + ) + self.logic = build_worlds_queries_matrix( + 2, self.n_facts, "addmnist" ) - self.logic = build_worlds_queries_matrix(2, self.n_facts, "addmnist") self.nr_classes = 19 elif args.task == "product": self.n_facts = ( - 10 if not args.dataset in ["halfmnist", "restrictedmnist"] else 5 + 10 + if not args.dataset + in ["halfmnist", "restrictedmnist"] + else 5 + ) + self.logic = build_worlds_queries_matrix( + 2, self.n_facts, "productmnist" ) - self.logic = build_worlds_queries_matrix(2, self.n_facts, "productmnist") self.nr_classes = 37 elif args.task == "multiop": self.n_facts = 5 - self.logic = build_worlds_queries_matrix(2, self.n_facts, "multiopmnist") + self.logic = build_worlds_queries_matrix( + 2, self.n_facts, "multiopmnist" + ) self.nr_classes = 3 self.mlp = torch.nn.Sequential( @@ -97,7 +117,11 @@ def forward(self, x): lc, _, _ = self.encoder(xs[i]) # sizes are ok cs.append(lc) clen = len(cs[0].shape) - cs = torch.stack(cs, dim=1) if clen == 2 else torch.cat(cs, dim=1) + cs = ( + torch.stack(cs, dim=1) + if clen == 2 + else torch.cat(cs, dim=1) + ) pCs = self.normalize_concepts(cs) @@ -137,7 +161,9 @@ def normalize_concepts(self, z, split=2): Z2 = torch.sum(prob_digit2, dim=-1, keepdim=True) prob_digit2 = prob_digit2 / Z2 # Normalization - return torch.stack([prob_digit1, prob_digit2], dim=1).view(-1, 2, self.n_facts) + return torch.stack([prob_digit1, prob_digit2], dim=1).view( + -1, 2, self.n_facts + ) def get_loss(self, args): """Returns the loss function for this architecture @@ -152,7 +178,12 @@ def get_loss(self, args): Raises: err: NotImplementedError if the loss function is not available """ - if args.dataset in ["addmnist", "shortmnist", "restrictedmnist", "halfmnist"]: + if args.dataset in [ + "addmnist", + "shortmnist", + "restrictedmnist", + "halfmnist", + ]: return ADDMNIST_SL(ADDMNIST_Cumulative, self.logic, args) else: return NotImplementedError("Wrong dataset choice") diff --git a/XOR_MNIST/models/mnistslrec.py b/XOR_MNIST/models/mnistslrec.py index e04e977..72a2875 100644 --- a/XOR_MNIST/models/mnistslrec.py +++ b/XOR_MNIST/models/mnistslrec.py @@ -1,10 +1,10 @@ # mnistslrec module import torch +from models.cext import CExt +from models.utils.utils_problog import build_worlds_queries_matrix from utils.args import * from utils.conf import get_device from utils.losses import * -from models.cext import CExt -from models.utils.utils_problog import build_worlds_queries_matrix from utils.semantic_loss import ADDMNIST_SL @@ -14,7 +14,9 @@ def get_parser() -> ArgumentParser: Returns: pars: argument parser """ - parser = ArgumentParser(description="Learning via" "Concept Extractor .") + parser = ArgumentParser( + description="Learning via" "Concept Extractor ." + ) add_management_args(parser) add_experiment_args(parser) return parser @@ -101,7 +103,11 @@ def forward(self, x): latents.append((mu + eps * logvar.exp()).view(L, -1)) for i in range(len(self.c_split)): - latents.append(F.gumbel_softmax(c[:, i, :], tau=1, hard=True, dim=-1)) + latents.append( + F.gumbel_softmax( + c[:, i, :], tau=1, hard=True, dim=-1 + ) + ) latents = torch.cat(latents, dim=1) @@ -110,10 +116,20 @@ def forward(self, x): # return everything clen = len(cs[0].shape) - cs = torch.stack(cs, dim=1) if clen == 2 else torch.cat(cs, dim=1) - mus = torch.stack(mus, dim=-1) if clen == 2 else torch.cat(mus, dim=1) + cs = ( + torch.stack(cs, dim=1) + if clen == 2 + else torch.cat(cs, dim=1) + ) + mus = ( + torch.stack(mus, dim=-1) + if clen == 2 + else torch.cat(mus, dim=1) + ) logvars = ( - torch.stack(logvars, dim=-1) if clen == 2 else torch.cat(logvars, dim=1) + torch.stack(logvars, dim=-1) + if clen == 2 + else torch.cat(logvars, dim=1) ) # normalize concept preditions @@ -160,7 +176,9 @@ def normalize_concepts(self, z, split=2): Z2 = torch.sum(prob_digit2, dim=-1, keepdim=True) prob_digit2 = prob_digit2 / Z2 # Normalization - return torch.stack([prob_digit1, prob_digit2], dim=1).view(-1, 2, 10) + return torch.stack([prob_digit1, prob_digit2], dim=1).view( + -1, 2, 10 + ) def get_loss(self, args): """Returns the loss function for MNIST SL Rec diff --git a/XOR_MNIST/models/utils/deepproblog_modules.py b/XOR_MNIST/models/utils/deepproblog_modules.py index 3c10866..53e9e77 100644 --- a/XOR_MNIST/models/utils/deepproblog_modules.py +++ b/XOR_MNIST/models/utils/deepproblog_modules.py @@ -1,6 +1,6 @@ import torch -from torch import nn from problog.evaluator import Semiring +from torch import nn from utils.conf import get_device @@ -64,10 +64,14 @@ def value(self, a): class DeepProblogModel(nn.Module): - def __init__(self, encoder, model_dict=None, n_facts=20, nr_classes=19): + def __init__( + self, encoder, model_dict=None, n_facts=20, nr_classes=19 + ): super(DeepProblogModel, self).__init__() self.encoder = encoder - self.model_dict = model_dict # Dictionary of pre-compiled ProbLog models + self.model_dict = ( + model_dict # Dictionary of pre-compiled ProbLog models + ) self.device = get_device() self.nr_classes = nr_classes @@ -83,7 +87,9 @@ def forward(self, x): # normalize concept preditions self.facts_probs = self.normalize_concepts(z) # Problog inference to compute worlds and query probability distributions - self.query_prob, self.worlds_prob = self.problog_inference(self.facts_probs) + self.query_prob, self.worlds_prob = self.problog_inference( + self.facts_probs + ) return self.query_prob, self.facts_probs diff --git a/XOR_MNIST/models/utils/ops.py b/XOR_MNIST/models/utils/ops.py index b6b7550..05a4278 100644 --- a/XOR_MNIST/models/utils/ops.py +++ b/XOR_MNIST/models/utils/ops.py @@ -11,14 +11,20 @@ def outer_product(*tensors): # Check if all tensors have the same shape shape = tensors[0].shape if any(tensor.shape != shape for tensor in tensors): - raise ValueError("All input tensors must have the same shape.") + raise ValueError( + "All input tensors must have the same shape." + ) # Create the einsum string dynamically based on the number of tensors - einsum_string = ",".join(f"z{chr(97 + i)}" for i in range(len(tensors))) + einsum_string = ",".join( + f"z{chr(97 + i)}" for i in range(len(tensors)) + ) # Calculate the outer product result = torch.einsum( - einsum_string + "->z" + "".join(chr(97 + i) for i in range(len(tensors))), + einsum_string + + "->z" + + "".join(chr(97 + i) for i in range(len(tensors))), *tensors, ) diff --git a/XOR_MNIST/models/utils/utils_problog.py b/XOR_MNIST/models/utils/utils_problog.py index 4a38bc9..854d274 100644 --- a/XOR_MNIST/models/utils/utils_problog.py +++ b/XOR_MNIST/models/utils/utils_problog.py @@ -1,20 +1,19 @@ +import itertools import os.path import random from datetime import datetime from itertools import product from math import isnan from pathlib import Path -from time import time, sleep +from time import sleep, time + import numpy as np import pandas as pd import torch -from problog.formula import LogicFormula, LogicDAG +from problog.formula import LogicDAG, LogicFormula +from problog.logic import AnnotatedDisjunction, Constant, Term, Var from problog.sdd_formula import SDD from torch import nn -import itertools - -from problog.logic import Term, Constant -from problog.logic import Var, AnnotatedDisjunction def lock_resource(lock_filename): @@ -41,7 +40,12 @@ def create_facts(sequence_len, n_digits=10): digit = Term("digit") X = Var("X") facts = [ - digit(X, Constant(pos), Constant(y), p="p_" + str(pos) + str(y)) + digit( + X, + Constant(pos), + Constant(y), + p="p_" + str(pos) + str(y), + ) for y in range(n_digits) ] annot_disj += str(AnnotatedDisjunction(facts, None)) + "." @@ -51,7 +55,9 @@ def create_facts(sequence_len, n_digits=10): return ad -def define_ProbLog_model(facts, rules, label, digit_query=None, mode="query"): +def define_ProbLog_model( + facts, rules, label, digit_query=None, mode="query" +): """Build the ProbLog model using teh given facts, rules, evidence and query.""" model = "" # Empty program @@ -81,7 +87,9 @@ def define_ProbLog_model(facts, rules, label, digit_query=None, mode="query"): return model -def update_resource(log_filepath, update_info, lock_filename="access.lock"): +def update_resource( + log_filepath, update_info, lock_filename="access.lock" +): # {'Experiment_ID': 0, 'Run_ID': 1, ...} print("Updating resource with: {}".format(update_info)) @@ -125,14 +133,18 @@ def load_mnist_classifier(checkpoint_path, device): clf = clf.to(device) else: clf.load_state_dict( - torch.load(checkpoint_path, map_location=torch.device("cpu")) + torch.load( + checkpoint_path, map_location=torch.device("cpu") + ) ) return clf def define_experiment(exp_folder, exp_class, params, exp_counter): - log_file = Path(os.path.join(exp_folder, exp_class, exp_class + ".csv")) + log_file = Path( + os.path.join(exp_folder, exp_class, exp_class + ".csv") + ) params_columns = [ "latent_dim_sub", "latent_dim_sym", @@ -147,13 +159,17 @@ def define_experiment(exp_folder, exp_class, params, exp_counter): ] if log_file.is_file(): # Load file - log_csv = pd.read_csv(os.path.join(exp_folder, exp_class, exp_class + ".csv")) + log_csv = pd.read_csv( + os.path.join(exp_folder, exp_class, exp_class + ".csv") + ) # Check if the required number of test has been already satisfied required_exp = params["n_exp"] if len(log_csv) > 0: - query = "".join(f" {key} == {params[key]} &" for key in params_columns)[:-1] + query = "".join( + f" {key} == {params[key]} &" for key in params_columns + )[:-1] n_exp = len(log_csv.query(query)) if n_exp == 0: exp_ID = log_csv["exp_ID"].max() + 1 @@ -163,30 +179,40 @@ def define_experiment(exp_folder, exp_class, params, exp_counter): print( "\n\n{} compatible experiments found in file {} -> {} experiments to run.".format( n_exp, - os.path.join(exp_folder, exp_class, exp_class + ".csv"), + os.path.join( + exp_folder, exp_class, exp_class + ".csv" + ), counter, ) ) - run_ID = datetime.today().strftime("%d-%m-%Y-%H-%M-%S") + run_ID = datetime.today().strftime( + "%d-%m-%Y-%H-%M-%S" + ) elif n_exp < required_exp: exp_ID = log_csv.query(query)["exp_ID"].values[0] counter = required_exp - n_exp print( "\n\n{} compatible experiments found in file {} -> {} experiments to run.".format( n_exp, - os.path.join(exp_folder, exp_class, exp_class + ".csv"), + os.path.join( + exp_folder, exp_class, exp_class + ".csv" + ), counter, ) ) - run_ID = datetime.today().strftime("%d-%m-%Y-%H-%M-%S") + run_ID = datetime.today().strftime( + "%d-%m-%Y-%H-%M-%S" + ) else: print( "\n\n{} compatible experiments found in file {} -> No experiments to run.".format( n_exp, - os.path.join(exp_folder, exp_class, exp_class + ".csv"), + os.path.join( + exp_folder, exp_class, exp_class + ".csv" + ), 0, ) ) @@ -208,7 +234,10 @@ def define_experiment(exp_folder, exp_class, params, exp_counter): else: counter = params["n_exp"] # Create log file - log_file = open(os.path.join(exp_folder, exp_class, exp_class + ".csv"), "w") + log_file = open( + os.path.join(exp_folder, exp_class, exp_class + ".csv"), + "w", + ) header = ( "exp_ID,run_ID," + "".join(str(key) + "," for key in params_columns) @@ -225,7 +254,9 @@ def define_experiment(exp_folder, exp_class, params, exp_counter): print("-" * 40) print( "\nNO csv file found -> new file created {}".format( - os.path.join(exp_folder, exp_class, exp_class + ".csv") + os.path.join( + exp_folder, exp_class, exp_class + ".csv" + ) ) ) print("-" * 40) @@ -249,14 +280,22 @@ def build_model_dict(sequence_len, n_digits): rules = "addition(X,N) :- digit(X,1,N1), digit(X,2,N2), N is N1 + N2.\ndigits(X,Y):-digit(img,1,X), digit(img,2,Y)." facts = create_facts(sequence_len, n_digits=n_digits) model_dict = { - "query": {add: "EMPTY" for add in possible_query_add[sequence_len]}, - "evidence": {add: "EMPTY" for add in possible_query_add[sequence_len]}, + "query": { + add: "EMPTY" for add in possible_query_add[sequence_len] + }, + "evidence": { + add: "EMPTY" for add in possible_query_add[sequence_len] + }, } for mode in ["query", "evidence"]: for add in model_dict[mode]: problog_model = define_ProbLog_model( - facts, rules, label=add, digit_query="digits(X,Y)", mode=mode + facts, + rules, + label=add, + digit_query="digits(X,Y)", + mode=mode, ) lf = LogicFormula.create_from(problog_model) dag = LogicDAG.create_from(lf) @@ -266,7 +305,9 @@ def build_model_dict(sequence_len, n_digits): return model_dict -def build_worlds_queries_matrix_kandinsky(sequence_len=0, n_facts=0, n_shapes=0): +def build_worlds_queries_matrix_kandinsky( + sequence_len=0, n_facts=0, n_shapes=0 +): """Build Worlds Queries Matrices The Kandinsky Figure has two pairs of objects with the same shape, @@ -281,7 +322,9 @@ def find_equal_indices(vector): index_dict[value].append(i) equal_indices = { - key: indices for key, indices in index_dict.items() if len(indices) > 1 + key: indices + for key, indices in index_dict.items() + if len(indices) > 1 } return list(equal_indices.values()) @@ -289,7 +332,9 @@ def two_pairs(shapesf1, shapesf2): c_f1 = Counter(shapesf1) c_f2 = Counter(shapesf2) - idx_1, idx_2 = find_equal_indices(shapesf1), find_equal_indices(shapesf2) + idx_1, idx_2 = find_equal_indices( + shapesf1 + ), find_equal_indices(shapesf2) to_rtn_idx = [idx_1, idx_2] if len(c_f1) == 2 and len(c_f2) == 2: @@ -306,7 +351,9 @@ def two_pairs(shapesf1, shapesf2): product(range(n_facts * n_shapes), repeat=sequence_len) ) # 576 n_worlds = len(possible_worlds) - n_queries = len(range(0, 2)) # 2 possible queries (it is or it is not, right?) + n_queries = len( + range(0, 2) + ) # 2 possible queries (it is or it is not, right?) look_up = {i: c for i, c in zip(range(n_worlds), possible_worlds)} w_q = torch.zeros(n_worlds, n_queries) # (576, 2) @@ -327,13 +374,19 @@ def two_pairs(shapesf1, shapesf2): return w_q -def build_worlds_queries_matrix(sequence_len=0, n_digits=0, task="addmnist"): +def build_worlds_queries_matrix( + sequence_len=0, n_digits=0, task="addmnist" +): """Build Worlds-Queries matrix""" if task == "addmnist": - possible_worlds = list(product(range(n_digits), repeat=sequence_len)) + possible_worlds = list( + product(range(n_digits), repeat=sequence_len) + ) n_worlds = len(possible_worlds) n_queries = len(range(0, 10 + 10)) - look_up = {i: c for i, c in zip(range(n_worlds), possible_worlds)} + look_up = { + i: c for i, c in zip(range(n_worlds), possible_worlds) + } w_q = torch.zeros(n_worlds, n_queries) # (100, 20) for w in range(n_worlds): digit1, digit2 = look_up[w] @@ -343,14 +396,18 @@ def build_worlds_queries_matrix(sequence_len=0, n_digits=0, task="addmnist"): return w_q elif task == "productmnist": - possible_worlds = list(product(range(n_digits), repeat=sequence_len)) + possible_worlds = list( + product(range(n_digits), repeat=sequence_len) + ) n_worlds = len(possible_worlds) n_queries = [0] for i, j in itertools.product(range(1, 10), range(1, 10)): n_queries.append(i * j) n_queries = np.unique(np.array(n_queries)) - look_up = {i: c for i, c in zip(range(n_worlds), possible_worlds)} + look_up = { + i: c for i, c in zip(range(n_worlds), possible_worlds) + } w_q = torch.zeros(n_worlds, len(n_queries)) # (100, boh) for w in range(n_worlds): digit1, digit2 = look_up[w] @@ -361,12 +418,16 @@ def build_worlds_queries_matrix(sequence_len=0, n_digits=0, task="addmnist"): return w_q elif task == "multiopmnist": - possible_worlds = list(product(range(n_digits), repeat=sequence_len)) + possible_worlds = list( + product(range(n_digits), repeat=sequence_len) + ) n_worlds = len(possible_worlds) n_queries = np.array([0, 1, 2, 3]) w_q = torch.zeros(n_worlds, len(n_queries)) # (16, 4) - look_up = {i: c for i, c in zip(range(n_worlds), possible_worlds)} + look_up = { + i: c for i, c in zip(range(n_worlds), possible_worlds) + } for w in range(n_worlds): digit1, digit2 = look_up[w] for i, q in enumerate(n_queries): @@ -418,7 +479,9 @@ def build_worlds_queries_matrix_KAND( n_queries = 3 - look_up = {i: c for i, c in zip(range(n_worlds), possible_worlds)} + look_up = { + i: c for i, c in zip(range(n_worlds), possible_worlds) + } w_q = torch.zeros(n_worlds, n_queries) # (3^6, 9) for w in range(n_worlds): s1, s2, s3 = look_up[w] @@ -459,12 +522,16 @@ def build_worlds_queries_matrix_KAND( else: and_or_rule[p, 0] = 1 - possible_worlds = list(product(range(n_poss), repeat=n_concepts)) + possible_worlds = list( + product(range(n_poss), repeat=n_concepts) + ) n_worlds = len(possible_worlds) n_queries = 9 - look_up = {i: c for i, c in zip(range(n_worlds), possible_worlds)} + look_up = { + i: c for i, c in zip(range(n_worlds), possible_worlds) + } w_q = torch.zeros(n_worlds, n_queries) # (3^6, 9) for w in range(n_worlds): s1, s2, s3, c1, c2, c3 = look_up[w] @@ -502,12 +569,16 @@ def build_worlds_queries_matrix_KAND( else: and_or_rule[p, 0] = 1 - possible_worlds = list(product(range(n_poss), repeat=n_concepts)) + possible_worlds = list( + product(range(n_poss), repeat=n_concepts) + ) n_worlds = len(possible_worlds) n_queries = 9 - look_up = {i: c for i, c in zip(range(n_worlds), possible_worlds)} + look_up = { + i: c for i, c in zip(range(n_worlds), possible_worlds) + } w_q = torch.zeros(n_worlds, n_queries) # (3^6, 9) for w in range(n_worlds): s1, s2, s3, c1, c2, c3 = look_up[w] @@ -531,12 +602,16 @@ def build_worlds_queries_matrix_KAND( return w_q, and_or_rule elif task == "red_triangle": - possible_worlds = list(product(range(n_poss), repeat=n_concepts)) + possible_worlds = list( + product(range(n_poss), repeat=n_concepts) + ) n_worlds = len(possible_worlds) n_queries = 2 - look_up = {i: c for i, c in zip(range(n_worlds), possible_worlds)} + look_up = { + i: c for i, c in zip(range(n_worlds), possible_worlds) + } w_q = torch.zeros(n_worlds, n_queries) # (3^8, 2) for w in range(n_worlds): s1, s2, s3, c1, c2, c3 = look_up[w] @@ -553,12 +628,16 @@ def build_worlds_queries_matrix_KAND( return w_q, and_rule elif task == "base": - possible_worlds = list(product(range(n_poss), repeat=n_concepts)) + possible_worlds = list( + product(range(n_poss), repeat=n_concepts) + ) n_worlds = len(possible_worlds) n_queries = 2 - look_up = {i: c for i, c in zip(range(n_worlds), possible_worlds)} + look_up = { + i: c for i, c in zip(range(n_worlds), possible_worlds) + } w_q = torch.zeros(n_worlds, n_queries) # (3^8, 2) for w in range(n_worlds): s1, s2, s3, s4, c1, c2, c3, c4 = look_up[w] @@ -594,9 +673,9 @@ def build_worlds_queries_matrix_KAND( if __name__ == "__main__": - print("DIOCANE") - - w_q, and_or_rule = build_worlds_queries_matrix_KAND(task="mini_patterns") + w_q, and_or_rule = build_worlds_queries_matrix_KAND( + task="mini_patterns" + ) # print(w_q) diff --git a/XOR_MNIST/server.py b/XOR_MNIST/server.py index 6539bb1..dabc561 100644 --- a/XOR_MNIST/server.py +++ b/XOR_MNIST/server.py @@ -1,10 +1,12 @@ # Server module # Allows for the execution of multiple parameters +import os +import sys + import submitit -from main import parse_args, main from experiments import * -import os, sys +from main import main, parse_args conf_path = os.getcwd() + "." sys.path.append(conf_path) @@ -13,7 +15,9 @@ # start_main() args = parse_args() # # args = prepare_args() # parse_args() # - executor = submitit.AutoExecutor(folder="./logs", slurm_max_num_timeout=150) + executor = submitit.AutoExecutor( + folder="./logs", slurm_max_num_timeout=150 + ) executor.update_parameters( mem_gb=4, gpus_per_node=1, diff --git a/XOR_MNIST/utils/__init__.py b/XOR_MNIST/utils/__init__.py index 703d15d..6e4b7b2 100644 --- a/XOR_MNIST/utils/__init__.py +++ b/XOR_MNIST/utils/__init__.py @@ -1,6 +1,6 @@ # init utils module -import os import builtins +import os import sys diff --git a/XOR_MNIST/utils/args.py b/XOR_MNIST/utils/args.py index 8dcee7b..45f82bd 100644 --- a/XOR_MNIST/utils/args.py +++ b/XOR_MNIST/utils/args.py @@ -1,6 +1,7 @@ # Args module from argparse import ArgumentParser + from datasets import NAMES as DATASET_NAMES from models import get_all_models @@ -80,34 +81,63 @@ def add_experiment_args(parser: ArgumentParser) -> None: ) # weights of logic parser.add_argument( - "--w_sl", type=float, default=10, help="Weight of Semantic Loss" + "--w_sl", + type=float, + default=10, + help="Weight of Semantic Loss", ) # weight of mitigation - parser.add_argument("--gamma", type=float, default=1, help="Weight of mitigation") + parser.add_argument( + "--gamma", type=float, default=1, help="Weight of mitigation" + ) # additional hyperparams parser.add_argument( - "--w_rec", type=float, default=1, help="Weight of Reconstruction" + "--w_rec", + type=float, + default=1, + help="Weight of Reconstruction", + ) + parser.add_argument( + "--beta", type=float, default=2, help="Multiplier of KL" + ) + parser.add_argument( + "--w_h", type=float, default=1, help="Weight of entropy" + ) + parser.add_argument( + "--w_c", type=float, default=1, help="Weight of concept sup" ) - parser.add_argument("--beta", type=float, default=2, help="Multiplier of KL") - parser.add_argument("--w_h", type=float, default=1, help="Weight of entropy") - parser.add_argument("--w_c", type=float, default=1, help="Weight of concept sup") # optimization params - parser.add_argument("--lr", type=float, default=0.001, help="Learning rate.") - parser.add_argument("--warmup_steps", type=int, default=2, help="Warmup epochs.") parser.add_argument( - "--exp_decay", type=float, default=0.99, help="Exp decay of learning rate." + "--lr", type=float, default=0.001, help="Learning rate." + ) + parser.add_argument( + "--warmup_steps", type=int, default=2, help="Warmup epochs." + ) + parser.add_argument( + "--exp_decay", + type=float, + default=0.99, + help="Exp decay of learning rate.", ) # learning hyperams parser.add_argument( - "--n_epochs", type=int, default=50, help="Number of epochs per task." + "--n_epochs", + type=int, + default=50, + help="Number of epochs per task.", + ) + parser.add_argument( + "--batch_size", type=int, default=64, help="Batch size." ) - parser.add_argument("--batch_size", type=int, default=64, help="Batch size.") # deep ensembles parser.add_argument( - "--n_ensembles", type=int, default=5, help="Number of model in DeepEnsembles" + "--n_ensembles", + type=int, + default=5, + help="Number of model in DeepEnsembles", ) @@ -121,9 +151,13 @@ def add_management_args(parser: ArgumentParser) -> None: None: This function does not return a value. """ # random seed - parser.add_argument("--seed", type=int, default=None, help="The random seed.") + parser.add_argument( + "--seed", type=int, default=None, help="The random seed." + ) # verbosity - parser.add_argument("--notes", type=str, default=None, help="Notes for this run.") + parser.add_argument( + "--notes", type=str, default=None, help="Notes for this run." + ) parser.add_argument("--non_verbose", action="store_true") # logging parser.add_argument( @@ -172,7 +206,10 @@ def add_management_args(parser: ArgumentParser) -> None: help="Used to preprocess dataset", ) parser.add_argument( - "--finetuning", type=int, default=0, help="Phase of active learning" + "--finetuning", + type=int, + default=0, + help="Phase of active learning", ) @@ -187,14 +224,22 @@ def add_test_args(parser: ArgumentParser) -> None: """ # random seed parser.add_argument( - "--use_ood", action="store_true", help="Use Out of Distribution test samples." + "--use_ood", + action="store_true", + help="Use Out of Distribution test samples.", ) # verbosity parser.add_argument( "--type", type=str, default="frequentist", - choices=["frequentist", "mcdropout", "ensemble", "laplace", "bears"], + choices=[ + "frequentist", + "mcdropout", + "ensemble", + "laplace", + "bears", + ], help="Evaluation type.", ) parser.add_argument( @@ -210,7 +255,10 @@ def add_test_args(parser: ArgumentParser) -> None: help="Employ a knowledge aware KL.", ) parser.add_argument( - "--real-kl", action="store_true", default=False, help="Real paper KL." + "--real-kl", + action="store_true", + default=False, + help="Real paper KL.", ) parser.add_argument( "--evaluate-all", @@ -220,7 +268,10 @@ def add_test_args(parser: ArgumentParser) -> None: ) # weight of deep separation kl parser.add_argument( - "--lambda_h", type=float, default=1, help="Lambda for the KL divergence" + "--lambda_h", + type=float, + default=1, + help="Lambda for the KL divergence", ) parser.add_argument( "--skip_laplace", diff --git a/XOR_MNIST/utils/bayes.py b/XOR_MNIST/utils/bayes.py index e00b8a2..ed123c4 100644 --- a/XOR_MNIST/utils/bayes.py +++ b/XOR_MNIST/utils/bayes.py @@ -1,27 +1,25 @@ # Bayes module +import math +from typing import List, Tuple + +import numpy as np import torch import torch.nn as nn -import numpy as np -from numpy import ndarray -from typing import Tuple, List -from utils.checkpoint import load_checkpoint - -from utils.metrics import evaluate_metrics, vector_to_parameters -from utils.conf import set_random_seed -from models import get_model -from warmup_scheduler import GradualWarmupScheduler +import wandb +from datasets import get_dataset from datasets.utils.base_dataset import BaseDataset - from laplace import Laplace -from utils import fprint -from datasets import get_dataset -from scipy.special import softmax +from models import get_model from models.utils.utils_problog import build_worlds_queries_matrix -import math +from numpy import ndarray +from scipy.special import softmax +from torch.utils.data import DataLoader, Dataset +from utils import fprint +from utils.checkpoint import load_checkpoint +from utils.conf import set_random_seed +from utils.metrics import evaluate_metrics, vector_to_parameters from utils.status import progress_bar -import wandb - -from torch.utils.data import Dataset, DataLoader +from warmup_scheduler import GradualWarmupScheduler class DatasetPcX(Dataset): @@ -232,18 +230,26 @@ def montecarlo_dropout_single_batch( # activate dropout during evaluation activate_dropout(model) - output_dicts = [model(batch_samples) for _ in range(num_mc_samples)] # 30 + output_dicts = [ + model(batch_samples) for _ in range(num_mc_samples) + ] # 30 label_prob = [ - out_dict["YS"].detach().cpu().numpy() for out_dict in output_dicts + out_dict["YS"].detach().cpu().numpy() + for out_dict in output_dicts ] # 30 concept_logit = [ - out_dict["CS"].detach().cpu().numpy() for out_dict in output_dicts + out_dict["CS"].detach().cpu().numpy() + for out_dict in output_dicts ] # 30 concept_prob = [ - out_dict["pCS"].detach().cpu().numpy() for out_dict in output_dicts + out_dict["pCS"].detach().cpu().numpy() + for out_dict in output_dicts + ] # 30 + ll = [ + out_dict["pCS"].detach().cpu().numpy() + for out_dict in output_dicts ] # 30 - ll = [out_dict["pCS"].detach().cpu().numpy() for out_dict in output_dicts] # 30 label_prob = np.stack(label_prob, axis=0) concept_logit = np.stack(concept_logit, axis=0) @@ -310,14 +316,24 @@ def montecarlo_dropout( c_pred = concept_logit_ens pc_pred = concept_prob_ens else: - y_true = np.concatenate([y_true, labels.detach().cpu().numpy()], axis=0) - c_true = np.concatenate([c_true, concepts.detach().cpu().numpy()], axis=0) + y_true = np.concatenate( + [y_true, labels.detach().cpu().numpy()], axis=0 + ) + c_true = np.concatenate( + [c_true, concepts.detach().cpu().numpy()], axis=0 + ) y_pred = np.concatenate([y_pred, label_prob_ens], axis=1) - c_pred = np.concatenate([c_pred, concept_logit_ens], axis=1) - pc_pred = np.concatenate([pc_pred, concept_prob_ens], axis=1) + c_pred = np.concatenate( + [c_pred, concept_logit_ens], axis=1 + ) + pc_pred = np.concatenate( + [pc_pred, concept_prob_ens], axis=1 + ) # Compute the final arrangements - gs = np.split(c_true, c_true.shape[1], axis=1) # splitted groundtruth concepts + gs = np.split( + c_true, c_true.shape[1], axis=1 + ) # splitted groundtruth concepts cs = np.split( c_pred, c_pred.shape[2], axis=2 ) # splitted concepts # (nmod, data, 10, 10) @@ -334,15 +350,21 @@ def montecarlo_dropout( gs = np.where(gs == "-1-1", -1, gs) gs = gs.squeeze(-1).astype(int) - p_cs_1 = np.expand_dims(p_cs[0].squeeze(2), axis=-1) # 30, 256, 10, 1 - p_cs_2 = np.expand_dims(p_cs[1].squeeze(2), axis=-2) # 30, 256, 1, 10 + p_cs_1 = np.expand_dims( + p_cs[0].squeeze(2), axis=-1 + ) # 30, 256, 10, 1 + p_cs_2 = np.expand_dims( + p_cs[1].squeeze(2), axis=-2 + ) # 30, 256, 1, 10 p_cs = np.matmul( p_cs_1, p_cs_2 ) # 30, 256, 10, 10 -> # [#modelli, #data, #facts, #facts] p_cs = np.reshape( p_cs, (*p_cs.shape[:-2], p_cs.shape[-1] * p_cs.shape[-2]) ) # -> [#modelli, #data, #facts^2] - p_cs = np.mean(p_cs, axis=0) # avg[#modelli, #data, #facts^2] = [#data, #facts^2] + p_cs = np.mean( + p_cs, axis=0 + ) # avg[#modelli, #data, #facts^2] = [#data, #facts^2] # mean probabilities of the output p_ys = np.mean(y_pred, axis=0) # -> (256, 19) @@ -350,8 +372,12 @@ def montecarlo_dropout( p_ys_full = p_ys # all the items of probabilities are considered (#data, #facts^2) p_cs_full = p_cs # all the items of probabilities are considered (#data, #facts^2) - cs = p_cs.argmax(axis=1) # the predicted concept is the argument maximum - p_cs = p_cs.max(axis=1) # only the maximum one is considered (#data,) + cs = p_cs.argmax( + axis=1 + ) # the predicted concept is the argument maximum + p_cs = p_cs.max( + axis=1 + ) # only the maximum one is considered (#data,) p_ys = p_ys.max(axis=1) # only the maximum one is considered assert gs.shape == cs.shape, f"gs: {gs.shape}, cs: {cs.shape}" @@ -427,7 +453,9 @@ def early_stop(self, model, validation_loss): self.counter = 0 self.best_weights = model.state_dict() self.stuck = False - elif validation_loss > (self.min_validation_loss + self.min_delta): + elif validation_loss > ( + self.min_validation_loss + self.min_delta + ): self.counter += 1 if self.counter >= self.patience: model.load_state_dict(self.best_weights) @@ -487,7 +515,9 @@ def deep_ensemble_active( # lambda for kl lambda_h = 0.01 - def kl_paper(p_model: torch.tensor, p_rest: torch.tensor, k, last_hope=True): + def kl_paper( + p_model: torch.tensor, p_rest: torch.tensor, k, last_hope=True + ): """KL of the paper Args: @@ -507,17 +537,23 @@ def kl_paper(p_model: torch.tensor, p_rest: torch.tensor, k, last_hope=True): p_rest = p_rest / (1 + (p_model.shape[1] * 1e-5)) if last_hope: - kl_ew = torch.sum(p_model * torch.log(p_model + (k - 1) * p_rest), dim=1) + kl_ew = torch.sum( + p_model * torch.log(p_model + (k - 1) * p_rest), dim=1 + ) else: ratio = torch.div(p_rest, p_model) - kl_ew = torch.sum(p_model * torch.log(1 + (k - 1) * ratio), dim=1) + kl_ew = torch.sum( + p_model * torch.log(1 + (k - 1) * ratio), dim=1 + ) return torch.mean(kl_ew, dim=0) dataset = get_dataset(args) if len(supervision) > 0: - dataset.give_supervision_to(supervision[0], supervision[1], supervision[2]) + dataset.give_supervision_to( + supervision[0], supervision[1], supervision[2] + ) # Load dataset, model, loss, and optimizer encoder, decoder = dataset.get_backbone() @@ -547,8 +583,12 @@ def kl_paper(p_model: torch.tensor, p_rest: torch.tensor, k, last_hope=True): tag = ["shapes", "colors"] for i in range(2): - all_c = np.concatenate((cs[i][:, 0], cs[i][:, 1], cs[i][:, 2])) - all_g = np.concatenate((gs[i][:, 0], gs[i][:, 1], gs[i][:, 2])) + all_c = np.concatenate( + (cs[i][:, 0], cs[i][:, 1], cs[i][:, 2]) + ) + all_g = np.concatenate( + (gs[i][:, 0], gs[i][:, 1], gs[i][:, 2]) + ) if use_wandb: wandb.log( @@ -562,11 +602,19 @@ def kl_paper(p_model: torch.tensor, p_rest: torch.tensor, k, last_hope=True): cs = np.split(c, 2, axis=1) gs = np.split(g, 2, axis=1) - shapes_pred = np.concatenate((cs[0][:, 0], cs[0][:, 1], cs[0][:, 2])) - shapes_true = np.concatenate((gs[0][:, 0], gs[0][:, 1], gs[0][:, 2])) + shapes_pred = np.concatenate( + (cs[0][:, 0], cs[0][:, 1], cs[0][:, 2]) + ) + shapes_true = np.concatenate( + (gs[0][:, 0], gs[0][:, 1], gs[0][:, 2]) + ) - colors_pred = np.concatenate((cs[1][:, 0], cs[1][:, 1], cs[1][:, 2])) - colors_true = np.concatenate((gs[1][:, 0], gs[1][:, 1], gs[1][:, 2])) + colors_pred = np.concatenate( + (cs[1][:, 0], cs[1][:, 1], cs[1][:, 2]) + ) + colors_true = np.concatenate( + (gs[1][:, 0], gs[1][:, 1], gs[1][:, 2]) + ) all_c = shapes_pred * 3 + colors_pred all_g = shapes_true * 3 + colors_true @@ -597,8 +645,12 @@ def kl_paper(p_model: torch.tensor, p_rest: torch.tensor, k, last_hope=True): model.to(model.device) train_loader, val_loader, _ = dataset.get_data_loaders() - scheduler = torch.optim.lr_scheduler.ExponentialLR(model.opt, args.exp_decay) - w_scheduler = GradualWarmupScheduler(model.opt, 1.0, args.warmup_steps) + scheduler = torch.optim.lr_scheduler.ExponentialLR( + model.opt, args.exp_decay + ) + w_scheduler = GradualWarmupScheduler( + model.opt, 1.0, args.warmup_steps + ) # default for warm-up model.opt.zero_grad() @@ -618,7 +670,11 @@ def kl_paper(p_model: torch.tensor, p_rest: torch.tensor, k, last_hope=True): # Generate adversarial examples using x_batch out_dict = model(images) out_dict.update( - {"INPUTS": images, "LABELS": labels, "CONCEPTS": concepts} + { + "INPUTS": images, + "LABELS": labels, + "CONCEPTS": concepts, + } ) model.opt.zero_grad() @@ -628,9 +684,13 @@ def kl_paper(p_model: torch.tensor, p_rest: torch.tensor, k, last_hope=True): c_worlds = [] for i in range(3): - list_concepts = torch.split(c_prb[:, i, :], 3, dim=-1) + list_concepts = torch.split( + c_prb[:, i, :], 3, dim=-1 + ) p_w_image = ( - outer_product(*list_concepts).unsqueeze(1).view(-1, 1, 3**6) + outer_product(*list_concepts) + .unsqueeze(1) + .view(-1, 1, 3**6) ) c_worlds.append(p_w_image) @@ -654,7 +714,9 @@ def kl_paper(p_model: torch.tensor, p_rest: torch.tensor, k, last_hope=True): other_c_worlds = [] for i in range(3): - list_concepts = torch.split(c_prb[:, i, :], 3, dim=-1) + list_concepts = torch.split( + c_prb[:, i, :], 3, dim=-1 + ) p_w_image = ( outer_product(*list_concepts) .unsqueeze(1) @@ -669,14 +731,18 @@ def kl_paper(p_model: torch.tensor, p_rest: torch.tensor, k, last_hope=True): pcx_list_ensemble.append(c_prb_other) pcx_list_ensemble = torch.stack(pcx_list_ensemble) - other_m_pc_x_mean = torch.mean(pcx_list_ensemble, dim=0) + other_m_pc_x_mean = torch.mean( + pcx_list_ensemble, dim=0 + ) for i in range(3): p_t = model_itself_pc_x[:, i, :] p_ens = other_m_pc_x_mean[:, i, :] distance = lambda_h * ( 1 - + kl_paper(p_t, p_ens, len(ensemble) + 1, True) + + kl_paper( + p_t, p_ens, len(ensemble) + 1, True + ) / (6 * math.log(3)) ) # remove last hope @@ -700,7 +766,12 @@ def kl_paper(p_model: torch.tensor, p_rest: torch.tensor, k, last_hope=True): model.opt.step() if ti % 10 == 0: - progress_bar(ti, len(train_loader) - 9, epoch, loss_original.item()) + progress_bar( + ti, + len(train_loader) - 9, + epoch, + loss_original.item(), + ) # update at end of the epoch if epoch < args.warmup_steps: @@ -713,9 +784,16 @@ def kl_paper(p_model: torch.tensor, p_rest: torch.tensor, k, last_hope=True): model.eval() # Evaluate performances on VAL - y_true, c_true, y_pred, c_pred, p_cs, p_ys, p_cs_all, p_ys_all = ( - evaluate_metrics(model, val_loader, args, last=True) - ) + ( + y_true, + c_true, + y_pred, + c_pred, + p_cs, + p_ys, + p_cs_all, + p_ys_all, + ) = evaluate_metrics(model, val_loader, args, last=True) if True: import os @@ -772,8 +850,12 @@ def kl_paper(p_model: torch.tensor, p_rest: torch.tensor, k, last_hope=True): tag = ["shapes", "colors"] for i in range(2): - all_c = np.concatenate((cs[i][:, 0], cs[i][:, 1], cs[i][:, 2])) - all_g = np.concatenate((gs[i][:, 0], gs[i][:, 1], gs[i][:, 2])) + all_c = np.concatenate( + (cs[i][:, 0], cs[i][:, 1], cs[i][:, 2]) + ) + all_g = np.concatenate( + (gs[i][:, 0], gs[i][:, 1], gs[i][:, 2]) + ) wandb.log( { @@ -786,11 +868,19 @@ def kl_paper(p_model: torch.tensor, p_rest: torch.tensor, k, last_hope=True): cs = np.split(c, 2, axis=1) gs = np.split(g, 2, axis=1) - shapes_pred = np.concatenate((cs[0][:, 0], cs[0][:, 1], cs[0][:, 2])) - shapes_true = np.concatenate((gs[0][:, 0], gs[0][:, 1], gs[0][:, 2])) + shapes_pred = np.concatenate( + (cs[0][:, 0], cs[0][:, 1], cs[0][:, 2]) + ) + shapes_true = np.concatenate( + (gs[0][:, 0], gs[0][:, 1], gs[0][:, 2]) + ) - colors_pred = np.concatenate((cs[1][:, 0], cs[1][:, 1], cs[1][:, 2])) - colors_true = np.concatenate((gs[1][:, 0], gs[1][:, 1], gs[1][:, 2])) + colors_pred = np.concatenate( + (cs[1][:, 0], cs[1][:, 1], cs[1][:, 2]) + ) + colors_true = np.concatenate( + (gs[1][:, 0], gs[1][:, 1], gs[1][:, 2]) + ) all_c = shapes_pred * 3 + colors_pred all_g = shapes_true * 3 + colors_true @@ -846,8 +936,8 @@ def deep_ensemble( Returns: ensemble: models ensemble """ - from datasets.utils.base_dataset import get_loader import wandb + from datasets.utils.base_dataset import get_loader def wandb_log_step_resense(i, epoch, loss_ce, loss_kl, prefix): wandb.log( @@ -871,7 +961,11 @@ def wandb_log_step_deep_ens(i, epoch, loss_ce, loss_adv, prefix): def wandb_log_val(i, epoch, loss, prefix): wandb.log( - {f"{prefix}loss-val": loss, f"{prefix}epoch": epoch, f"{prefix}step": i} + { + f"{prefix}loss-val": loss, + f"{prefix}epoch": epoch, + f"{prefix}step": i, + } ) def wandb_log_epoch(**kwargs): @@ -900,7 +994,9 @@ def wandb_log_epoch(**kwargs): if separate_from_others: print("Doing a separation with KL...") - def kl_paper(p_model: torch.tensor, p_rest: torch.tensor, k, last_hope=True): + def kl_paper( + p_model: torch.tensor, p_rest: torch.tensor, k, last_hope=True + ): p_model = p_model + 1e-5 p_rest = p_rest + 1e-5 p_model = p_model / (1 + (p_model.shape[1] * 1e-5)) @@ -908,9 +1004,13 @@ def kl_paper(p_model: torch.tensor, p_rest: torch.tensor, k, last_hope=True): if not last_hope: ratio = torch.div(p_rest, p_model) - kl_ew = torch.sum(p_model * torch.log(1 + (k - 1) * ratio), dim=1) + kl_ew = torch.sum( + p_model * torch.log(1 + (k - 1) * ratio), dim=1 + ) else: - kl_ew = torch.sum(p_model * torch.log(p_model + (k - 1) * p_rest), dim=1) + kl_ew = torch.sum( + p_model * torch.log(p_model + (k - 1) * p_rest), dim=1 + ) return torch.mean(kl_ew, dim=0) @@ -927,7 +1027,10 @@ def kl_paper(p_model: torch.tensor, p_rest: torch.tensor, k, last_hope=True): n_images, c_split = dataset.get_split() train_loader, val_loader, _ = dataset.get_data_loaders() pcx_loader = get_loader( - dataset.dataset_train, args.batch_size, num_workers=4, val_test=True + dataset.dataset_train, + args.batch_size, + num_workers=4, + val_test=True, ) # model @@ -936,8 +1039,12 @@ def kl_paper(p_model: torch.tensor, p_rest: torch.tensor, k, last_hope=True): model.start_optim(args) model.to(model.device) - scheduler = torch.optim.lr_scheduler.ExponentialLR(model.opt, args.exp_decay) - w_scheduler = GradualWarmupScheduler(model.opt, 1.0, args.warmup_steps) + scheduler = torch.optim.lr_scheduler.ExponentialLR( + model.opt, args.exp_decay + ) + w_scheduler = GradualWarmupScheduler( + model.opt, 1.0, args.warmup_steps + ) # Training loop for one model in the ensemble model.train() @@ -950,7 +1057,9 @@ def kl_paper(p_model: torch.tensor, p_rest: torch.tensor, k, last_hope=True): wq = wq.to(model.device) # early stopper - early_stopper = EarlyStopper(patience=5, min_delta=0.001) # prev 0.01 + early_stopper = EarlyStopper( + patience=5, min_delta=0.001 + ) # prev 0.01 for epoch in range(num_epochs): model.train() @@ -971,7 +1080,11 @@ def kl_paper(p_model: torch.tensor, p_rest: torch.tensor, k, last_hope=True): # Generate adversarial examples using x_batch out_dict = model(images) out_dict.update( - {"INPUTS": images, "LABELS": labels, "CONCEPTS": concepts} + { + "INPUTS": images, + "LABELS": labels, + "CONCEPTS": concepts, + } ) loss_original, _ = criterion(out_dict, args) @@ -981,15 +1094,23 @@ def kl_paper(p_model: torch.tensor, p_rest: torch.tensor, k, last_hope=True): loss_original.backward() # Generate adversarial examples - adversarial_batch = images + epsilon * images.grad.sign() + adversarial_batch = ( + images + epsilon * images.grad.sign() + ) # Compute adversarial loss out_dict_adversarial = model(adversarial_batch) out_dict_adversarial.update( - {"INPUTS": images, "LABELS": labels, "CONCEPTS": concepts} + { + "INPUTS": images, + "LABELS": labels, + "CONCEPTS": concepts, + } ) - loss_adversarial, _ = criterion(out_dict_adversarial, args) + loss_adversarial, _ = criterion( + out_dict_adversarial, args + ) # Minimize the combined loss l(θm, xbatch, ybatch) + l(θm, advbatch, advbatch) w.r.t. θm loss_adversarial.backward() @@ -1011,16 +1132,24 @@ def kl_paper(p_model: torch.tensor, p_rest: torch.tensor, k, last_hope=True): from models.utils.ops import outer_product - model_itself_pc_x = outer_product(c_prb_1, c_prb_2).view( - c_prb_1.shape[0], c_prb_1.shape[1] * c_prb_1.shape[1] + model_itself_pc_x = outer_product( + c_prb_1, c_prb_2 + ).view( + c_prb_1.shape[0], + c_prb_1.shape[1] * c_prb_1.shape[1], ) total_dist = 0 if len(ensemble) > 0: if knowledge_aware_kl: - model_itself_pc_x = compute_pw_knowledge_filter( - c_prb_1=c_prb_1, c_prb_2=c_prb_2, labels=labels, wq=wq + model_itself_pc_x = ( + compute_pw_knowledge_filter( + c_prb_1=c_prb_1, + c_prb_2=c_prb_2, + labels=labels, + wq=wq, + ) ) # Create the ensembles world probabilities @@ -1040,8 +1169,12 @@ def kl_paper(p_model: torch.tensor, p_rest: torch.tensor, k, last_hope=True): pcx_list_ensemble.append(other_pcx) if real_kl: - pcx_list_ensemble = torch.stack(pcx_list_ensemble) - other_m_pc_x_mean = torch.mean(pcx_list_ensemble, dim=0) + pcx_list_ensemble = torch.stack( + pcx_list_ensemble + ) + other_m_pc_x_mean = torch.mean( + pcx_list_ensemble, dim=0 + ) distance = lambda_h * kl_paper( model_itself_pc_x, other_m_pc_x_mean, @@ -1049,11 +1182,16 @@ def kl_paper(p_model: torch.tensor, p_rest: torch.tensor, k, last_hope=True): False, ) else: - crit_kl = nn.KLDivLoss(reduction="batchmean")() + crit_kl = nn.KLDivLoss( + reduction="batchmean" + )() distance = -lambda_h * torch.mean( torch.stack( list( - crit_kl(model_itself_pc_x.log(), other_m_pc_x) + crit_kl( + model_itself_pc_x.log(), + other_m_pc_x, + ) for other_m_pc_x in pcx_list_ensemble ) ), @@ -1098,7 +1236,11 @@ def kl_paper(p_model: torch.tensor, p_rest: torch.tensor, k, last_hope=True): out_dict = model(images) out_dict.update( - {"INPUTS": images, "LABELS": labels, "CONCEPTS": concepts} + { + "INPUTS": images, + "LABELS": labels, + "CONCEPTS": concepts, + } ) curr_val_loss, _ = criterion(out_dict, args) @@ -1115,7 +1257,9 @@ def kl_paper(p_model: torch.tensor, p_rest: torch.tensor, k, last_hope=True): validation_loss = validation_loss / len(val_loader) - tloss, cacc, yacc, f1 = evaluate_metrics(model, val_loader, args) + tloss, cacc, yacc, f1 = evaluate_metrics( + model, val_loader, args + ) if use_wandb is not None: prefix = ( f"seed_{seed}_biretta-" @@ -1131,7 +1275,9 @@ def kl_paper(p_model: torch.tensor, p_rest: torch.tensor, k, last_hope=True): lr=float(scheduler.get_last_lr()[0]), ) - fprint(f"># Epoch {epoch}: val loss equal to {validation_loss}") + fprint( + f"># Epoch {epoch}: val loss equal to {validation_loss}" + ) if early_stopper.early_stop(model, validation_loss): break @@ -1185,7 +1331,9 @@ def update_pcx_dataset(model, dataset, pcx_loader, batch_size): c_prb_2 = c_prb[:, 1, :] to_append = torch.chunk( - torch.cat((c_prb_1, c_prb_2), axis=1), images.size(0), dim=0 + torch.cat((c_prb_1, c_prb_2), axis=1), + images.size(0), + dim=0, ) j = 0 @@ -1230,7 +1378,9 @@ def populate_pcx_dataset( # Append the new tensor c_prb_tmp = torch.chunk( - torch.cat((c_prb_1, c_prb_2), axis=1), images.size(0), dim=0 + torch.cat((c_prb_1, c_prb_2), axis=1), + images.size(0), + dim=0, ) tmp_img_list = torch.chunk(images, images.size(0), dim=0) @@ -1342,12 +1492,16 @@ def populate_pcx_dataset_knowledge_aware( for sublist in w_prob_tmp: w_prb_filtered_list.append([sublist]) - dataset = DatasetPcX(images=images_list, pcx=w_prb_filtered_list, wq=wq) + dataset = DatasetPcX( + images=images_list, pcx=w_prb_filtered_list, wq=wq + ) return dataset -def update_pcx_dataset_knowledge_aware(model, dataset, pcx_loader, n_facts): +def update_pcx_dataset_knowledge_aware( + model, dataset, pcx_loader, n_facts +): """Update the dataset with knowledge aware p(w|x) Args: @@ -1424,12 +1578,16 @@ def get_predictions(model, loader): ) out_dict = model(images) - out_dict.update({"INPUTS": images, "LABELS": labels, "CONCEPTS": concepts}) + out_dict.update( + {"INPUTS": images, "LABELS": labels, "CONCEPTS": concepts} + ) return out_dict def ensemble_single_predict( - models: List[nn.Module], batch_samples: torch.tensor, apply_softmax=False + models: List[nn.Module], + batch_samples: torch.tensor, + apply_softmax=False, ): """Single prediction from an ensemble of models @@ -1449,9 +1607,18 @@ def ensemble_single_predict( output_dicts = [model(batch_samples) for model in models] # get out the different output - label_prob = [out_dict["YS"].detach().cpu().numpy() for out_dict in output_dicts] - concept_logit = [out_dict["CS"].detach().cpu().numpy() for out_dict in output_dicts] - concept_prob = [out_dict["pCS"].detach().cpu().numpy() for out_dict in output_dicts] + label_prob = [ + out_dict["YS"].detach().cpu().numpy() + for out_dict in output_dicts + ] + concept_logit = [ + out_dict["CS"].detach().cpu().numpy() + for out_dict in output_dicts + ] + concept_prob = [ + out_dict["pCS"].detach().cpu().numpy() + for out_dict in output_dicts + ] label_prob_ens = np.stack(label_prob, axis=0) concept_logit_ens = np.stack(concept_logit, axis=0) @@ -1464,7 +1631,9 @@ def ensemble_single_predict( def ensemble_single_la_predict( - models: List[nn.Module], batch_samples: torch.tensor, apply_softmax=False + models: List[nn.Module], + batch_samples: torch.tensor, + apply_softmax=False, ): """Single prediction from an ensemble of Laplace models @@ -1484,9 +1653,18 @@ def ensemble_single_la_predict( output_dicts = [model(batch_samples) for model in models] # get out the different output - label_prob = [out_dict["YS"].detach().cpu().numpy() for out_dict in output_dicts] - concept_logit = [out_dict["CS"].detach().cpu().numpy() for out_dict in output_dicts] - concept_prob = [out_dict["pCS"].detach().cpu().numpy() for out_dict in output_dicts] + label_prob = [ + out_dict["YS"].detach().cpu().numpy() + for out_dict in output_dicts + ] + concept_logit = [ + out_dict["CS"].detach().cpu().numpy() + for out_dict in output_dicts + ] + concept_prob = [ + out_dict["pCS"].detach().cpu().numpy() + for out_dict in output_dicts + ] label_prob_ens = np.stack(label_prob, axis=0) concept_logit_ens = np.stack(concept_logit, axis=0) @@ -1533,8 +1711,8 @@ def ensemble_predict( ) # Call Ensemble predict - (label_prob_ens, concept_logit_ens, concept_prob_ens) = ensemble_single_predict( - ensemble, images, apply_softmax + (label_prob_ens, concept_logit_ens, concept_prob_ens) = ( + ensemble_single_predict(ensemble, images, apply_softmax) ) # Concatenate the output @@ -1545,14 +1723,24 @@ def ensemble_predict( c_pred = concept_logit_ens pc_pred = concept_prob_ens else: - y_true = np.concatenate([y_true, labels.detach().cpu().numpy()], axis=0) - c_true = np.concatenate([c_true, concepts.detach().cpu().numpy()], axis=0) + y_true = np.concatenate( + [y_true, labels.detach().cpu().numpy()], axis=0 + ) + c_true = np.concatenate( + [c_true, concepts.detach().cpu().numpy()], axis=0 + ) y_pred = np.concatenate([y_pred, label_prob_ens], axis=1) - c_pred = np.concatenate([c_pred, concept_logit_ens], axis=1) - pc_pred = np.concatenate([pc_pred, concept_prob_ens], axis=1) + c_pred = np.concatenate( + [c_pred, concept_logit_ens], axis=1 + ) + pc_pred = np.concatenate( + [pc_pred, concept_prob_ens], axis=1 + ) # Compute the final arrangements - gs = np.split(c_true, c_true.shape[1], axis=1) # splitted groundtruth concepts + gs = np.split( + c_true, c_true.shape[1], axis=1 + ) # splitted groundtruth concepts cs = np.split( c_pred, c_pred.shape[2], axis=2 ) # splitted concepts # (nmod, data, 10, 10) @@ -1569,15 +1757,21 @@ def ensemble_predict( gs = np.where(gs == "-1-1", -1, gs) gs = gs.squeeze(-1).astype(int) - p_cs_1 = np.expand_dims(p_cs[0].squeeze(2), axis=-1) # 30, 256, 10, 1 - p_cs_2 = np.expand_dims(p_cs[1].squeeze(2), axis=-2) # 30, 256, 1, 10 + p_cs_1 = np.expand_dims( + p_cs[0].squeeze(2), axis=-1 + ) # 30, 256, 10, 1 + p_cs_2 = np.expand_dims( + p_cs[1].squeeze(2), axis=-2 + ) # 30, 256, 1, 10 p_cs = np.matmul( p_cs_1, p_cs_2 ) # 30, 256, 10, 10 -> # [#modelli, #data, #facts, #facts] p_cs = np.reshape( p_cs, (*p_cs.shape[:-2], p_cs.shape[-1] * p_cs.shape[-2]) ) # -> [#modelli, #data, #facts^2] - p_cs = np.mean(p_cs, axis=0) # avg[#modelli, #data, #facts^2] = [#data, #facts^2] + p_cs = np.mean( + p_cs, axis=0 + ) # avg[#modelli, #data, #facts^2] = [#data, #facts^2] # mean probabilities of the output p_ys = np.mean(y_pred, axis=0) # -> (256, 19) @@ -1585,8 +1779,12 @@ def ensemble_predict( p_ys_full = p_ys # all the items of probabilities are considered (#data, #facts^2) p_cs_full = p_cs # all the items of probabilities are considered (#data, #facts^2) - cs = p_cs.argmax(axis=1) # the predicted concept is the argument maximum - p_cs = p_cs.max(axis=1) # only the maximum one is considered (#data,) + cs = p_cs.argmax( + axis=1 + ) # the predicted concept is the argument maximum + p_cs = p_cs.max( + axis=1 + ) # only the maximum one is considered (#data,) p_ys = p_ys.max(axis=1) # only the maximum one is considered assert gs.shape == cs.shape, f"gs: {gs.shape}, cs: {cs.shape}" @@ -1612,7 +1810,9 @@ def ensemble_predict( ) -def laplace_approximation(model: nn.Module, device, train_loader, val_loader): +def laplace_approximation( + model: nn.Module, device, train_loader, val_loader +): """Performs the Laplace Approximation Args: @@ -1624,8 +1824,8 @@ def laplace_approximation(model: nn.Module, device, train_loader, val_loader): Returns: la: laplace model """ - from torch.utils.data import DataLoader from laplace.curvature import AsdlGGN + from torch.utils.data import DataLoader # Wrapper DataLoader class WrapperDataLoader(DataLoader): @@ -1656,10 +1856,15 @@ def __iter__(self): iter: dataset iterator """ # Get the iterator from the original DataLoader - original_iterator = super(WrapperDataLoader, self).__iter__() + original_iterator = super( + WrapperDataLoader, self + ).__iter__() for original_batch in original_iterator: - modified_batch = [original_batch[0], original_batch[1].to(torch.long)] + modified_batch = [ + original_batch[0], + original_batch[1].to(torch.long), + ] yield modified_batch # Wrapper Model @@ -1713,7 +1918,12 @@ def forward(self, input_batch): # I want to flat all the tensors in this way: return torch.cat( - (ys, py.reshape(batch_size, -1), pCS.reshape(batch_size, -1)), dim=1 + ( + ys, + py.reshape(batch_size, -1), + pCS.reshape(batch_size, -1), + ), + dim=1, ) def get_ensembles(self, la_model, n_models): @@ -1732,15 +1942,21 @@ def get_ensembles(self, la_model, n_models): ensembles = [] for i, mp in enumerate(self.model_possibilities): # substituting to the current model one of the possible parameters - vector_to_parameters(mp, la_model.model.last_layer.parameters()) + vector_to_parameters( + mp, la_model.model.last_layer.parameters() + ) # Retrieve the current model and append it - ensembles.append(copy.deepcopy(la_model.model.model.original_model)) + ensembles.append( + copy.deepcopy(la_model.model.model.original_model) + ) if i == n_models - 1: break # restore original model - vector_to_parameters(la_model.mean, la_model.model.last_layer.parameters()) + vector_to_parameters( + la_model.mean, la_model.model.last_layer.parameters() + ) # return an ensembles of models return ensembles @@ -1762,7 +1978,9 @@ def get_ensembles(self, la_model, n_models): ) la.fit(la_training_loader) - la.optimize_prior_precision(method="marglik", val_loader=la_val_loader) + la.optimize_prior_precision( + method="marglik", val_loader=la_val_loader + ) # Enabling last layer output all la.model.model.output_all = True @@ -1823,7 +2041,11 @@ def sample(self, n_samples=100): It seems like it batch multply the samples (from a gaussian centered in zero with 1 as variance), and then rescale it """ recovered_pred = recover_predictions_from_laplace( - pred, sample_batch.shape[0], output_classes, num_concepts, apply_softmax + pred, + sample_batch.shape[0], + output_classes, + num_concepts, + apply_softmax, ) return recovered_pred @@ -1849,15 +2071,21 @@ def recover_predictions_from_laplace( out_dict: dictionary of predictions """ # Recovering shape - ys = la_prediction[:, :output_classes] # take all until output_classes + ys = la_prediction[ + :, :output_classes + ] # take all until output_classes py = la_prediction[ :, output_classes : output_classes + 2 * num_concepts ] # take all from output_classes until output_classes+2*num_concepts - py = py.reshape(batch_size, 2, num_concepts) # reshape it correctly + py = py.reshape( + batch_size, 2, num_concepts + ) # reshape it correctly pCS = la_prediction[ :, output_classes + 2 * num_concepts : ] # take all from the previous to the end - pCS = pCS.reshape(batch_size, 2, num_concepts) # reshape it correctly + pCS = pCS.reshape( + batch_size, 2, num_concepts + ) # reshape it correctly if apply_softmax: import torch.nn.functional as F @@ -1911,15 +2139,23 @@ def laplace_prediction( # prediction _ = laplace_single_prediction( - laplace_model, images, output_classes, num_concepts, apply_softmax + laplace_model, + images, + output_classes, + num_concepts, + apply_softmax, ) # Call Laplace ensembles - ensemble = laplace_model.model.model.get_ensembles(laplace_model, n_ensembles) + ensemble = laplace_model.model.model.get_ensembles( + laplace_model, n_ensembles + ) # Call Ensemble predict (label_prob_ens, concept_logit_ens, concept_prob_ens) = ( - ensemble_single_la_predict(ensemble, images, apply_softmax) + ensemble_single_la_predict( + ensemble, images, apply_softmax + ) ) # Concatenate the output @@ -1930,14 +2166,24 @@ def laplace_prediction( c_pred = concept_logit_ens pc_pred = concept_prob_ens else: - y_true = np.concatenate([y_true, labels.detach().cpu().numpy()], axis=0) - c_true = np.concatenate([c_true, concepts.detach().cpu().numpy()], axis=0) + y_true = np.concatenate( + [y_true, labels.detach().cpu().numpy()], axis=0 + ) + c_true = np.concatenate( + [c_true, concepts.detach().cpu().numpy()], axis=0 + ) y_pred = np.concatenate([y_pred, label_prob_ens], axis=1) - c_pred = np.concatenate([c_pred, concept_logit_ens], axis=1) - pc_pred = np.concatenate([pc_pred, concept_prob_ens], axis=1) + c_pred = np.concatenate( + [c_pred, concept_logit_ens], axis=1 + ) + pc_pred = np.concatenate( + [pc_pred, concept_prob_ens], axis=1 + ) # Compute the final arrangements - gs = np.split(c_true, c_true.shape[1], axis=1) # splitted groundtruth concepts + gs = np.split( + c_true, c_true.shape[1], axis=1 + ) # splitted groundtruth concepts cs = np.split( c_pred, c_pred.shape[2], axis=2 ) # splitted concepts # (nmod, data, 10, 10) @@ -1954,15 +2200,21 @@ def laplace_prediction( gs = np.where(gs == "-1-1", -1, gs) gs = gs.squeeze(-1).astype(int) - p_cs_1 = np.expand_dims(p_cs[0].squeeze(2), axis=-1) # 30, 256, 10, 1 - p_cs_2 = np.expand_dims(p_cs[1].squeeze(2), axis=-2) # 30, 256, 1, 10 + p_cs_1 = np.expand_dims( + p_cs[0].squeeze(2), axis=-1 + ) # 30, 256, 10, 1 + p_cs_2 = np.expand_dims( + p_cs[1].squeeze(2), axis=-2 + ) # 30, 256, 1, 10 p_cs = np.matmul( p_cs_1, p_cs_2 ) # 30, 256, 10, 10 -> # [#modelli, #data, #facts, #facts] p_cs = np.reshape( p_cs, (*p_cs.shape[:-2], p_cs.shape[-1] * p_cs.shape[-2]) ) # -> [#modelli, #data, #facts^2] - p_cs = np.mean(p_cs, axis=0) # avg[#modelli, #data, #facts^2] = [#data, #facts^2] + p_cs = np.mean( + p_cs, axis=0 + ) # avg[#modelli, #data, #facts^2] = [#data, #facts^2] # mean probabilities of the output p_ys = np.mean(y_pred, axis=0) # -> (256, 19) @@ -1970,8 +2222,12 @@ def laplace_prediction( p_ys_full = p_ys # all the items of probabilities are considered (#data, #facts^2) p_cs_full = p_cs # all the items of probabilities are considered (#data, #facts^2) - cs = p_cs.argmax(axis=1) # the predicted concept is the argument maximum - p_cs = p_cs.max(axis=1) # only the maximum one is considered (#data,) + cs = p_cs.argmax( + axis=1 + ) # the predicted concept is the argument maximum + p_cs = p_cs.max( + axis=1 + ) # only the maximum one is considered (#data,) p_ys = p_ys.max(axis=1) # only the maximum one is considered assert gs.shape == cs.shape, f"gs: {gs.shape}, cs: {cs.shape}" diff --git a/XOR_MNIST/utils/checkpoint.py b/XOR_MNIST/utils/checkpoint.py index 1813e1c..8fdf1e8 100644 --- a/XOR_MNIST/utils/checkpoint.py +++ b/XOR_MNIST/utils/checkpoint.py @@ -1,6 +1,7 @@ # Checkpoint module -import torch import os + +import torch from utils.conf import create_path @@ -14,7 +15,10 @@ def _get_tag(args): tag (str): tag for the model name """ tag = "dis" if not args.joint else "joint" - if args.task == "product" and args.model in ["mnistsl", "mnistslrec"]: + if args.task == "product" and args.model in [ + "mnistsl", + "mnistslrec", + ]: tag = tag + "-prod" if args.task == "multiop": tag = tag + "-multiop" @@ -82,7 +86,9 @@ def get_model_name(args): Returns: name (str): name of the model """ - return f"dset_{args.dataset}-model_{args.model}-tag_{_get_tag(args)}" + return ( + f"dset_{args.dataset}-model_{args.model}-tag_{_get_tag(args)}" + ) def load_checkpoint(model, args, checkin=None): @@ -105,7 +111,9 @@ def load_checkpoint(model, args, checkin=None): PATH = f"data/ckpts/{args.dataset}-{args.model}-{tag}-{args.seed}-end.pt" if not os.path.exists(PATH): - raise ValueError(f"You have to train the model first, missing {PATH}") + raise ValueError( + f"You have to train the model first, missing {PATH}" + ) print("Loaded", PATH, "\n") print("Path", PATH) diff --git a/XOR_MNIST/utils/conf.py b/XOR_MNIST/utils/conf.py index e1b3177..65846ff 100644 --- a/XOR_MNIST/utils/conf.py +++ b/XOR_MNIST/utils/conf.py @@ -1,9 +1,10 @@ # config module +import os import random -import torch + import numpy as np -import os +import torch def get_device() -> torch.device: @@ -13,7 +14,9 @@ def get_device() -> torch.device: device: device """ # return torch.device('cpu') #debug - return torch.device("cuda" if torch.cuda.is_available() else "cpu") + return torch.device( + "cuda" if torch.cuda.is_available() else "cpu" + ) def base_path() -> str: diff --git a/XOR_MNIST/utils/generative.py b/XOR_MNIST/utils/generative.py index 7a8ead3..e647d88 100644 --- a/XOR_MNIST/utils/generative.py +++ b/XOR_MNIST/utils/generative.py @@ -17,17 +17,25 @@ def conditional_gen(model, pC=None): # select whether generate at random or not if pC is None: pC = 5 * torch.randn( - (8, model.n_images, model.encoder.c_dim), device=model.device + (8, model.n_images, model.encoder.c_dim), + device=model.device, ) # pC = torch.softmax(pC, dim=-1) - zs = torch.randn((8, model.n_images, model.encoder.latent_dim), device=model.device) + zs = torch.randn( + (8, model.n_images, model.encoder.latent_dim), + device=model.device, + ) latents = [] for _ in range(model.n_images): for i in range(len(model.c_split)): latents.append(zs[:, i, :]) - latents.append(F.gumbel_softmax(pC[:, i, :], tau=1, hard=True, dim=-1)) + latents.append( + F.gumbel_softmax( + pC[:, i, :], tau=1, hard=True, dim=-1 + ) + ) # generated images decode = model.decoder(torch.cat(latents, dim=1)).detach() diff --git a/XOR_MNIST/utils/losses.py b/XOR_MNIST/utils/losses.py index 7d152c9..5534674 100644 --- a/XOR_MNIST/utils/losses.py +++ b/XOR_MNIST/utils/losses.py @@ -1,6 +1,6 @@ # Losses module -import torch import numpy as np +import torch import torch.nn.functional as F @@ -84,8 +84,13 @@ def ADDMNIST_REC_Match(out_dict: dict, args): assert inputs.size() == recs.size(), f"{len(inputs)}-{len(recs)}" - recon = F.binary_cross_entropy(recs.view(L, -1), inputs.view(L, -1)) - kld = (-0.5 * (1 + logvars - mus**2 - logvars.exp()).sum(1).mean() - 1).abs() + recon = F.binary_cross_entropy( + recs.view(L, -1), inputs.view(L, -1) + ) + kld = ( + -0.5 * (1 + logvars - mus**2 - logvars.exp()).sum(1).mean() + - 1 + ).abs() losses = {"recon-loss": recon.item(), "kld": kld.item()} @@ -116,7 +121,11 @@ def ADDMNIST_Entropy(out_dict, args): loss = 0 for i in range(p_mean.size(0)): - loss -= torch.sum(p_mean[i] * p_mean[i].log()) / np.log(10) / p_mean.size(0) + loss -= ( + torch.sum(p_mean[i] * p_mean[i].log()) + / np.log(10) + / p_mean.size(0) + ) losses = {"H-loss": 1 - loss} @@ -206,8 +215,12 @@ def KAND_Classification(out_dict: dict, args): ) final_weight = torch.tensor([0.5, 0.5], device=preds.device) elif args.task == "red_triangle": - weight = torch.tensor([0.35538, 1 - 0.35538], device=preds.device) - final_weight = torch.tensor([0.04685, 1 - 0.04685], device=preds.device) + weight = torch.tensor( + [0.35538, 1 - 0.35538], device=preds.device + ) + final_weight = torch.tensor( + [0.04685, 1 - 0.04685], device=preds.device + ) else: weight = torch.tensor([0.5, 0.5], device=out.device) final_weight = torch.tensor([0.5, 0.5], device=out.device) @@ -215,7 +228,10 @@ def KAND_Classification(out_dict: dict, args): if args.model in ["kanddpl"]: # loss = torch.tensor(1e-5) loss = F.nll_loss( - out.log(), final_labels, reduction="mean", weight=final_weight + out.log(), + final_labels, + reduction="mean", + weight=final_weight, ) # for i in range( inter_labels.shape[-1]): @@ -303,7 +319,11 @@ def KAND_Entropy(out_dict, args): loss = 0 for i in range(p_mean.size(0)): - loss -= torch.sum(p_mean[i] * p_mean[i].log()) / np.log(10) / p_mean.size(0) + loss -= ( + torch.sum(p_mean[i] * p_mean[i].log()) + / np.log(10) + / p_mean.size(0) + ) losses = {"H-loss": 1 - loss} diff --git a/XOR_MNIST/utils/ltn_loss.py b/XOR_MNIST/utils/ltn_loss.py index f03b5b0..c4e80c0 100644 --- a/XOR_MNIST/utils/ltn_loss.py +++ b/XOR_MNIST/utils/ltn_loss.py @@ -1,7 +1,8 @@ # Module which identifies an LTN loss +import itertools + import ltn import torch -import itertools from utils.normal_kl_divergence import kl_divergence @@ -24,7 +25,9 @@ def forward(self, x, d): class ADDMNIST_SAT_AGG(torch.nn.Module): - def __init__(self, loss, task="addition", nr_classes=19, pcbm=False) -> None: + def __init__( + self, loss, task="addition", nr_classes=19, pcbm=False + ) -> None: super().__init__() self.base_loss = loss self.task = task @@ -145,8 +148,12 @@ def ADDMNISTsat_agg_loss(eltn, p1, p2, labels, grade): b_2 = ltn.Variable("b_2", torch.tensor(range(max_c))) And = ltn.Connective(ltn.fuzzy_ops.AndProd()) - Exists = ltn.Quantifier(ltn.fuzzy_ops.AggregPMean(p=2), quantifier="e") - Forall = ltn.Quantifier(ltn.fuzzy_ops.AggregPMeanError(p=2), quantifier="f") + Exists = ltn.Quantifier( + ltn.fuzzy_ops.AggregPMean(p=2), quantifier="e" + ) + Forall = ltn.Quantifier( + ltn.fuzzy_ops.AggregPMeanError(p=2), quantifier="f" + ) SatAgg = ltn.fuzzy_ops.SatAgg() sat_agg = Forall( @@ -155,7 +162,9 @@ def ADDMNISTsat_agg_loss(eltn, p1, p2, labels, grade): [b_1, b_2], And(eltn(bit1, b_1), eltn(bit2, b_2)), cond_vars=[b_1, b_2, y_true], - cond_fn=lambda d1, d2, z: torch.eq((d1.value + d2.value), z.value), + cond_fn=lambda d1, d2, z: torch.eq( + (d1.value + d2.value), z.value + ), p=grade, ), ) @@ -192,8 +201,12 @@ def PRODMNISTsat_agg_loss(eltn, p1, p2, labels, grade): b_2 = ltn.Variable("b_2", torch.tensor(range(max_c))) And = ltn.Connective(ltn.fuzzy_ops.AndProd()) - Exists = ltn.Quantifier(ltn.fuzzy_ops.AggregPMean(p=2), quantifier="e") - Forall = ltn.Quantifier(ltn.fuzzy_ops.AggregPMeanError(p=2), quantifier="f") + Exists = ltn.Quantifier( + ltn.fuzzy_ops.AggregPMean(p=2), quantifier="e" + ) + Forall = ltn.Quantifier( + ltn.fuzzy_ops.AggregPMeanError(p=2), quantifier="f" + ) SatAgg = ltn.fuzzy_ops.SatAgg() sat_agg = Forall( @@ -202,7 +215,9 @@ def PRODMNISTsat_agg_loss(eltn, p1, p2, labels, grade): [b_1, b_2], And(eltn(bit1, b_1), eltn(bit2, b_2)), cond_vars=[b_1, b_2, y_true], - cond_fn=lambda b_1, b_2, z: torch.eq(b_1.value * b_2.value, z.value), + cond_fn=lambda b_1, b_2, z: torch.eq( + b_1.value * b_2.value, z.value + ), p=grade, ), ).value @@ -234,8 +249,12 @@ def MULTIOPsat_agg_loss(eltn, p1, p2, labels, grade): b_2 = ltn.Variable("b_2", torch.tensor(range(max_c))) And = ltn.Connective(ltn.fuzzy_ops.AndProd()) - Exists = ltn.Quantifier(ltn.fuzzy_ops.AggregPMean(p=2), quantifier="e") - Forall = ltn.Quantifier(ltn.fuzzy_ops.AggregPMeanError(p=2), quantifier="f") + Exists = ltn.Quantifier( + ltn.fuzzy_ops.AggregPMean(p=2), quantifier="e" + ) + Forall = ltn.Quantifier( + ltn.fuzzy_ops.AggregPMeanError(p=2), quantifier="f" + ) SatAgg = ltn.fuzzy_ops.SatAgg() sat_agg = Forall( @@ -245,7 +264,8 @@ def MULTIOPsat_agg_loss(eltn, p1, p2, labels, grade): And(eltn(bit1, b_1), eltn(bit2, b_2)), cond_vars=[b_1, b_2, y_true], cond_fn=lambda b_1, b_2, z: torch.eq( - b_1.value**2 + b_2.value**2 + b_1.value * b_2.value, z.value + b_1.value**2 + b_2.value**2 + b_1.value * b_2.value, + z.value, ), p=grade, ), diff --git a/XOR_MNIST/utils/metrics.py b/XOR_MNIST/utils/metrics.py index 3e635d2..77849fd 100644 --- a/XOR_MNIST/utils/metrics.py +++ b/XOR_MNIST/utils/metrics.py @@ -1,15 +1,20 @@ # This module contains the computation of the metrics -import torch -import numpy as np import math -from numpy import ndarray -from typing import Tuple, Dict, List -import torch.nn as nn +from typing import Dict, List, Tuple + +import numpy as np import torch +import torch.nn as nn import torch.nn.functional as F -from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score +from numpy import ndarray from scipy.special import softmax +from sklearn.metrics import ( + accuracy_score, + f1_score, + precision_score, + recall_score, +) def evaluate_mix(true, pred): @@ -73,7 +78,9 @@ def evaluate_metrics( ) out_dict = model(images) - out_dict.update({"INPUTS": images, "LABELS": labels, "CONCEPTS": concepts}) + out_dict.update( + {"INPUTS": images, "LABELS": labels, "CONCEPTS": concepts} + ) logits = out_dict["YS"] if last and i == 0: @@ -83,24 +90,43 @@ def evaluate_metrics( c_pred = out_dict["CS"].detach().cpu().numpy() pc_pred = out_dict["pCS"].detach().cpu().numpy() elif last and i > 0: - y_true = np.concatenate([y_true, labels.detach().cpu().numpy()], axis=0) - c_true = np.concatenate([c_true, concepts.detach().cpu().numpy()], axis=0) + y_true = np.concatenate( + [y_true, labels.detach().cpu().numpy()], axis=0 + ) + c_true = np.concatenate( + [c_true, concepts.detach().cpu().numpy()], axis=0 + ) y_pred = np.concatenate( - [y_pred, out_dict["YS"].detach().cpu().numpy()], axis=0 + [y_pred, out_dict["YS"].detach().cpu().numpy()], + axis=0, ) c_pred = np.concatenate( - [c_pred, out_dict["CS"].detach().cpu().numpy()], axis=0 + [c_pred, out_dict["CS"].detach().cpu().numpy()], + axis=0, ) pc_pred = np.concatenate( - [pc_pred, out_dict["pCS"].detach().cpu().numpy()], axis=0 + [pc_pred, out_dict["pCS"].detach().cpu().numpy()], + axis=0, ) if ( - args.dataset in ["addmnist", "shortmnist", "restrictedmnist", "halfmnist"] + args.dataset + in [ + "addmnist", + "shortmnist", + "restrictedmnist", + "halfmnist", + ] and not last ): - loss, ac, acc = ADDMNIST_eval_tloss_cacc_acc(out_dict, concepts) - elif args.dataset in ["kandinsky", "prekandinsky", "minikandinsky"]: + loss, ac, acc = ADDMNIST_eval_tloss_cacc_acc( + out_dict, concepts + ) + elif args.dataset in [ + "kandinsky", + "prekandinsky", + "minikandinsky", + ]: loss, ac, acc, f1 = KAND_eval_tloss_cacc_acc(out_dict) else: NotImplementedError() @@ -127,7 +153,11 @@ def evaluate_metrics( gs = np.concatenate(gs, axis=0).squeeze(1) - if args.dataset not in ["kandinsky", "prekandinsky", "minikandinsky"]: + if args.dataset not in [ + "kandinsky", + "prekandinsky", + "minikandinsky", + ]: cs = np.concatenate(cs, axis=0).squeeze(1).argmax(axis=-1) p_cs = ( np.concatenate(p_cs, axis=0).squeeze(1).max(axis=-1) @@ -135,17 +165,24 @@ def evaluate_metrics( else: cs = np.concatenate(cs, axis=0).squeeze(1) - p_cs = np.concatenate(p_cs, axis=0).squeeze(1).max(axis=-1) + p_cs = ( + np.concatenate(p_cs, axis=0).squeeze(1).max(axis=-1) + ) cs = np.split(cs, 6, axis=-1) p_cs = np.split(p_cs, 6, axis=-1) cs = np.concatenate( - [np.argmax(c, axis=-1).reshape(-1, 1) for c in cs], axis=-1 + [np.argmax(c, axis=-1).reshape(-1, 1) for c in cs], + axis=-1, ) p_cs = np.concatenate( - [np.argmax(pc, axis=-1).reshape(-1, 1) for pc in p_cs], axis=-1 + [ + np.argmax(pc, axis=-1).reshape(-1, 1) + for pc in p_cs + ], + axis=-1, ) p_cs_all = np.concatenate(p_cs_all, axis=0).squeeze( @@ -161,7 +198,11 @@ def evaluate_metrics( cs = c_pred.argmax(axis=2) return y_true, gs, ys, cs, p_cs, p_ys, p_cs_all, p_ys_all else: - if args.dataset in ["kandinsky", "prekandinsky", "minikandinsky"]: + if args.dataset in [ + "kandinsky", + "prekandinsky", + "minikandinsky", + ]: return tloss / L, cacc / L, yacc / L, f1sc / L else: return tloss / L, cacc / L, yacc / L, 0 @@ -196,7 +237,13 @@ def evaluate_metrics_ensemble(ensemble, loader, args, last=True): ) out_dict = model(images) - out_dict.update({"INPUTS": images, "LABELS": labels, "CONCEPTS": concepts}) + out_dict.update( + { + "INPUTS": images, + "LABELS": labels, + "CONCEPTS": concepts, + } + ) if last and i == 0: y_true = labels.detach().cpu().numpy() @@ -205,18 +252,23 @@ def evaluate_metrics_ensemble(ensemble, loader, args, last=True): c_pred = out_dict["CS"].detach().cpu().numpy() pc_pred = out_dict["pCS"].detach().cpu().numpy() elif last and i > 0: - y_true = np.concatenate([y_true, labels.detach().cpu().numpy()], axis=0) + y_true = np.concatenate( + [y_true, labels.detach().cpu().numpy()], axis=0 + ) c_true = np.concatenate( [c_true, concepts.detach().cpu().numpy()], axis=0 ) y_pred = np.concatenate( - [y_pred, out_dict["YS"].detach().cpu().numpy()], axis=0 + [y_pred, out_dict["YS"].detach().cpu().numpy()], + axis=0, ) c_pred = np.concatenate( - [c_pred, out_dict["CS"].detach().cpu().numpy()], axis=0 + [c_pred, out_dict["CS"].detach().cpu().numpy()], + axis=0, ) pc_pred = np.concatenate( - [pc_pred, out_dict["pCS"].detach().cpu().numpy()], axis=0 + [pc_pred, out_dict["pCS"].detach().cpu().numpy()], + axis=0, ) y_true_all.append(y_true) @@ -296,10 +348,14 @@ def evaluate_metrics_ensemble(ensemble, loader, args, last=True): cacc /= n_figures * n_concepts y = torch.tensor(y_pred_all) # out_dict['YS'] - y_true = torch.tensor(y_true_all)[:, -1] # out_dict['LABELS'][:,-1] + y_true = torch.tensor(y_true_all)[ + :, -1 + ] # out_dict['LABELS'][:,-1] y_pred = torch.argmax(y, dim=-1) - assert y_pred.size() == y_true.size(), f"size c_pred: {c_pred.size()}" + assert ( + y_pred.size() == y_true.size() + ), f"size c_pred: {c_pred.size()}" acc = (y_pred == y_true).sum().item() / len(y_true) @@ -451,7 +507,9 @@ def get_world_probabilities_matrix( c_prb_2_expanded = np.expand_dims(c_prb_2, axis=1) # 256, 1, 10 # Compute the outer product to get c_prbs - decomposed_world_prob = np.matmul(c_prb_1_expanded, c_prb_2_expanded) # 256, 10, 10 + decomposed_world_prob = np.matmul( + c_prb_1_expanded, c_prb_2_expanded + ) # 256, 10, 10 # Reshape c_prbs to get worlds_prob worlds_prob = decomposed_world_prob.reshape( @@ -490,14 +548,18 @@ def get_mean_world_probability( world_counter[world_label] = 0 # get that world concept probability - mean_world_prob[world_label] += decomposed_world_prob[i, c_1, c_2] + mean_world_prob[world_label] += decomposed_world_prob[ + i, c_1, c_2 + ] # count that world world_counter[world_label] += 1 # Normalize for el in mean_world_prob: mean_world_prob[el] = ( - 0 if world_counter[el] == 0 else mean_world_prob[el] / world_counter[el] + 0 + if world_counter[el] == 0 + else mean_world_prob[el] / world_counter[el] ) return mean_world_prob, world_counter @@ -574,7 +636,9 @@ def get_alpha_single( # loop over the grountruth concept of each sample (concept) for i, c in enumerate(c_true): - world_label = str(c) # get the world as a string for the key of the dictionary + world_label = str( + c + ) # get the world as a string for the key of the dictionary # fill if it is zero if world_label not in world_counter: @@ -631,9 +695,12 @@ def get_concept_probability(model, loader): c_true = concepts.detach().cpu().numpy() else: c_prb = np.concatenate( - [c_prb, out_dict["pCS"].detach().cpu().numpy()], axis=0 + [c_prb, out_dict["pCS"].detach().cpu().numpy()], + axis=0, + ) + c_true = np.concatenate( + [c_true, concepts.detach().cpu().numpy()], axis=0 ) - c_true = np.concatenate([c_true, concepts.detach().cpu().numpy()], axis=0) c_prb_1 = c_prb[:, 0, :] # [#dati, #facts] c_prb_2 = c_prb[:, 1, :] # [#dati, #facts] @@ -686,7 +753,8 @@ def get_concept_probability_ensemble(models, loader): c_prb = out_dict["pCS"].detach().cpu().numpy() else: c_prb = np.concatenate( - [c_prb, out_dict["pCS"].detach().cpu().numpy()], axis=0 + [c_prb, out_dict["pCS"].detach().cpu().numpy()], + axis=0, ) c_prb_1 = c_prb[:, 0, :] @@ -696,7 +764,9 @@ def get_concept_probability_ensemble(models, loader): ensemble_c_prb_2.append(c_prb_2) # Average for each model in the ensemble - return calculate_mean_pCs(ensemble_c_prb_1, ensemble_c_prb_2, len(ensemble_c_prb_1)) + return calculate_mean_pCs( + ensemble_c_prb_1, ensemble_c_prb_2, len(ensemble_c_prb_1) + ) def get_concept_probability_factorized_ensemble(models, loader): @@ -734,7 +804,8 @@ def get_concept_probability_factorized_ensemble(models, loader): c_true = concepts.detach().cpu().numpy() else: c_prb = np.concatenate( - [c_prb, out_dict["pCS"].detach().cpu().numpy()], axis=0 + [c_prb, out_dict["pCS"].detach().cpu().numpy()], + axis=0, ) c_true = np.concatenate( [c_true, concepts.detach().cpu().numpy()], axis=0 @@ -798,7 +869,8 @@ def get_concept_probability_factorized_mcdropout( c_true = concepts.detach().cpu().numpy() else: c_prb = np.concatenate( - [c_prb, out_dict["pCS"].detach().cpu().numpy()], axis=0 + [c_prb, out_dict["pCS"].detach().cpu().numpy()], + axis=0, ) c_true = np.concatenate( [c_true, concepts.detach().cpu().numpy()], axis=0 @@ -819,7 +891,11 @@ def get_concept_probability_factorized_mcdropout( avg_c_prb_1 = np.mean(ensemble_c_prb_1, axis=0) avg_c_prb_2 = np.mean(ensemble_c_prb_2, axis=0) - return avg_c_prb_1, avg_c_prb_2, gt_factorized # (6000,10), (6000,10) + return ( + avg_c_prb_1, + avg_c_prb_2, + gt_factorized, + ) # (6000,10), (6000,10) def get_concept_probability_mcdropout( @@ -859,7 +935,8 @@ def get_concept_probability_mcdropout( c_prb = out_dict["pCS"].detach().cpu().numpy() else: c_prb = np.concatenate( - [c_prb, out_dict["pCS"].detach().cpu().numpy()], axis=0 + [c_prb, out_dict["pCS"].detach().cpu().numpy()], + axis=0, ) c_prb_1 = c_prb[:, 0, :] @@ -869,11 +946,15 @@ def get_concept_probability_mcdropout( ensemble_c_prb_2.append(c_prb_2) # Average for each model in the ensemble - return calculate_mean_pCs(ensemble_c_prb_1, ensemble_c_prb_2, len(ensemble_c_prb_1)) + return calculate_mean_pCs( + ensemble_c_prb_1, ensemble_c_prb_2, len(ensemble_c_prb_1) + ) def calculate_mean_pCs( - ensemble_c_prb_1: ndarray, ensemble_c_prb_2: ndarray, ensamble_len: int + ensemble_c_prb_1: ndarray, + ensemble_c_prb_2: ndarray, + ensamble_len: int, ) -> ndarray: """Get concept probabilities for mcdropout @@ -908,7 +989,13 @@ def _bin_initializer(num_bins: int) -> Dict[int, Dict[str, int]]: """ # Builds the bin return { - i: {"COUNT": 0, "CONF": 0, "ACC": 0, "BIN_ACC": 0, "BIN_CONF": 0} + i: { + "COUNT": 0, + "CONF": 0, + "ACC": 0, + "BIN_ACC": 0, + "BIN_CONF": 0, + } for i in range(num_bins) } @@ -951,7 +1038,10 @@ def _populate_bins( def expected_calibration_error( - confs: ndarray, preds: ndarray, labels: ndarray, num_bins: int = 10 + confs: ndarray, + preds: ndarray, + labels: ndarray, + num_bins: int = 10, ) -> Tuple[float, Dict[str, float]]: """Computes the ECE @@ -967,7 +1057,9 @@ def expected_calibration_error( # Perfect calibration is achieved when the ECE is zero # Formula: ECE = sum 1 upto M of number of elements in bin m|Bm| over number of samples across all bins (n), times |(Accuracy of Bin m Bm) - Confidence of Bin m Bm)| - bin_dict = _populate_bins(confs, preds, labels, num_bins) # populate the bins + bin_dict = _populate_bins( + confs, preds, labels, num_bins + ) # populate the bins num_samples = len(labels) # number of samples (n) ece = sum( (bin_info["BIN_ACC"] - bin_info["BIN_CONF"]).__abs__() @@ -1020,9 +1112,9 @@ def entropy(probabilities: ndarray, n_values: int): probabilities += 1e-5 probabilities /= 1 + (n_values * 1e-5) - entropy_values = -np.sum(probabilities * np.log(probabilities), axis=1) / np.log( - n_values - ) + entropy_values = -np.sum( + probabilities * np.log(probabilities), axis=1 + ) / np.log(n_values) return entropy_values @@ -1054,9 +1146,9 @@ def variance(probabilities: np.ndarray, n_values: int): # Compute variance along the columns mean_values = np.mean(probabilities, axis=1, keepdims=True) # Var(X) = E[(X - mu)^2] c= 1/n-1 (X - mu)**2 (unbiased estimator) - variance_values = np.sum((probabilities - mean_values) ** 2, axis=1) / ( - n_values - 1 - ) + variance_values = np.sum( + (probabilities - mean_values) ** 2, axis=1 + ) / (n_values - 1) return variance_values @@ -1097,7 +1189,9 @@ def class_mean_entropy( class_counts = np.zeros(n_classes) for i in range(num_samples): - sample_entropy = entropy(np.expand_dims(probabilities[i], axis=0), num_classes) + sample_entropy = entropy( + np.expand_dims(probabilities[i], axis=0), num_classes + ) class_mean_entropy_values[true_classes[i]] += sample_entropy class_counts[true_classes[i]] += 1 @@ -1146,7 +1240,12 @@ def class_mean_variance( def get_concept_probability_factorized_laplace( - device, loader, laplace_single_prediction, la, output_classes, num_concepts + device, + loader, + laplace_single_prediction, + la, + output_classes, + num_concepts, ): """Function which computes the factorized concept probability for Laplace @@ -1171,16 +1270,21 @@ def get_concept_probability_factorized_laplace( concepts.to(device), ) - out_dict = laplace_single_prediction(la, images, output_classes, num_concepts) + out_dict = laplace_single_prediction( + la, images, output_classes, num_concepts + ) if i == 0: c_prb = out_dict["pCS"].detach().cpu().numpy() c_true = concepts.detach().cpu().numpy() else: c_prb = np.concatenate( - [c_prb, out_dict["pCS"].detach().cpu().numpy()], axis=0 + [c_prb, out_dict["pCS"].detach().cpu().numpy()], + axis=0, + ) + c_true = np.concatenate( + [c_true, concepts.detach().cpu().numpy()], axis=0 ) - c_true = np.concatenate([c_true, concepts.detach().cpu().numpy()], axis=0) c_prb_1 = c_prb[:, 0, :] c_prb_2 = c_prb[:, 1, :] @@ -1193,7 +1297,9 @@ def get_concept_probability_factorized_laplace( return c_prb_1, c_prb_2, gt_factorized -def get_concept_probability_laplace(device, loader, laplace_model, n_ensembles): +def get_concept_probability_laplace( + device, loader, laplace_model, n_ensembles +): """Function which gets the concept probability for Laplace Args: @@ -1208,7 +1314,9 @@ def get_concept_probability_laplace(device, loader, laplace_model, n_ensembles): ensemble_c_prb_1 = [] ensemble_c_prb_2 = [] - ensemble = laplace_model.model.model.get_ensembles(laplace_model, n_ensembles) + ensemble = laplace_model.model.model.get_ensembles( + laplace_model, n_ensembles + ) for model in ensemble: model.eval() @@ -1229,7 +1337,8 @@ def get_concept_probability_laplace(device, loader, laplace_model, n_ensembles): c_prb = out_dict["pCS"].detach().cpu().numpy() else: c_prb = np.concatenate( - [c_prb, out_dict["pCS"].detach().cpu().numpy()], axis=0 + [c_prb, out_dict["pCS"].detach().cpu().numpy()], + axis=0, ) c_prb_1 = c_prb[:, 0, :] @@ -1239,7 +1348,9 @@ def get_concept_probability_laplace(device, loader, laplace_model, n_ensembles): ensemble_c_prb_2.append(c_prb_2) # Average for each model in the ensemble - return calculate_mean_pCs(ensemble_c_prb_1, ensemble_c_prb_2, len(ensemble_c_prb_1)) + return calculate_mean_pCs( + ensemble_c_prb_1, ensemble_c_prb_2, len(ensemble_c_prb_1) + ) def ensemble_p_c_x_distance(ensemble: List[nn.Module], loader): @@ -1270,7 +1381,8 @@ def ensemble_p_c_x_distance(ensemble: List[nn.Module], loader): c_prb = out_dict["pCS"].detach().cpu().numpy() else: c_prb = np.concatenate( - [c_prb, out_dict["pCS"].detach().cpu().numpy()], axis=0 + [c_prb, out_dict["pCS"].detach().cpu().numpy()], + axis=0, ) c_prb_1 = c_prb[:, 0, :] @@ -1317,7 +1429,8 @@ def mcdropout_p_c_x_distance( c_prb = out_dict["pCS"].detach().cpu().numpy() else: c_prb = np.concatenate( - [c_prb, out_dict["pCS"].detach().cpu().numpy()], axis=0 + [c_prb, out_dict["pCS"].detach().cpu().numpy()], + axis=0, ) c_prb_1 = c_prb[:, 0, :] @@ -1371,14 +1484,18 @@ def laplace_p_c_x_distance( ) out_tensor = la.model.model(images) out_dict = recover_predictions_from_laplace( - out_tensor, out_tensor.shape[0], output_classes, num_concepts + out_tensor, + out_tensor.shape[0], + output_classes, + num_concepts, ) if i == 0: c_prb = out_dict["pCS"].detach().cpu().numpy() else: c_prb = np.concatenate( - [c_prb, out_dict["pCS"].detach().cpu().numpy()], axis=0 + [c_prb, out_dict["pCS"].detach().cpu().numpy()], + axis=0, ) c_prb_1 = c_prb[:, 0, :] @@ -1430,7 +1547,9 @@ def vector_to_parameters(vec: torch.Tensor, parameters) -> None: # Ensure vec of type Tensor if not isinstance(vec, torch.Tensor): raise TypeError( - "expected torch.Tensor, but got: {}".format(torch.typename(vec)) + "expected torch.Tensor, but got: {}".format( + torch.typename(vec) + ) ) # Pointer for slicing the vector for each parameter @@ -1439,14 +1558,19 @@ def vector_to_parameters(vec: torch.Tensor, parameters) -> None: # The length of the parameter num_param = param.numel() # Slice the vector, reshape it, and replace the old data of the parameter - param.data = vec[pointer : pointer + num_param].view_as(param).data + param.data = ( + vec[pointer : pointer + num_param].view_as(param).data + ) # Increment the pointer pointer += num_param def get_accuracy_and_counter( - n_concepts: int, c_pred: ndarray, c_true: ndarray, map_index: bool = False + n_concepts: int, + c_pred: ndarray, + c_true: ndarray, + map_index: bool = False, ): """Function which counts the occurrece and accuracy @@ -1486,7 +1610,9 @@ def get_accuracy_and_counter( return concept_counter_list, concept_acc_list -def concept_accuracy(c1_prob: ndarray, c2_prob: ndarray, c_true: ndarray): +def concept_accuracy( + c1_prob: ndarray, c2_prob: ndarray, c_true: ndarray +): """Function which computes the concept accuracy Args: @@ -1508,7 +1634,9 @@ def concept_accuracy(c1_prob: ndarray, c2_prob: ndarray, c_true: ndarray): return get_accuracy_and_counter(n_concepts, c_pred, c_true) -def world_accuracy(world_prob: ndarray, world_true: ndarray, n_concepts: int): +def world_accuracy( + world_prob: ndarray, world_true: ndarray, n_concepts: int +): """Function which computes the world accuracy Args: @@ -1527,7 +1655,11 @@ def world_accuracy(world_prob: ndarray, world_true: ndarray, n_concepts: int): unit_values = np.array([x % n_concepts for x in world_pred]) world_pred = np.array( - np.char.add(decimal_values.astype(str), unit_values.astype(str)) + np.char.add( + decimal_values.astype(str), unit_values.astype(str) + ) ).astype(int) - return get_accuracy_and_counter(n_world, world_pred, world_true, True) + return get_accuracy_and_counter( + n_world, world_pred, world_true, True + ) diff --git a/XOR_MNIST/utils/preprocess_resnet.py b/XOR_MNIST/utils/preprocess_resnet.py index 934aa65..4bbc44f 100644 --- a/XOR_MNIST/utils/preprocess_resnet.py +++ b/XOR_MNIST/utils/preprocess_resnet.py @@ -1,16 +1,14 @@ # This module contains the preoprocessing operation for Kandisnky using a ResNet -import numpy as np - import os -from utils.wandb_logger import * -from utils.status import progress_bar +import numpy as np from datasets.utils.base_dataset import BaseDataset from models.mnistdpl import MnistDPL -from utils.dpl_loss import ADDMNIST_DPL - from utils import fprint +from utils.dpl_loss import ADDMNIST_DPL +from utils.status import progress_bar +from utils.wandb_logger import * def preprocess(model: MnistDPL, dataset: BaseDataset, args): @@ -54,9 +52,17 @@ def preprocess(model: MnistDPL, dataset: BaseDataset, args): y = labels.detach().cpu().numpy() g = concepts.detach().cpu().numpy() - np.save(f"data/kand-preprocess/train/images/{str(id).zfill(5)}", emb) - np.save(f"data/kand-preprocess/train/labels/{str(id).zfill(5)}", y) - np.save(f"data/kand-preprocess/train/concepts/{str(id).zfill(5)}", g) + np.save( + f"data/kand-preprocess/train/images/{str(id).zfill(5)}", + emb, + ) + np.save( + f"data/kand-preprocess/train/labels/{str(id).zfill(5)}", y + ) + np.save( + f"data/kand-preprocess/train/concepts/{str(id).zfill(5)}", + g, + ) if i % 10 == 0: progress_bar(i, len(train_loader) - 9, 0, 0) @@ -84,9 +90,15 @@ def preprocess(model: MnistDPL, dataset: BaseDataset, args): y = labels.detach().cpu().numpy() g = concepts.detach().cpu().numpy() - np.save(f"data/kand-preprocess/val/images/{str(id).zfill(5)}", emb) - np.save(f"data/kand-preprocess/val/labels/{str(id).zfill(5)}", y) - np.save(f"data/kand-preprocess/val/concepts/{str(id).zfill(5)}", g) + np.save( + f"data/kand-preprocess/val/images/{str(id).zfill(5)}", emb + ) + np.save( + f"data/kand-preprocess/val/labels/{str(id).zfill(5)}", y + ) + np.save( + f"data/kand-preprocess/val/concepts/{str(id).zfill(5)}", g + ) if i % 10 == 0: progress_bar(i, len(val_loader) - 9, 0, 0) @@ -114,9 +126,17 @@ def preprocess(model: MnistDPL, dataset: BaseDataset, args): y = labels.detach().cpu().numpy() g = concepts.detach().cpu().numpy() - np.save(f"data/kand-preprocess/test/images/{str(id).zfill(5)}", emb) - np.save(f"data/kand-preprocess/test/labels/{str(id).zfill(5)}", y) - np.save(f"data/kand-preprocess/test/concepts/{str(id).zfill(5)}", g) + np.save( + f"data/kand-preprocess/test/images/{str(id).zfill(5)}", + emb, + ) + np.save( + f"data/kand-preprocess/test/labels/{str(id).zfill(5)}", y + ) + np.save( + f"data/kand-preprocess/test/concepts/{str(id).zfill(5)}", + g, + ) if i % 10 == 0: progress_bar(i, len(test_loader) - 9, 0, 0) diff --git a/XOR_MNIST/utils/semantic_loss.py b/XOR_MNIST/utils/semantic_loss.py index b3e582c..de7a901 100644 --- a/XOR_MNIST/utils/semantic_loss.py +++ b/XOR_MNIST/utils/semantic_loss.py @@ -14,12 +14,18 @@ def __init__(self, loss, logic, args, pcbm=False) -> None: # Worlds-queries matrix if args.task == "addition": self.n_facts = ( - 10 if not args.dataset in ["halfmnist", "restrictedmnist"] else 5 + 10 + if not args.dataset + in ["halfmnist", "restrictedmnist"] + else 5 ) self.nr_classes = 19 elif args.task == "product": self.n_facts = ( - 10 if not args.dataset in ["halfmnist", "restrictedmnist"] else 5 + 10 + if not args.dataset + in ["halfmnist", "restrictedmnist"] + else 5 ) self.nr_classes = 37 elif args.task == "multiop": @@ -78,7 +84,9 @@ def forward(self, out_dict, args): for i in range(self.nr_classes): query = i - query_prob[:, i] = self.compute_query(query, worlds_prob).view(-1) + query_prob[:, i] = self.compute_query( + query, worlds_prob + ).view(-1) # add a small offset query_prob += 1e-5 @@ -86,7 +94,9 @@ def forward(self, out_dict, args): Z = torch.sum(query_prob, dim=-1, keepdim=True) query_prob = query_prob / Z - sl = F.nll_loss(query_prob.log(), Y.to(torch.long), reduction="mean") + sl = F.nll_loss( + query_prob.log(), Y.to(torch.long), reduction="mean" + ) losses.update({"sl": sl.item()}) diff --git a/XOR_MNIST/utils/status.py b/XOR_MNIST/utils/status.py index 5443413..ff89634 100644 --- a/XOR_MNIST/utils/status.py +++ b/XOR_MNIST/utils/status.py @@ -3,16 +3,19 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -from datetime import datetime -import sys import os -from utils.conf import base_path +import sys +from argparse import Namespace +from datetime import datetime from typing import Any, Dict, Union + from torch import nn -from argparse import Namespace +from utils.conf import base_path -def create_stash(model: nn.Module, args: Namespace, dataset) -> Dict[Any, str]: +def create_stash( + model: nn.Module, args: Namespace, dataset +) -> Dict[Any, str]: """Creates the dictionary where to save the model status. Args: @@ -33,12 +36,17 @@ def create_stash(model: nn.Module, args: Namespace, dataset) -> Dict[Any, str]: model_stash["mean_accs"] = [] model_stash["args"] = args model_stash["backup_folder"] = os.path.join( - base_path(), "backups", dataset.SETTING, model_stash["model_name"] + base_path(), + "backups", + dataset.SETTING, + model_stash["model_name"], ) return model_stash -def create_fake_stash(model: nn.Module, args: Namespace) -> Dict[Any, str]: +def create_fake_stash( + model: nn.Module, args: Namespace +) -> Dict[Any, str]: """Create a fake stash, containing just the model name. This is used in general continual, as it is useless to backup a lightweight MNIST-360 training. Args: @@ -59,7 +67,9 @@ def create_fake_stash(model: nn.Module, args: Namespace) -> Dict[Any, str]: return model_stash -def progress_bar(i: int, max_iter: int, epoch: Union[int, str], loss: float) -> None: +def progress_bar( + i: int, max_iter: int, epoch: Union[int, str], loss: float +) -> None: """Prints out the progress bar on the stderr file. Args: @@ -74,7 +84,9 @@ def progress_bar(i: int, max_iter: int, epoch: Union[int, str], loss: float) -> """ # if not (i + 1) % 10 or (i + 1) == max_iter: progress = min(float((i + 1) / max_iter), 1) - progress_bar = ("█" * int(50 * progress)) + ("┈" * (50 - int(50 * progress))) + progress_bar = ("█" * int(50 * progress)) + ( + "┈" * (50 - int(50 * progress)) + ) print( "\r[ {} ] epoch {}: |{}| loss: {}".format( datetime.now().strftime("%m-%d | %H:%M"), diff --git a/XOR_MNIST/utils/test.py b/XOR_MNIST/utils/test.py index 923d4dd..0f3c679 100644 --- a/XOR_MNIST/utils/test.py +++ b/XOR_MNIST/utils/test.py @@ -1,35 +1,35 @@ import os + import numpy as np -from utils.wandb_logger import * +import wandb from datasets.utils.base_dataset import BaseDataset from models.mnistdpl import MnistDPL -import wandb +from utils import fprint +from utils.bayes import ( + activate_dropout, + deep_ensemble, + ensemble_predict, + laplace_approximation, + laplace_prediction, + laplace_single_prediction, + montecarlo_dropout, +) +from utils.checkpoint import get_model_name, load_checkpoint from utils.metrics import ( + concept_accuracy, evaluate_metrics, evaluate_mix, get_concept_probability, - get_concept_probability_mcdropout, get_concept_probability_ensemble, - get_concept_probability_laplace, - get_concept_probability_factorized_mcdropout, get_concept_probability_factorized_ensemble, get_concept_probability_factorized_laplace, - concept_accuracy, + get_concept_probability_factorized_mcdropout, + get_concept_probability_laplace, + get_concept_probability_mcdropout, world_accuracy, ) -from utils.bayes import ( - montecarlo_dropout, - deep_ensemble, - ensemble_predict, - activate_dropout, - laplace_approximation, - laplace_prediction, - laplace_single_prediction, -) -from utils import fprint -from utils.checkpoint import load_checkpoint, get_model_name - from utils.test_utils import * +from utils.wandb_logger import * class IllegalArgumentError(ValueError): @@ -73,7 +73,12 @@ def test(model: MnistDPL, dataset: BaseDataset, args, **kwargs): args.type = TOTAL_METHODS[0] TOTAL_METHODS.pop(0) - print("Doing total evaluation on...", args.type, "remaining: ", TOTAL_METHODS) + print( + "Doing total evaluation on...", + args.type, + "remaining: ", + TOTAL_METHODS, + ) # Wandb if args.wandb is not None: @@ -114,7 +119,10 @@ def test(model: MnistDPL, dataset: BaseDataset, args, **kwargs): if args.model in ["mnistsl", "mnistpcbmsl"]: apply_softmax = True - if args.type == EVALUATION_TYPE.LAPLACE.value and args.skip_laplace: + if ( + args.type == EVALUATION_TYPE.LAPLACE.value + and args.skip_laplace + ): if args.evaluate_all: test(model, dataset, args, **kwargs) else: @@ -123,8 +131,14 @@ def test(model: MnistDPL, dataset: BaseDataset, args, **kwargs): # Retrieve the metrics according to the type of evaluation specified if args.type == EVALUATION_TYPE.NORMAL.value: fprint("## Not Bayesian model ##") - y_true, c_true, y_pred, c_pred, p_cs, p_ys, p_cs_all, _ = evaluate_metrics( - model, test_loader, args, last=True, apply_softmax=apply_softmax + y_true, c_true, y_pred, c_pred, p_cs, p_ys, p_cs_all, _ = ( + evaluate_metrics( + model, + test_loader, + args, + last=True, + apply_softmax=apply_softmax, + ) ) _, c_true_cc, _, c_pred_cc, _, _, _, _ = evaluate_metrics( model, @@ -147,7 +161,9 @@ def test(model: MnistDPL, dataset: BaseDataset, args, **kwargs): p_ys, p_cs_all, _, - ) = montecarlo_dropout(model, test_loader, model.n_facts, 30, apply_softmax) + ) = montecarlo_dropout( + model, test_loader, model.n_facts, 30, apply_softmax + ) elif ( args.type == EVALUATION_TYPE.BEARS.value or args.type == EVALUATION_TYPE.ENSEMBLE.value @@ -162,7 +178,9 @@ def test(model: MnistDPL, dataset: BaseDataset, args, **kwargs): fprint("Preparing the ensembles...") ensemble = deep_ensemble( - seeds=[i + args.seed + 1 for i in range(args.n_ensembles)], + seeds=[ + i + args.seed + 1 for i in range(args.n_ensembles) + ], dataset=dataset, num_epochs=args.n_epochs, args=args, @@ -188,7 +206,9 @@ def test(model: MnistDPL, dataset: BaseDataset, args, **kwargs): p_ys, p_cs_all, _, - ) = ensemble_predict(ensemble, test_loader, model.n_facts, apply_softmax) + ) = ensemble_predict( + ensemble, test_loader, model.n_facts, apply_softmax + ) elif args.type == EVALUATION_TYPE.LAPLACE.value: fprint("### Laplace Approximation ###") fprint("Preparing laplace model, please wait...") @@ -236,7 +256,13 @@ def test(model: MnistDPL, dataset: BaseDataset, args, **kwargs): # metrix, h(c|y) and concept confusion matrix mean_h_c, yac, cac, cf1, yf1 = print_metrics( - y_true, y_pred, c_true, c_pred, p_cs_all, model.n_facts, args.type + y_true, + y_pred, + c_true, + c_pred, + p_cs_all, + model.n_facts, + args.type, ) # Log in Wandb @@ -270,17 +296,28 @@ def test(model: MnistDPL, dataset: BaseDataset, args, **kwargs): kwargs["cac_hard"].append(cac) # label and concept ece - ece_y = produce_ece_curve(p_ys, y_pred, y_true, args.type, "labels") + ece_y = produce_ece_curve( + p_ys, y_pred, y_true, args.type, "labels" + ) if args.type == EVALUATION_TYPE.NORMAL.value: - worlds_prob, c_factorized_1, c_factorized_2, worlds_groundtruth = ( - get_concept_probability(model, test_loader) + ( + worlds_prob, + c_factorized_1, + c_factorized_2, + worlds_groundtruth, + ) = get_concept_probability( + model, test_loader ) # 2 arrays of size 256, 10 (concept 1 and concept 2 for all items) c_pred_normal = worlds_prob.argmax(axis=1) p_c_normal = worlds_prob.max(axis=1) ece = produce_ece_curve( - p_c_normal, c_pred_normal, worlds_groundtruth, args.type, "concepts" + p_c_normal, + c_pred_normal, + worlds_groundtruth, + args.type, + "concepts", ) mean_h_c, yac, cac, cf1, yf1 = print_metrics( @@ -297,7 +334,9 @@ def test(model: MnistDPL, dataset: BaseDataset, args, **kwargs): kwargs["cac_hard"].append(cac) else: - ece = produce_ece_curve(p_cs, c_pred, c_true, args.type, "concepts") + ece = produce_ece_curve( + p_cs, c_pred, c_true, args.type, "concepts" + ) if "ece" not in kwargs: kwargs["ece"] = [] @@ -325,7 +364,9 @@ def test(model: MnistDPL, dataset: BaseDataset, args, **kwargs): if args.type == EVALUATION_TYPE.LAPLACE.value: # Get the ensembles for the inner model - ensemble = laplace_model.model.model.get_ensembles(laplace_model, 30) + ensemble = laplace_model.model.model.get_ensembles( + laplace_model, 30 + ) for i, model in enumerate(ensemble): fprint(f"-- Model {i} --") @@ -339,16 +380,22 @@ def test(model: MnistDPL, dataset: BaseDataset, args, **kwargs): p_cs_all_ens, _, ) = evaluate_metrics( - model, test_loader, args, last=True, apply_softmax=apply_softmax - ) - _, c_true_cc_ens, _, c_pred_cc_ens, _, _, _, _ = evaluate_metrics( model, test_loader, args, last=True, - concatenated_concepts=False, apply_softmax=apply_softmax, ) + _, c_true_cc_ens, _, c_pred_cc_ens, _, _, _, _ = ( + evaluate_metrics( + model, + test_loader, + args, + last=True, + concatenated_concepts=False, + apply_softmax=apply_softmax, + ) + ) mean_sh_c, syac, scac, scf1, syf1 = print_metrics( y_true_ens, @@ -421,7 +468,9 @@ def test(model: MnistDPL, dataset: BaseDataset, args, **kwargs): # Obtain the factorized probabilities c_factorized_1, c_factorized_2, gt_factorized = ( - get_concept_probability_factorized_ensemble(ensemble, test_loader) + get_concept_probability_factorized_ensemble( + ensemble, test_loader + ) ) elif args.type == EVALUATION_TYPE.LAPLACE.value: worlds_prob = get_concept_probability_laplace( @@ -441,8 +490,13 @@ def test(model: MnistDPL, dataset: BaseDataset, args, **kwargs): ) else: # NORMAL MODE - worlds_prob, c_factorized_1, c_factorized_2, worlds_groundtruth = ( - get_concept_probability(model, test_loader) + ( + worlds_prob, + c_factorized_1, + c_factorized_2, + worlds_groundtruth, + ) = get_concept_probability( + model, test_loader ) # 2 arrays of size 256, 10 (concept 1 and concept 2 for all items) # Change it for the concept factorized entropy and variance @@ -451,7 +505,9 @@ def test(model: MnistDPL, dataset: BaseDataset, args, **kwargs): gt_factorized = c_true # factorized probability concatenated - c_factorized_full = np.concatenate((c_factorized_1, c_factorized_2), axis=0) + c_factorized_full = np.concatenate( + (c_factorized_1, c_factorized_2), axis=0 + ) # maximum element probability for the ECE count c_factorized_max_p = np.max(c_factorized_full, axis=1) # factorized predictions with argmax @@ -490,7 +546,9 @@ def test(model: MnistDPL, dataset: BaseDataset, args, **kwargs): kwargs["cf1"].append(cf1) kwargs["cac"].append(cac) - if not any(key in kwargs for key in ["e_c1", "e_c2", "e_c", "e_(c1, c2)"]): + if not any( + key in kwargs for key in ["e_c1", "e_c2", "e_c", "e_(c1, c2)"] + ): kwargs["e_c1"] = list() kwargs["e_c2"] = list() kwargs["e_c"] = list() @@ -519,10 +577,13 @@ def test(model: MnistDPL, dataset: BaseDataset, args, **kwargs): p_cs_all = worlds_prob c_true = worlds_groundtruth - world_counter_list, world_acc_list = world_accuracy(p_cs_all, c_true, model.n_facts) + world_counter_list, world_acc_list = world_accuracy( + p_cs_all, c_true, model.n_facts + ) if not any( - key in kwargs for key in ["c_acc_count", "c_acc", "w_acc_count", "w_acc"] + key in kwargs + for key in ["c_acc_count", "c_acc", "w_acc_count", "w_acc"] ): kwargs["c_acc_count"] = list() kwargs["c_acc"] = list() @@ -536,7 +597,9 @@ def test(model: MnistDPL, dataset: BaseDataset, args, **kwargs): kwargs["w_acc_count"].append(world_counter_list) kwargs["w_acc"].append(world_acc_list) - e_per_c = compute_entropy_per_concept(c_factorized_full, gt_factorized) + e_per_c = compute_entropy_per_concept( + c_factorized_full, gt_factorized + ) kwargs["c_ova_filtered"].append(e_per_c["c_ova_filtered"]) kwargs["c_all_filtered"].append(e_per_c["c_all_filtered"]) diff --git a/XOR_MNIST/utils/test_utils.py b/XOR_MNIST/utils/test_utils.py index 24226b6..0a52308 100644 --- a/XOR_MNIST/utils/test_utils.py +++ b/XOR_MNIST/utils/test_utils.py @@ -1,37 +1,36 @@ # Debugging stuff for test.py import json -import numpy as np -import torch - import math import time - -from numpy import ndarray -from utils.wandb_logger import * from enum import Enum +from itertools import product from typing import List + +import numpy as np +import torch +from numpy import ndarray +from utils import fprint +from utils.checkpoint import get_model_name from utils.metrics import ( - evaluate_mix, - get_alpha, - get_alpha_single, - expected_calibration_error, - mean_entropy, class_mean_entropy, class_mean_variance, + ensemble_p_c_x_distance, + evaluate_mix, + expected_calibration_error, expected_calibration_error_by_concept, + get_alpha, + get_alpha_single, laplace_p_c_x_distance, - ensemble_p_c_x_distance, mcdropout_p_c_x_distance, + mean_entropy, ) from utils.visualization import ( - produce_confusion_matrix, produce_alpha_matrix, - produce_calibration_curve, produce_bar_plot, + produce_calibration_curve, + produce_confusion_matrix, ) -from utils import fprint -from utils.checkpoint import get_model_name -from itertools import product +from utils.wandb_logger import * class ECEMODE(Enum): @@ -70,10 +69,14 @@ def euclidean_distance(w1, w2): Returns: distance: euclidean distance """ - return torch.sqrt(sum(torch.sum((p1 - p2) ** 2) for p1, p2 in zip(w1, w2))) + return torch.sqrt( + sum(torch.sum((p1 - p2) ** 2) for p1, p2 in zip(w1, w2)) + ) -def fprint_weights_distance(original_weights, ensemble, method_1, method_2): +def fprint_weights_distance( + original_weights, ensemble, method_1, method_2 +): """Function which prints euclidean distance between model and elements within the ensemble Args: @@ -87,10 +90,17 @@ def fprint_weights_distance(original_weights, ensemble, method_1, method_2): """ distance = 0 for model in ensemble: - model_weights = [param.data.clone() for param in model.parameters()] - distance += euclidean_distance(original_weights, model_weights) + model_weights = [ + param.data.clone() for param in model.parameters() + ] + distance += euclidean_distance( + original_weights, model_weights + ) distance = distance / len(ensemble) - fprint(f"Euclidean Distance between {method_1} and {method_2}: ", distance.item()) + fprint( + f"Euclidean Distance between {method_1} and {method_2}: ", + distance.item(), + ) def fprint_ensemble_distance(ensemble): @@ -104,11 +114,21 @@ def fprint_ensemble_distance(ensemble): """ distance = 0 for i in range(len(ensemble) - 1): - original_weights = [param.data.clone() for param in ensemble[i].parameters()] + original_weights = [ + param.data.clone() for param in ensemble[i].parameters() + ] for j in range(i + 1, len(ensemble)): - model_weights = [param.data.clone() for param in ensemble[j].parameters()] - distance = euclidean_distance(original_weights, model_weights) - fprint(f"Euclidean Distance between #{i} and #{j}: ", distance.item()) + model_weights = [ + param.data.clone() + for param in ensemble[j].parameters() + ] + distance = euclidean_distance( + original_weights, model_weights + ) + fprint( + f"Euclidean Distance between #{i} and #{j}: ", + distance.item(), + ) def print_p_c_given_x_distance( @@ -151,7 +171,10 @@ def print_p_c_given_x_distance( dist = mcdropout_p_c_x_distance( model, test_loader, activate_dropout, num_ensembles ) - elif type == EVALUATION_TYPE.BEARS.value or type == EVALUATION_TYPE.ENSEMBLE.value: + elif ( + type == EVALUATION_TYPE.BEARS.value + or type == EVALUATION_TYPE.ENSEMBLE.value + ): dist = ensemble_p_c_x_distance(ensemble, test_loader) fprint(f"Mean P(C|X) for {type} distance L2 is {dist}") @@ -200,7 +223,11 @@ def print_metrics( def produce_h_c_given_y( - p_cs_all: ndarray, y_true: ndarray, nr_classes: int, mode: str, suffix: str + p_cs_all: ndarray, + y_true: ndarray, + nr_classes: int, + mode: str, + suffix: str, ) -> None: """Function which produces a bar plot of H(C|Y) @@ -228,7 +255,11 @@ def produce_h_c_given_y( def produce_var_c_given_y( - p_cs_all: ndarray, y_true: ndarray, nr_classes: int, mode: str, suffix: str + p_cs_all: ndarray, + y_true: ndarray, + nr_classes: int, + mode: str, + suffix: str, ) -> None: """Function which produces a bar plot of Var(C|Y) @@ -306,10 +337,13 @@ def ova_entropy(p: ndarray, c: int): c_fact_stacked = np.vstack([c_fact_1, c_fact_2]) for c_fact, key in zip( - [c_fact_1, c_fact_2, c_fact_stacked, p_w_x], ["c1", "c2", "c", "(c1, c2)"] + [c_fact_1, c_fact_2, c_fact_stacked, p_w_x], + ["c1", "c2", "c", "(c1, c2)"], ): for c in range(c_fact.shape[1]): - result = np.apply_along_axis(ova_entropy, axis=1, arr=c_fact, c=c) + result = np.apply_along_axis( + ova_entropy, axis=1, arr=c_fact, c=c + ) conditional_entropies[key].append(np.mean(result)) return conditional_entropies @@ -370,17 +404,28 @@ def entropy(p: ndarray): return normalized_entropy - conditional_entropies = {"c_ova_filtered": list(), "c_all_filtered": list()} + conditional_entropies = { + "c_ova_filtered": list(), + "c_all_filtered": list(), + } for c in range(c_fact_stacked.shape[1]): indices = np.where(c_true == c)[0] c_fact_filtered = c_fact_stacked[indices] - result = np.apply_along_axis(ova_entropy, axis=1, arr=c_fact_filtered, c=c) - conditional_entropies["c_ova_filtered"].append(np.mean(result)) + result = np.apply_along_axis( + ova_entropy, axis=1, arr=c_fact_filtered, c=c + ) + conditional_entropies["c_ova_filtered"].append( + np.mean(result) + ) - result = np.apply_along_axis(entropy, axis=1, arr=c_fact_filtered) - conditional_entropies["c_all_filtered"].append(np.mean(result)) + result = np.apply_along_axis( + entropy, axis=1, arr=c_fact_filtered + ) + conditional_entropies["c_all_filtered"].append( + np.mean(result) + ) return conditional_entropies @@ -444,10 +489,13 @@ def bernoulli_std(p: ndarray, c: int): c_fact_stacked = np.vstack([c_fact_1, c_fact_2]) for c_fact, key in zip( - [c_fact_1, c_fact_2, c_fact_stacked, p_w_x], ["c1", "c2", "c", "(c1, c2)"] + [c_fact_1, c_fact_2, c_fact_stacked, p_w_x], + ["c1", "c2", "c", "(c1, c2)"], ): for c in range(c_fact.shape[1]): - result = np.apply_along_axis(bernoulli_std, axis=1, arr=c_fact, c=c) + result = np.apply_along_axis( + bernoulli_std, axis=1, arr=c_fact, c=c + ) conditional_variances[key].append(np.mean(result) ** 2) return conditional_variances @@ -481,13 +529,18 @@ def produce_ece_curve( ece = None if ece_mode == ECEMODE.FILTERED_BY_CONCEPT: - ece_data = expected_calibration_error_by_concept(p, pred, true, concept) + ece_data = expected_calibration_error_by_concept( + p, pred, true, concept + ) else: ece_data = expected_calibration_error(p, pred, true) if ece_data: ece, ece_bins = ece_data - fprint(f"Expected Calibration Error (ECE) {exp_mode} on {purpose}", ece) + fprint( + f"Expected Calibration Error (ECE) {exp_mode} on {purpose}", + ece, + ) concept_flag = True if purpose != "labels" else False produce_calibration_curve( ece_bins, @@ -514,10 +567,18 @@ def generate_concept_labels(concept_labels: List[str]): # Generate all the product with repetition of size two of the concept labels (which indeed are all the possible words) - concept_labels_full = ["".join(comb) for comb in product(concept_labels, repeat=2)] - concept_labels_single = ["".join(comb) for comb in product(concept_labels)] - sklearn_concept_labels = [str(int(el)) for el in concept_labels_full] - sklearn_concept_labels_single = [str(int(el)) for el in concept_labels_single] + concept_labels_full = [ + "".join(comb) for comb in product(concept_labels, repeat=2) + ] + concept_labels_single = [ + "".join(comb) for comb in product(concept_labels) + ] + sklearn_concept_labels = [ + str(int(el)) for el in concept_labels_full + ] + sklearn_concept_labels_single = [ + str(int(el)) for el in concept_labels_single + ] return ( concept_labels_full, @@ -605,13 +666,17 @@ def produce_confusion_matrices( from itertools import product concept_labels = [ - "".join(comb) for comb in product([str(el) for el in range(10)], repeat=2) + "".join(comb) + for comb in product([str(el) for el in range(10)], repeat=2) ] concept_labels_single = [ - "".join(comb) for comb in product([str(el) for el in range(10)]) + "".join(comb) + for comb in product([str(el) for el in range(10)]) ] sklearn_concept_labels = [str(int(el)) for el in concept_labels] - sklearn_concept_labels_single = [str(int(el)) for el in concept_labels_single] + sklearn_concept_labels_single = [ + str(int(el)) for el in concept_labels_single + ] # extend them in order to have a single element: e.g. 03 means that the first element was associated to 0 while the second with 3 c_extended_true = np.array( @@ -682,12 +747,18 @@ def produce_alpha( alpha_M, _ = get_alpha(worlds_prob, c_true_cc, n_facts=n_facts) produce_alpha_matrix( - alpha_M, "p((C1, C2)| (G1, G2))", concept_labels, f"alpha_plot_{mode}", n_facts + alpha_M, + "p((C1, C2)| (G1, G2))", + concept_labels, + f"alpha_plot_{mode}", + n_facts, ) # Only the single model produces the single ALPHA if type == EVALUATION_TYPE.NORMAL.value: - words_prob_single_concept = np.concatenate((c_prb_1, c_prb_2), axis=0) + words_prob_single_concept = np.concatenate( + (c_prb_1, c_prb_2), axis=0 + ) alpha_M_single, _ = get_alpha_single( words_prob_single_concept, c_true, n_facts=n_facts ) @@ -740,7 +811,9 @@ def save_csv( """ import csv - gt_factorized = np.reshape(gt_factorized, (int(gt_factorized.shape[0] / 2), 2)) + gt_factorized = np.reshape( + gt_factorized, (int(gt_factorized.shape[0] / 2), 2) + ) with open(file_path, "a", newline="") as csvfile: csv_writer = csv.writer(csvfile) diff --git a/XOR_MNIST/utils/train.py b/XOR_MNIST/utils/train.py index dbd5ead..21a7369 100644 --- a/XOR_MNIST/utils/train.py +++ b/XOR_MNIST/utils/train.py @@ -1,31 +1,26 @@ # Module which contains the code for training a model and the active learning setup -import torch -import numpy as np - -import wandb +import math import os import sys -import math - -from torchvision.utils import make_grid -from utils.wandb_logger import * -from utils.status import progress_bar +import numpy as np +import torch +import wandb from datasets.utils.base_dataset import BaseDataset from models.mnistdpl import MnistDPL +from torchvision.utils import make_grid +from utils import fprint +from utils.bayes import deep_ensemble_active from utils.dpl_loss import ADDMNIST_DPL +from utils.generative import conditional_gen, recon_visaulization from utils.metrics import ( evaluate_metrics, - evaluate_mix, evaluate_metrics_ensemble, + evaluate_mix, mean_entropy, ) -from utils.generative import conditional_gen, recon_visaulization -from utils.bayes import ( - deep_ensemble_active, -) -from utils import fprint - +from utils.status import progress_bar +from utils.wandb_logger import * from warmup_scheduler import GradualWarmupScheduler @@ -40,7 +35,9 @@ def active_start(model, seed, ensemble=[]): Returns: None: This function does not return a value. """ - model_filename = f"data/ckpts/minikandinsky-kanddpl-dis-{seed}-end.pt" + model_filename = ( + f"data/ckpts/minikandinsky-kanddpl-dis-{seed}-end.pt" + ) if os.path.exists(model_filename): state_dict = torch.load(model_filename) @@ -84,7 +81,9 @@ def return_metrics( return yac, cac, cf1, yf1 -def train_active(model: MnistDPL, dataset: BaseDataset, _loss: ADDMNIST_DPL, args): +def train_active( + model: MnistDPL, dataset: BaseDataset, _loss: ADDMNIST_DPL, args +): """Active learning Args: @@ -180,7 +179,9 @@ def train_active(model: MnistDPL, dataset: BaseDataset, _loss: ADDMNIST_DPL, arg # temporal lists containing initial superivion + chosen supervision tmp_data_idx = starting_img_idx + chosen_samples[0] - tmp_figure_idx = starting_figure_idx + chosen_samples[1] + tmp_figure_idx = ( + starting_figure_idx + chosen_samples[1] + ) tmp_obj_idx = starting_obj_idx + chosen_samples[2] # counter of selected elements @@ -188,7 +189,9 @@ def train_active(model: MnistDPL, dataset: BaseDataset, _loss: ADDMNIST_DPL, arg # loop through indices, figures, objects # reminders: indices, figures, objects are the sorted vectors according to highest entropy samples - for indice, figure, oggetto in zip(indices, figures, objects): + for indice, figure, oggetto in zip( + indices, figures, objects + ): if random_selection: # random supervision indice, figure, oggetto = ( @@ -232,11 +235,15 @@ def train_active(model: MnistDPL, dataset: BaseDataset, _loss: ADDMNIST_DPL, arg # Default Setting for Training model.to(model.device) - train_loader, val_loader, test_loader = dataset.get_data_loaders() + train_loader, val_loader, test_loader = ( + dataset.get_data_loaders() + ) scheduler = torch.optim.lr_scheduler.ExponentialLR( model.opt, args.exp_decay ) - w_scheduler = GradualWarmupScheduler(model.opt, 1.0, args.warmup_steps) + w_scheduler = GradualWarmupScheduler( + model.opt, 1.0, args.warmup_steps + ) # train loader without shuffling train_loader_as_val = dataset.get_train_loader_as_val() @@ -255,7 +262,10 @@ def train_active(model: MnistDPL, dataset: BaseDataset, _loss: ADDMNIST_DPL, arg if biretta: # training ensembles ensemble = deep_ensemble_active( - seeds=[i + args.seed + 1 for i in range(n_ensembles - 1)], + seeds=[ + i + args.seed + 1 + for i in range(n_ensembles - 1) + ], base_model=model, dataset=dataset, num_epochs=args.n_epochs, @@ -289,7 +299,11 @@ def train_active(model: MnistDPL, dataset: BaseDataset, _loss: ADDMNIST_DPL, arg out_dict = model(images) out_dict.update( - {"INPUTS": images, "LABELS": labels, "CONCEPTS": concepts} + { + "INPUTS": images, + "LABELS": labels, + "CONCEPTS": concepts, + } ) model.opt.zero_grad() @@ -308,10 +322,17 @@ def train_active(model: MnistDPL, dataset: BaseDataset, _loss: ADDMNIST_DPL, arg ) if i % 10 == 0: - progress_bar(i, len(train_loader) - 9, epoch, loss.item()) + progress_bar( + i, + len(train_loader) - 9, + epoch, + loss.item(), + ) model.eval() - tloss, cacc, yacc, f1 = evaluate_metrics(model, val_loader, args) + tloss, cacc, yacc, f1 = evaluate_metrics( + model, val_loader, args + ) # update at end of the epoch if epoch < args.warmup_steps: @@ -322,7 +343,9 @@ def train_active(model: MnistDPL, dataset: BaseDataset, _loss: ADDMNIST_DPL, arg _loss.update_grade(epoch) ### LOGGING ### - fprint("\n ACC C", cacc, " ACC Y", yacc, "F1 Y", f1) + fprint( + "\n ACC C", cacc, " ACC Y", yacc, "F1 Y", f1 + ) # simple early stopper if yacc > 95: @@ -340,16 +363,36 @@ def train_active(model: MnistDPL, dataset: BaseDataset, _loss: ADDMNIST_DPL, arg # Evaluate performances on val or test if args.validate: - y_true, c_true, y_pred, c_pred, p_cs, p_ys, p_cs_all, p_ys_all = ( - evaluate_metrics(model, val_loader, args, last=True) + ( + y_true, + c_true, + y_pred, + c_pred, + p_cs, + p_ys, + p_cs_all, + p_ys_all, + ) = evaluate_metrics( + model, val_loader, args, last=True ) else: - y_true, c_true, y_pred, c_pred, p_cs, p_ys, p_cs_all, p_ys_all = ( - evaluate_metrics(model, test_loader, args, last=True) + ( + y_true, + c_true, + y_pred, + c_pred, + p_cs, + p_ys, + p_cs_all, + p_ys_all, + ) = evaluate_metrics( + model, test_loader, args, last=True ) if not biretta: - tloss, cacc, yacc, f1 = evaluate_metrics(model, val_loader, args) + tloss, cacc, yacc, f1 = evaluate_metrics( + model, val_loader, args + ) else: tloss, cacc, yacc, f1 = evaluate_metrics_ensemble( ensemble, val_loader, args @@ -361,26 +404,41 @@ def train_active(model: MnistDPL, dataset: BaseDataset, _loss: ADDMNIST_DPL, arg acc_c[int(biretta)].append(cacc) acc_y[int(biretta)].append(yacc) - entropy_array, figures, objects, indices = concept_supervision_selection( - train_loader_as_val, model, ensemble + entropy_array, figures, objects, indices = ( + concept_supervision_selection( + train_loader_as_val, model, ensemble + ) ) # print all, and print as many as finetuning says np.set_printoptions(threshold=sys.maxsize) fprint( - "Most confused elements figures", figures[: args.finetuning].astype(int) + "Most confused elements figures", + figures[: args.finetuning].astype(int), ) fprint( - "Most confused elements objects", objects[: args.finetuning].astype(int) + "Most confused elements objects", + objects[: args.finetuning].astype(int), ) - fprint("Most confused elements entropy", entropy_array[: args.finetuning]) fprint( - "Most confused elements indices", indices[: args.finetuning].astype(int) + "Most confused elements entropy", + entropy_array[: args.finetuning], + ) + fprint( + "Most confused elements indices", + indices[: args.finetuning].astype(int), + ) + fprint( + "Max entropy", + max(entropy_array), + "min entropy", + min(entropy_array), ) - fprint("Max entropy", max(entropy_array), "min entropy", min(entropy_array)) - ensemble_string = "" if len(ensemble) == 0 else "-ensemble" + ensemble_string = ( + "" if len(ensemble) == 0 else "-ensemble" + ) if not biretta: # save numpy values for debugging @@ -464,8 +522,12 @@ def train_active(model: MnistDPL, dataset: BaseDataset, _loss: ADDMNIST_DPL, arg tag = ["shapes", "colors"] for i in range(2): - all_c = np.concatenate((cs[i][:, 0], cs[i][:, 1], cs[i][:, 2])) - all_g = np.concatenate((gs[i][:, 0], gs[i][:, 1], gs[i][:, 2])) + all_c = np.concatenate( + (cs[i][:, 0], cs[i][:, 1], cs[i][:, 2]) + ) + all_g = np.concatenate( + (gs[i][:, 0], gs[i][:, 1], gs[i][:, 2]) + ) wandb.log( { @@ -478,11 +540,19 @@ def train_active(model: MnistDPL, dataset: BaseDataset, _loss: ADDMNIST_DPL, arg cs = np.split(c, 2, axis=1) gs = np.split(g, 2, axis=1) - shapes_pred = np.concatenate((cs[0][:, 0], cs[0][:, 1], cs[0][:, 2])) - shapes_true = np.concatenate((gs[0][:, 0], gs[0][:, 1], gs[0][:, 2])) + shapes_pred = np.concatenate( + (cs[0][:, 0], cs[0][:, 1], cs[0][:, 2]) + ) + shapes_true = np.concatenate( + (gs[0][:, 0], gs[0][:, 1], gs[0][:, 2]) + ) - colors_pred = np.concatenate((cs[1][:, 0], cs[1][:, 1], cs[1][:, 2])) - colors_true = np.concatenate((gs[1][:, 0], gs[1][:, 1], gs[1][:, 2])) + colors_pred = np.concatenate( + (cs[1][:, 0], cs[1][:, 1], cs[1][:, 2]) + ) + colors_true = np.concatenate( + (gs[1][:, 0], gs[1][:, 1], gs[1][:, 2]) + ) all_c = shapes_pred * 3 + colors_pred all_g = shapes_true * 3 + colors_true @@ -504,8 +574,12 @@ def train_active(model: MnistDPL, dataset: BaseDataset, _loss: ADDMNIST_DPL, arg suffix = "_biretta" - np.save(f"data/kand-analysis/acc_c_{args.seed}{suffix}.npy", acc_c) - np.save(f"data/kand-analysis/acc_y_{args.seed}{suffix}.npy", acc_y) + np.save( + f"data/kand-analysis/acc_c_{args.seed}{suffix}.npy", acc_c + ) + np.save( + f"data/kand-analysis/acc_y_{args.seed}{suffix}.npy", acc_y + ) wandb.finish() @@ -548,14 +622,18 @@ def ova_entropy_per_object_ensemble(p_list, n_figure, n_object): shape_worlds = np.asarray(list(p[n_figure, n_object, 3:])) # compute the outer product - world_prob.append(np.outer(color_worlds, shape_worlds).flatten()) + world_prob.append( + np.outer(color_worlds, shape_worlds).flatten() + ) # mean in the ensemble world_prob = np.stack(world_prob, axis=0) world_prob = np.mean(world_prob, axis=0) # entropy - entropy_value = -np.sum(world_prob * np.log(world_prob)) / (math.log(9)) + entropy_value = -np.sum(world_prob * np.log(world_prob)) / ( + math.log(9) + ) return entropy_value @@ -577,7 +655,9 @@ def ova_entropy_per_object(p, n_figure, n_object): shape_worlds = np.asarray(list(p[n_figure, n_object, 3:])) world_prob = np.outer(color_worlds, shape_worlds).flatten() - entropy_value = -np.sum(world_prob * np.log(world_prob)) / (math.log(9)) + entropy_value = -np.sum(world_prob * np.log(world_prob)) / ( + math.log(9) + ) return entropy_value @@ -599,11 +679,15 @@ def ova_entropy_per_object(p, n_figure, n_object): logits = model(images, activate_simple_concepts=True) logits = list(torch.split(logits, 3, dim=-1)) for i in range(len(logits)): - logits[i] = torch.nn.functional.softmax(logits[i], dim=-1) + logits[i] = torch.nn.functional.softmax( + logits[i], dim=-1 + ) logits = torch.cat(logits, dim=-1) probs.append(logits.detach().cpu().numpy()) - probs = np.concatenate(probs, axis=0) # data, images, objects, probabilities + probs = np.concatenate( + probs, axis=0 + ) # data, images, objects, probabilities else: @@ -621,7 +705,9 @@ def ova_entropy_per_object(p, n_figure, n_object): logits = m(images, activate_simple_concepts=True) logits = list(torch.split(logits, 3, dim=-1)) for i in range(len(logits)): - logits[i] = torch.nn.functional.softmax(logits[i], dim=-1) + logits[i] = torch.nn.functional.softmax( + logits[i], dim=-1 + ) logits = torch.cat(logits, dim=-1) probs.append(logits.detach().cpu().numpy()) @@ -655,9 +741,15 @@ def ova_entropy_per_object(p, n_figure, n_object): n_figure=n_figure, n_object=n_object, ) - entropy_array = np.concatenate((entropy_array, entropies), axis=None) - figures_array = np.concatenate((figures_array, n_figure), axis=None) - objects_array = np.concatenate((objects_array, n_object), axis=None) + entropy_array = np.concatenate( + (entropy_array, entropies), axis=None + ) + figures_array = np.concatenate( + (figures_array, n_figure), axis=None + ) + objects_array = np.concatenate( + (objects_array, n_object), axis=None + ) indices = np.concatenate((indices, i), axis=None) # rank indices in descending order according to entropy values @@ -672,7 +764,9 @@ def ova_entropy_per_object(p, n_figure, n_object): ) -def train(model: MnistDPL, dataset: BaseDataset, _loss: ADDMNIST_DPL, args): +def train( + model: MnistDPL, dataset: BaseDataset, _loss: ADDMNIST_DPL, args +): """TRAINING Args: @@ -689,8 +783,12 @@ def train(model: MnistDPL, dataset: BaseDataset, _loss: ADDMNIST_DPL, args): model.to(model.device) train_loader, val_loader, test_loader = dataset.get_data_loaders() dataset.print_stats() - scheduler = torch.optim.lr_scheduler.ExponentialLR(model.opt, args.exp_decay) - w_scheduler = GradualWarmupScheduler(model.opt, 1.0, args.warmup_steps) + scheduler = torch.optim.lr_scheduler.ExponentialLR( + model.opt, args.exp_decay + ) + w_scheduler = GradualWarmupScheduler( + model.opt, 1.0, args.warmup_steps + ) if args.wandb is not None: fprint("\n---wandb on\n") @@ -722,7 +820,13 @@ def train(model: MnistDPL, dataset: BaseDataset, _loss: ADDMNIST_DPL, args): out_dict = model(images) - out_dict.update({"INPUTS": images, "LABELS": labels, "CONCEPTS": concepts}) + out_dict.update( + { + "INPUTS": images, + "LABELS": labels, + "CONCEPTS": concepts, + } + ) model.opt.zero_grad() loss, losses = _loss(out_dict, args) @@ -737,13 +841,17 @@ def train(model: MnistDPL, dataset: BaseDataset, _loss: ADDMNIST_DPL, args): y_true = out_dict["LABELS"] else: ys = torch.concatenate((ys, out_dict["YS"]), dim=0) - y_true = torch.concatenate((y_true, out_dict["LABELS"]), dim=0) + y_true = torch.concatenate( + (y_true, out_dict["LABELS"]), dim=0 + ) if args.wandb is not None: wandb_log_step(i, epoch, loss.item(), losses) if i % 10 == 0: - progress_bar(i, len(train_loader) - 9, epoch, loss.item()) + progress_bar( + i, len(train_loader) - 9, epoch, loss.item() + ) y_pred = torch.argmax(ys, dim=-1) @@ -755,7 +863,9 @@ def train(model: MnistDPL, dataset: BaseDataset, _loss: ADDMNIST_DPL, args): ) model.eval() - tloss, cacc, yacc, f1 = evaluate_metrics(model, val_loader, args) + tloss, cacc, yacc, f1 = evaluate_metrics( + model, val_loader, args + ) # update at end of the epoch if epoch < args.warmup_steps: @@ -779,13 +889,27 @@ def train(model: MnistDPL, dataset: BaseDataset, _loss: ADDMNIST_DPL, args): # Evaluate performances on val or test if args.validate: - y_true, c_true, y_pred, c_pred, p_cs, p_ys, p_cs_all, p_ys_all = ( - evaluate_metrics(model, val_loader, args, last=True) - ) + ( + y_true, + c_true, + y_pred, + c_pred, + p_cs, + p_ys, + p_cs_all, + p_ys_all, + ) = evaluate_metrics(model, val_loader, args, last=True) else: - y_true, c_true, y_pred, c_pred, p_cs, p_ys, p_cs_all, p_ys_all = ( - evaluate_metrics(model, test_loader, args, last=True) - ) + ( + y_true, + c_true, + y_pred, + c_pred, + p_cs, + p_ys, + p_cs_all, + p_ys_all, + ) = evaluate_metrics(model, test_loader, args, last=True) yac, yf1 = evaluate_mix(y_true, y_pred) cac, cf1 = evaluate_mix(c_true, c_pred) @@ -804,7 +928,10 @@ def train(model: MnistDPL, dataset: BaseDataset, _loss: ADDMNIST_DPL, args): wandb.log( { "cf-labels": wandb.plot.confusion_matrix( - None, y_true, y_pred, class_names=[str(i) for i in range(K + 1)] + None, + y_true, + y_pred, + class_names=[str(i) for i in range(K + 1)], ), } ) @@ -812,7 +939,10 @@ def train(model: MnistDPL, dataset: BaseDataset, _loss: ADDMNIST_DPL, args): wandb.log( { "cf-concepts": wandb.plot.confusion_matrix( - None, c_true, c_pred, class_names=[str(i) for i in range(K + 1)] + None, + c_true, + c_pred, + class_names=[str(i) for i in range(K + 1)], ), } ) @@ -822,11 +952,17 @@ def train(model: MnistDPL, dataset: BaseDataset, _loss: ADDMNIST_DPL, args): conditional_gen(model), nrow=8, ) - images = wandb.Image(list_images, caption="Generated samples") + images = wandb.Image( + list_images, caption="Generated samples" + ) wandb.log({"Conditional Gen": images}) - list_images = make_grid(recon_visaulization(out_dict), nrow=8) - images = wandb.Image(list_images, caption="Reconstructed samples") + list_images = make_grid( + recon_visaulization(out_dict), nrow=8 + ) + images = wandb.Image( + list_images, caption="Reconstructed samples" + ) wandb.log({"Reconstruction": images}) wandb.finish() diff --git a/XOR_MNIST/utils/visualization.py b/XOR_MNIST/utils/visualization.py index fc8559a..8d6b775 100644 --- a/XOR_MNIST/utils/visualization.py +++ b/XOR_MNIST/utils/visualization.py @@ -2,12 +2,12 @@ import os -from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay -from numpy import ndarray -from typing import List, Dict +from typing import Dict, List -import numpy as np import matplotlib.pyplot as plt +import numpy as np +from numpy import ndarray +from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix OUTPUT_FOLDER = "./plots" # PLOTS FOLDER @@ -57,7 +57,9 @@ def produce_confusion_matrix( ground_truth = y_true.astype(str) predictions = y_pred.astype(str) - all_labels = np.union1d(np.unique(ground_truth), np.unique(predictions)) + all_labels = np.union1d( + np.unique(ground_truth), np.unique(predictions) + ) label_encoder.fit(all_labels) ground_truth = label_encoder.transform(ground_truth) @@ -91,12 +93,20 @@ def produce_confusion_matrix( x_labels = ax.get_xticklabels() tick_positions = np.arange(len(x_labels)) - tick_labels = [x_labels[i] if i % ntimes == 0 else "" for i in range(len(x_labels))] - ax.set_xticks(tick_positions, tick_labels, rotation=90, fontsize=10) + tick_labels = [ + x_labels[i] if i % ntimes == 0 else "" + for i in range(len(x_labels)) + ] + ax.set_xticks( + tick_positions, tick_labels, rotation=90, fontsize=10 + ) y_labels = ax.get_yticklabels() tick_positions = np.arange(len(y_labels)) - tick_labels = [y_labels[i] if i % ntimes == 0 else "" for i in range(len(y_labels))] + tick_labels = [ + y_labels[i] if i % ntimes == 0 else "" + for i in range(len(y_labels)) + ] ax.set_yticks(tick_positions, tick_labels, fontsize=10) # Set title and color bar label @@ -137,7 +147,9 @@ def produce_world_probability_table( ax.axis("off") # Create the table with blue color - table_data = [[key, f"{value:.2f}"] for key, value in zip(keys, values)] + table_data = [ + [key, f"{value:.2f}"] for key, value in zip(keys, values) + ] table = ax.table( cellText=table_data, colLabels=[key_string, key_value], @@ -151,7 +163,9 @@ def produce_world_probability_table( table.set_fontsize(10) table.scale(1.5, 1.5) - ax.set_title(title, fontweight="bold", fontsize=16, color="#3366cc") + ax.set_title( + title, fontweight="bold", fontsize=16, color="#3366cc" + ) # Specify the file path and name where you want to save the image file_path = f"{OUTPUT_FOLDER}/mean_probability_table.png" @@ -200,7 +214,12 @@ def produce_alpha_matrix( # tick_positions = np.arange(len(concept_labels)) # tick_labels = [concept_labels[i] if i % ntimes == 0 else '' for i in range(len(concept_labels))] # plt.xticks(tick_positions, tick_labels, rotation=90, fontsize=10) - plt.xticks(np.arange(len(concept_labels)), concept_labels, rotation=90, fontsize=10) + plt.xticks( + np.arange(len(concept_labels)), + concept_labels, + rotation=90, + fontsize=10, + ) # Set y-axis ticks and labels # tick_positions = np.arange(len(keys)) @@ -251,13 +270,24 @@ def produce_scatter_multi_class( None: This function does not return a value. """ if colors is None: - colors = ["blue", "red", "green", "orange", "purple", "brown", "pink", "gray"] + colors = [ + "blue", + "red", + "green", + "orange", + "purple", + "brown", + "pink", + "gray", + ] if markers is None: markers = ["o", "s", "^", "D", "v", ">", "<", "p"] # Create a scatter plot for each class - for i, (x_values, y_values) in enumerate(zip(x_values_list, y_values_list)): + for i, (x_values, y_values) in enumerate( + zip(x_values_list, y_values_list) + ): plt.scatter( x_values, y_values, @@ -324,7 +354,9 @@ def plot_grouped_entropies( index = np.arange(num_categories, dtype=float) # Adjust the index positions to create separation between groups - total_group_width = num_categories * bar_width + (num_categories - 1) * group_gap + total_group_width = ( + num_categories * bar_width + (num_categories - 1) * group_gap + ) linspace_values = np.linspace( -total_group_width / 2, total_group_width / 2, num_categories ) @@ -366,7 +398,9 @@ def plot_grouped_entropies( ax.legend() if save: - file_path = f"{OUTPUT_FOLDER}/{dataset}_{prefix}_hc_bar_plot.png" + file_path = ( + f"{OUTPUT_FOLDER}/{dataset}_{prefix}_hc_bar_plot.png" + ) fig.tight_layout() plt.savefig(file_path, dpi=150) @@ -393,15 +427,21 @@ def produce_calibration_curve( num_bins = len(bin_info) # Extract relevant information from bin_info - bin_confidence = [bin_info[i]["BIN_CONF"] for i in range(num_bins)] + bin_confidence = [ + bin_info[i]["BIN_CONF"] for i in range(num_bins) + ] bin_accuracy = [bin_info[i]["BIN_ACC"] for i in range(num_bins)] bin_counts = [bin_info[i]["COUNT"] for i in range(num_bins)] # Calculate the center of each bin - bin_centers = np.linspace(1 / (2 * num_bins), 1 - 1 / (2 * num_bins), num_bins) + bin_centers = np.linspace( + 1 / (2 * num_bins), 1 - 1 / (2 * num_bins), num_bins + ) # Create a subplot with two plots (2 rows, 1 column) - fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 10), sharex=True) + fig, (ax1, ax2) = plt.subplots( + 2, 1, figsize=(10, 10), sharex=True + ) txt = "Concept" if concept else "Label" fig.suptitle(f"{txt} Calibration Curve ECE: {ece:.2f}") @@ -419,13 +459,23 @@ def produce_calibration_curve( ) # Plot grey dashed vertical lines for weighted average confidence and accuracy - avg_confidence = np.sum(np.array(bin_confidence) * bin_counts) / np.sum(bin_counts) - avg_accuracy = np.sum(np.array(bin_accuracy) * bin_counts) / np.sum(bin_counts) + avg_confidence = np.sum( + np.array(bin_confidence) * bin_counts + ) / np.sum(bin_counts) + avg_accuracy = np.sum( + np.array(bin_accuracy) * bin_counts + ) / np.sum(bin_counts) ax1.axvline( - x=avg_confidence, color="red", linestyle="--", label="Weighted Avg. Confidence" + x=avg_confidence, + color="red", + linestyle="--", + label="Weighted Avg. Confidence", ) ax1.axvline( - x=avg_accuracy, color="black", linestyle="--", label="Weighted Avg. Accuracy" + x=avg_accuracy, + color="black", + linestyle="--", + label="Weighted Avg. Accuracy", ) # Customize the second plot @@ -459,7 +509,13 @@ def produce_calibration_curve( ) # Plot the ideal line (diagonal) - ax2.plot([0, 1], [0, 1], linestyle="--", color="gray", label="Perfect Calibration") + ax2.plot( + [0, 1], + [0, 1], + linestyle="--", + color="gray", + label="Perfect Calibration", + ) # Customize the first plot ax2.set_xlabel("Mean Predicted Probability (Confidence)") @@ -499,7 +555,9 @@ def produce_bar_plot( indices = np.arange(len(data)) # Create the bar plot with improved styling - plt.bar(indices, data, color="blue", edgecolor="black", linewidth=1.2) + plt.bar( + indices, data, color="blue", edgecolor="black", linewidth=1.2 + ) # Adding labels and title plt.xlabel(xlabel, fontsize=12) diff --git a/XOR_MNIST/utils/wandb_logger.py b/XOR_MNIST/utils/wandb_logger.py index ed18380..c0f0dc3 100644 --- a/XOR_MNIST/utils/wandb_logger.py +++ b/XOR_MNIST/utils/wandb_logger.py @@ -65,7 +65,13 @@ def wandb_log_step_prefix(prefix, i, epoch, loss, losses=None): Returns: None: This function does not return a value. """ - wandb.log({f"{prefix}_loss": loss, f"{prefix}_epoch": epoch, f"{prefix}_step": i}) + wandb.log( + { + f"{prefix}_loss": loss, + f"{prefix}_epoch": epoch, + f"{prefix}_step": i, + } + ) if losses is not None: wandb.log(losses) @@ -88,7 +94,11 @@ def wandb_log_epoch_prefix(prefix, **kwargs): acc = kwargs["acc"] c_acc = kwargs["cacc"] wandb.log( - {f"{prefix}_acc": acc, f"{prefix}_c-acc": c_acc, f"{prefix}_epoch": epoch} + { + f"{prefix}_acc": acc, + f"{prefix}_c-acc": c_acc, + f"{prefix}_epoch": epoch, + } ) lr = kwargs["lr"]