From 12bd3bda9de474aa90adfd59fd54f362ae0def78 Mon Sep 17 00:00:00 2001 From: John Palowitch Date: Wed, 25 Mar 2020 16:07:30 -0700 Subject: [PATCH 1/4] Change all non-paren print statements to Python3 print. --- evaluate_interaction_prediction.py | 12 ++++++------ evaluate_state_change_prediction.py | 12 ++++++------ get_final_performance_numbers.py | 8 ++++---- jodie.py | 8 ++++---- library_data.py | 10 +++++----- library_models.py | 16 ++++++++-------- tbatch.py | 16 ++++++++-------- 7 files changed, 41 insertions(+), 41 deletions(-) diff --git a/evaluate_interaction_prediction.py b/evaluate_interaction_prediction.py index 4f706ba..a58f68d 100644 --- a/evaluate_interaction_prediction.py +++ b/evaluate_interaction_prediction.py @@ -27,7 +27,7 @@ if args.train_proportion > 0.8: sys.exit('Training sequence proportion cannot be greater than 0.8.') if args.network == "mooc": - print "No interaction prediction for %s" % args.network + print("No interaction prediction for %s" % args.network) sys.exit(0) # SET GPU @@ -43,7 +43,7 @@ for l in f: l = l.strip() if search_string in l: - print "Output file already has results of epoch %d" % args.epoch + print("Output file already has results of epoch %d" % args.epoch) sys.exit(0) f.close() @@ -58,7 +58,7 @@ num_users = len(user2id) num_items = len(item2id) + 1 true_labels_ratio = len(y_true)/(sum(y_true)+1) -print "*** Network statistics:\n %d users\n %d items\n %d interactions\n %d/%d true labels ***\n\n" % (num_users, num_items, num_interactions, sum(y_true), len(y_true)) +print("*** Network statistics:\n %d users\n %d items\n %d interactions\n %d/%d true labels ***\n\n" % (num_users, num_items, num_interactions, sum(y_true), len(y_true))) # SET TRAIN, VALIDATION, AND TEST BOUNDARIES train_end_idx = validation_start_idx = int(num_interactions * args.train_proportion) @@ -121,7 +121,7 @@ tbatch_start_time = None loss = 0 # FORWARD PASS -print "*** Making interaction predictions by forward pass (no t-batching) ***" +print("*** Making interaction predictions by forward pass (no t-batching) ***") with trange(train_end_idx, test_end_idx) as progress_bar: for j in progress_bar: progress_bar.set_description('%dth interaction for validation and testing' % j) @@ -218,13 +218,13 @@ fw = open(output_fname, "a") metrics = ['Mean Reciprocal Rank', 'Recall@10'] -print '\n\n*** Validation performance of epoch %d ***' % args.epoch +print('\n\n*** Validation performance of epoch %d ***' % args.epoch) fw.write('\n\n*** Validation performance of epoch %d ***\n' % args.epoch) for i in xrange(len(metrics)): print(metrics[i] + ': ' + str(performance_dict['validation'][i])) fw.write("Validation: " + metrics[i] + ': ' + str(performance_dict['validation'][i]) + "\n") -print '\n\n*** Test performance of epoch %d ***' % args.epoch +print('\n\n*** Test performance of epoch %d ***' % args.epoch) fw.write('\n\n*** Test performance of epoch %d ***\n' % args.epoch) for i in xrange(len(metrics)): print(metrics[i] + ': ' + str(performance_dict['test'][i])) diff --git a/evaluate_state_change_prediction.py b/evaluate_state_change_prediction.py index cc3356f..4a781c6 100644 --- a/evaluate_state_change_prediction.py +++ b/evaluate_state_change_prediction.py @@ -28,7 +28,7 @@ if args.train_proportion > 0.8: sys.exit('Training sequence proportion cannot be greater than 0.8.') if args.network == "lastfm": - print "No state change prediction for %s" % args.network + print("No state change prediction for %s" % args.network) sys.exit(0) # SET GPU @@ -44,7 +44,7 @@ for l in f: l = l.strip() if search_string in l: - print "Output file already has results of epoch %d" % args.epoch + print("Output file already has results of epoch %d" % args.epoch) sys.exit(0) f.close() @@ -59,7 +59,7 @@ num_users = len(user2id) num_items = len(item2id) + 1 true_labels_ratio = len(y_true)/(sum(y_true)+1) -print "*** Network statistics:\n %d users\n %d items\n %d interactions\n %d/%d true labels ***\n\n" % (num_users, num_items, num_interactions, sum(y_true), len(y_true)) +print("*** Network statistics:\n %d users\n %d items\n %d interactions\n %d/%d true labels ***\n\n" % (num_users, num_items, num_interactions, sum(y_true), len(y_true))) # SET TRAIN, VALIDATION, AND TEST BOUNDARIES train_end_idx = validation_start_idx = int(num_interactions * args.train_proportion) @@ -124,7 +124,7 @@ tbatch_start_time = None loss = 0 # FORWARD PASS -print "*** Making state change predictions by forward pass (no t-batching) ***" +print("*** Making state change predictions by forward pass (no t-batching) ***") with trange(train_end_idx, test_end_idx) as progress_bar: for j in progress_bar: progress_bar.set_description('%dth interaction for validation and testing' % j) @@ -218,14 +218,14 @@ fw = open(output_fname, "a") metrics = ['AUC'] -print '\n\n*** Validation performance of epoch %d ***' % args.epoch +print('\n\n*** Validation performance of epoch %d ***' % args.epoch) fw.write('\n\n*** Validation performance of epoch %d ***\n' % args.epoch) for i in xrange(len(metrics)): print(metrics[i] + ': ' + str(performance_dict['validation'][i])) fw.write("Validation: " + metrics[i] + ': ' + str(performance_dict['validation'][i]) + "\n") -print '\n\n*** Test performance of epoch %d ***' % args.epoch +print('\n\n*** Test performance of epoch %d ***' % args.epoch) fw.write('\n\n*** Test performance of epoch %d ***\n' % args.epoch) for i in xrange(len(metrics)): print(metrics[i] + ': ' + str(performance_dict['test'][i])) diff --git a/get_final_performance_numbers.py b/get_final_performance_numbers.py index 689304e..ad50563 100644 --- a/get_final_performance_numbers.py +++ b/get_final_performance_numbers.py @@ -43,13 +43,13 @@ else: metrics = ['AUC'] -print '\n\n*** For file: %s ***' % fname +print('\n\n*** For file: %s ***' % fname) best_val_idx = np.argmax(validation_performances[:,1]) -print "Best validation epoch: %d" % best_val_idx -print '\n\n*** Best validation performance (epoch %d) ***' % best_val_idx +print("Best validation epoch: %d" % best_val_idx) +print('\n\n*** Best validation performance (epoch %d) ***' % best_val_idx) for i in xrange(len(metrics)): print(metrics[i] + ': ' + str(validation_performances[best_val_idx][i+1])) -print '\n\n*** Final model performance on the test set, i.e., in epoch %d ***' % best_val_idx +print('\n\n*** Final model performance on the test set, i.e., in epoch %d ***' % best_val_idx) for i in xrange(len(metrics)): print(metrics[i] + ': ' + str(test_performances[best_val_idx][i+1])) diff --git a/jodie.py b/jodie.py index 9b729f3..f1eee71 100644 --- a/jodie.py +++ b/jodie.py @@ -42,7 +42,7 @@ num_items = len(item2id) + 1 # one extra item for "none-of-these" num_features = len(feature_sequence[0]) true_labels_ratio = len(y_true)/(1.0+sum(y_true)) # +1 in denominator in case there are no state change labels, which will throw an error. -print "*** Network statistics:\n %d users\n %d items\n %d interactions\n %d/%d true labels ***\n\n" % (num_users, num_items, num_interactions, sum(y_true), len(y_true)) +print("*** Network statistics:\n %d users\n %d items\n %d interactions\n %d/%d true labels ***\n\n" % (num_users, num_items, num_interactions, sum(y_true), len(y_true))) # SET TRAINING, VALIDATION, TESTING, and TBATCH BOUNDARIES train_end_idx = validation_start_idx = int(num_interactions * args.train_proportion) @@ -85,7 +85,7 @@ ''' THE MODEL IS TRAINED FOR SEVERAL EPOCHS. IN EACH EPOCH, JODIES USES THE TRAINING SET OF INTERACTIONS TO UPDATE ITS PARAMETERS. ''' -print "*** Training the JODIE model for %d epochs ***" % args.epochs +print("*** Training the JODIE model for %d epochs ***" % args.epochs) with trange(args.epochs) as progress_bar1: for ep in progress_bar1: progress_bar1.set_description('Epoch %d of %d' % (ep, args.epochs)) @@ -200,7 +200,7 @@ tbatch_to_insert = -1 # END OF ONE EPOCH - print "\n\nTotal loss in this epoch = %f" % (total_loss) + print("\n\nTotal loss in this epoch = %f" % (total_loss)) item_embeddings_dystat = torch.cat([item_embeddings, item_embedding_static], dim=1) user_embeddings_dystat = torch.cat([user_embeddings, user_embedding_static], dim=1) # SAVE CURRENT MODEL TO DISK TO BE USED IN EVALUATION. @@ -210,6 +210,6 @@ item_embeddings = initial_item_embedding.repeat(num_items, 1) # END OF ALL EPOCHS. SAVE FINAL MODEL DISK TO BE USED IN EVALUATION. -print "\n\n*** Training complete. Saving final model. ***\n\n" +print("\n\n*** Training complete. Saving final model. ***\n\n") save_model(model, optimizer, args, ep, user_embeddings_dystat, item_embeddings_dystat, train_end_idx, user_embeddings_timeseries, item_embeddings_timeseries) diff --git a/library_data.py b/library_data.py index 13c18b7..842df88 100644 --- a/library_data.py +++ b/library_data.py @@ -40,7 +40,7 @@ def load_network(args, time_scaling=True): start_timestamp = None y_true_labels = [] - print "\n\n**** Loading %s network from file: %s ****" % (network, datapath) + print("\n\n**** Loading %s network from file: %s ****" % (network, datapath)) f = open(datapath,"r") f.readline() for cnt, l in enumerate(f): @@ -59,7 +59,7 @@ def load_network(args, time_scaling=True): item_sequence = np.array(item_sequence) timestamp_sequence = np.array(timestamp_sequence) - print "Formating item sequence" + print("Formating item sequence") nodeid = 0 item2id = {} item_timedifference_sequence = [] @@ -74,7 +74,7 @@ def load_network(args, time_scaling=True): num_items = len(item2id) item_sequence_id = [item2id[item] for item in item_sequence] - print "Formating user sequence" + print("Formating user sequence") nodeid = 0 user2id = {} user_timedifference_sequence = [] @@ -94,11 +94,11 @@ def load_network(args, time_scaling=True): user_sequence_id = [user2id[user] for user in user_sequence] if time_scaling: - print "Scaling timestamps" + print("Scaling timestamps") user_timedifference_sequence = scale(np.array(user_timedifference_sequence) + 1) item_timedifference_sequence = scale(np.array(item_timedifference_sequence) + 1) - print "*** Network loading completed ***\n\n" + print("*** Network loading completed ***\n\n") return [user2id, user_sequence_id, user_timedifference_sequence, user_previous_itemid_sequence, \ item2id, item_sequence_id, item_timedifference_sequence, \ timestamp_sequence, \ diff --git a/library_models.py b/library_models.py index 9c7fd7f..165bb5e 100644 --- a/library_models.py +++ b/library_models.py @@ -46,7 +46,7 @@ class JODIE(nn.Module): def __init__(self, args, num_features, num_users, num_items): super(JODIE,self).__init__() - print "*** Initializing the JODIE model ***" + print("*** Initializing the JODIE model ***") self.modelname = args.model self.embedding_dim = args.embedding_dim self.num_users = num_users @@ -54,22 +54,22 @@ def __init__(self, args, num_features, num_users, num_items): self.user_static_embedding_size = num_users self.item_static_embedding_size = num_items - print "Initializing user and item embeddings" + print("Initializing user and item embeddings") self.initial_user_embedding = nn.Parameter(torch.Tensor(args.embedding_dim)) self.initial_item_embedding = nn.Parameter(torch.Tensor(args.embedding_dim)) rnn_input_size_items = rnn_input_size_users = self.embedding_dim + 1 + num_features - print "Initializing user and item RNNs" + print("Initializing user and item RNNs") self.item_rnn = nn.RNNCell(rnn_input_size_users, self.embedding_dim) self.user_rnn = nn.RNNCell(rnn_input_size_items, self.embedding_dim) - print "Initializing linear layers" + print("Initializing linear layers") self.linear_layer1 = nn.Linear(self.embedding_dim, 50) self.linear_layer2 = nn.Linear(50, 2) self.prediction_layer = nn.Linear(self.user_static_embedding_size + self.item_static_embedding_size + self.embedding_dim * 2, self.item_static_embedding_size + self.embedding_dim) self.embedding_layer = NormalLinear(1, self.embedding_dim) - print "*** JODIE initialization complete ***\n\n" + print("*** JODIE initialization complete ***\n\n") def forward(self, user_embeddings, item_embeddings, timediffs=None, features=None, select=None): if select == 'item_update': @@ -141,7 +141,7 @@ def calculate_state_prediction_loss(model, tbatch_interactionids, user_embedding # SAVE TRAINED MODEL TO DISK def save_model(model, optimizer, args, epoch, user_embeddings, item_embeddings, train_end_idx, user_embeddings_time_series=None, item_embeddings_time_series=None, path=PATH): - print "*** Saving embeddings and model ***" + print("*** Saving embeddings and model ***") state = { 'user_embeddings': user_embeddings.data.cpu().numpy(), 'item_embeddings': item_embeddings.data.cpu().numpy(), @@ -161,7 +161,7 @@ def save_model(model, optimizer, args, epoch, user_embeddings, item_embeddings, filename = os.path.join(directory, "checkpoint.%s.ep%d.tp%.1f.pth.tar" % (args.model, epoch, args.train_proportion)) torch.save(state, filename) - print "*** Saved embeddings and model to file: %s ***\n\n" % filename + print("*** Saved embeddings and model to file: %s ***\n\n" % filename) # LOAD PREVIOUSLY TRAINED AND SAVED MODEL @@ -169,7 +169,7 @@ def load_model(model, optimizer, args, epoch): modelname = args.model filename = PATH + "saved_models/%s/checkpoint.%s.ep%d.tp%.1f.pth.tar" % (args.network, modelname, epoch, args.train_proportion) checkpoint = torch.load(filename) - print "Loading saved embeddings and model: %s" % filename + print("Loading saved embeddings and model: %s" % filename) args.start_epoch = checkpoint['epoch'] user_embeddings = Variable(torch.from_numpy(checkpoint['user_embeddings']).cuda()) item_embeddings = Variable(torch.from_numpy(checkpoint['item_embeddings']).cuda()) diff --git a/tbatch.py b/tbatch.py index 9d2ce6a..4b947ee 100644 --- a/tbatch.py +++ b/tbatch.py @@ -28,7 +28,7 @@ num_items = len(item2id) + 1 # one extra item for "none-of-these" num_features = len(feature_sequence[0]) true_labels_ratio = len(y_true)/(1.0+sum(y_true)) # +1 in denominator in case there are no state change labels, which will throw an error. -print "*** Network statistics:\n %d users\n %d items\n %d interactions\n %d/%d true labels ***\n\n" % (num_users, num_items, num_interactions, sum(y_true), len(y_true)) +print("*** Network statistics:\n %d users\n %d items\n %d interactions\n %d/%d true labels ***\n\n" % (num_users, num_items, num_interactions, sum(y_true), len(y_true))) # OUTPUT FILE FOR THE BATCHES output_fname = "results/batches_%s.txt" % args.network @@ -52,7 +52,7 @@ tbatch_timespan = timespan / 500 # CREATE THE TBATCHES -print "*** Creating T-batches from %d interactions ***" % train_end_idx +print("*** Creating T-batches from %d interactions ***" % train_end_idx) # INITIALIZE TBATCH PARAMETERS tbatch_start_time = None tbatch_to_insert = -1 @@ -98,7 +98,7 @@ # AFTER PROCESSING ALL INTERACTIONS IN A TIMESPAN if timestamp - tbatch_start_time > tbatch_timespan or j == num_interactions - 1: # AFTER ALL INTERACTIONS IN THE TIME WINDOW ARE CONVERTED TO T-BATCHES, SAVE THEM TO FILE. - print 'Read till interaction %d. This timespan had %d interactions and created %d T-batches.' % (j, tbatch_interaction_count, len(lib.current_tbatches_user)) + print('Read till interaction %d. This timespan had %d interactions and created %d T-batches.' % (j, tbatch_interaction_count, len(lib.current_tbatches_user))) total_tbatches_count += len(lib.current_tbatches_user) total_interactions_count += tbatch_interaction_count @@ -127,8 +127,8 @@ tbatch_to_insert = -1 fout.close() -print "=======================" -print "T-batching complete. Output file: %s." % output_fname -print "%d interactions were processed, which created %d t-batches." % (total_interactions_count, total_tbatches_count) -print "This is a %.3f%% compression." % ((total_interactions_count - total_tbatches_count)*100.0/total_interactions_count) -print "=======================" +print("=======================") +print("T-batching complete. Output file: %s." % output_fname) +print("%d interactions were processed, which created %d t-batches." % (total_interactions_count, total_tbatches_count)) +print("This is a %.3f%% compression." % ((total_interactions_count - total_tbatches_count)*100.0/total_interactions_count)) +print("=======================") From 0700918e97de6cd6b7327b1f38bb5e6a477a28b1 Mon Sep 17 00:00:00 2001 From: John Palowitch Date: Wed, 25 Mar 2020 16:40:34 -0700 Subject: [PATCH 2/4] Remove cPickle imports --- library_data.py | 1 - library_models.py | 1 - 2 files changed, 2 deletions(-) diff --git a/library_data.py b/library_data.py index 842df88..a8f48b0 100644 --- a/library_data.py +++ b/library_data.py @@ -12,7 +12,6 @@ import copy from collections import defaultdict import os, re -import cPickle import argparse from sklearn.preprocessing import scale diff --git a/library_models.py b/library_models.py index 165bb5e..b51ab78 100644 --- a/library_models.py +++ b/library_models.py @@ -15,7 +15,6 @@ import sys from collections import defaultdict import os -import cPickle import gpustat from itertools import chain from tqdm import tqdm, trange, tqdm_notebook, tnrange From 77c1e8edc3eebccd7e92a1d6cd6df4fbc249e010 Mon Sep 17 00:00:00 2001 From: John Palowitch Date: Mon, 30 Mar 2020 17:15:10 -0700 Subject: [PATCH 3/4] Use list() to unroll yields --- library_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library_data.py b/library_data.py index a8f48b0..1574e40 100644 --- a/library_data.py +++ b/library_data.py @@ -51,7 +51,7 @@ def load_network(args, time_scaling=True): start_timestamp = float(ls[2]) timestamp_sequence.append(float(ls[2]) - start_timestamp) y_true_labels.append(int(ls[3])) # label = 1 at state change, 0 otherwise - feature_sequence.append(map(float,ls[4:])) + feature_sequence.append(list(map(float,ls[4:]))) f.close() user_sequence = np.array(user_sequence) From 18d341b794204cbbf7bc26ff8c46d56000ef3b53 Mon Sep 17 00:00:00 2001 From: John Palowitch Date: Mon, 30 Mar 2020 23:55:56 -0700 Subject: [PATCH 4/4] A start on caching tbatches --- jodie.py | 51 ++++++++++++++++++++++++++++++----------------- library_models.py | 34 +++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 18 deletions(-) diff --git a/jodie.py b/jodie.py index f1eee71..408be48 100644 --- a/jodie.py +++ b/jodie.py @@ -20,10 +20,11 @@ parser.add_argument('--epochs', default=50, type=int, help='Number of epochs to train the model') parser.add_argument('--embedding_dim', default=128, type=int, help='Number of dimensions of the dynamic embedding') parser.add_argument('--train_proportion', default=0.8, type=float, help='Fraction of interactions (from the beginning) that are used for training.The next 10% are used for validation and the next 10% for testing') -parser.add_argument('--state_change', default=True, type=bool, help='True if training with state change of users along with interaction prediction. False otherwise. By default, set to True.') +parser.add_argument('--state_change', default=True, type=bool, help='True if training with state change of users along with interaction prediction. False otherwise. By default, set to True.') +parser.add_argument('--cache_tbatches', default=False, type=bool, help='If True, will load pre-computed tbatches from data/${network}/tbatches. By default, set to False.') args = parser.parse_args() -args.datapath = "data/%s.csv" % args.network +args.datapath = "data/%s.csv" % args.network if args.train_proportion > 0.8: sys.exit('Training sequence proportion cannot be greater than 0.8.') @@ -44,6 +45,12 @@ true_labels_ratio = len(y_true)/(1.0+sum(y_true)) # +1 in denominator in case there are no state change labels, which will throw an error. print("*** Network statistics:\n %d users\n %d items\n %d interactions\n %d/%d true labels ***\n\n" % (num_users, num_items, num_interactions, sum(y_true), len(y_true))) +# ENSURE TBATCH DIRECTORY EXISTS IF NEEDED +if args.cache_tbatches: + args.tbatchdir = "data/%s_tbatches" % args.network + if not os.path.isdir(args.tbatchdir): + os.mkdir(args.tbatchdir) + # SET TRAINING, VALIDATION, TESTING, and TBATCH BOUNDARIES train_end_idx = validation_start_idx = int(num_interactions * args.train_proportion) test_start_idx = int(num_interactions * (args.train_proportion+0.1)) @@ -102,8 +109,8 @@ tbatch_to_insert = -1 tbatch_full = False - # TRAIN TILL THE END OF TRAINING INTERACTION IDX - with trange(train_end_idx) as progress_bar2: + # TRAIN TILL THE END OF TRAINING INTERACTION IDX + with trange(train_end_idx) as progress_bar2: for j in progress_bar2: progress_bar2.set_description('Processed %dth interactions' % j) @@ -114,18 +121,21 @@ user_timediff = user_timediffs_sequence[j] item_timediff = item_timediffs_sequence[j] - # CREATE T-BATCHES: ADD INTERACTION J TO THE CORRECT T-BATCH - tbatch_to_insert = max(lib.tbatchid_user[userid], lib.tbatchid_item[itemid]) + 1 - lib.tbatchid_user[userid] = tbatch_to_insert - lib.tbatchid_item[itemid] = tbatch_to_insert - - lib.current_tbatches_user[tbatch_to_insert].append(userid) - lib.current_tbatches_item[tbatch_to_insert].append(itemid) - lib.current_tbatches_feature[tbatch_to_insert].append(feature) - lib.current_tbatches_interactionids[tbatch_to_insert].append(j) - lib.current_tbatches_user_timediffs[tbatch_to_insert].append(user_timediff) - lib.current_tbatches_item_timediffs[tbatch_to_insert].append(item_timediff) - lib.current_tbatches_previous_item[tbatch_to_insert].append(user_previous_itemid_sequence[j]) + if not args.cache_tbatches or len(os.listdir(args.tbatchdir)) == 0: + # CREATE T-BATCHES: ADD INTERACTION J TO THE CORRECT T-BATCH + tbatch_to_insert = max(lib.tbatchid_user[userid], lib.tbatchid_item[itemid]) + 1 + lib.tbatchid_user[userid] = tbatch_to_insert + lib.tbatchid_item[itemid] = tbatch_to_insert + + lib.current_tbatches_user[tbatch_to_insert].append(userid) + lib.current_tbatches_item[tbatch_to_insert].append(itemid) + lib.current_tbatches_feature[tbatch_to_insert].append(feature) + lib.current_tbatches_interactionids[tbatch_to_insert].append(j) + lib.current_tbatches_user_timediffs[tbatch_to_insert].append(user_timediff) + lib.current_tbatches_item_timediffs[tbatch_to_insert].append(item_timediff) + lib.current_tbatches_previous_item[tbatch_to_insert].append(user_previous_itemid_sequence[j]) + if args.cache_tbatches and len(os.listdir(args.tbatchdir)) > 0: + lib.load_tbatches() timestamp = timestamp_sequence[j] if tbatch_start_time is None: @@ -196,8 +206,9 @@ user_embeddings_timeseries.detach_() # REINITIALIZE - reinitialize_tbatches() - tbatch_to_insert = -1 + if not args.cache_tbatches: + reinitialize_tbatches() + tbatch_to_insert = -1 # END OF ONE EPOCH print("\n\nTotal loss in this epoch = %f" % (total_loss)) @@ -206,6 +217,10 @@ # SAVE CURRENT MODEL TO DISK TO BE USED IN EVALUATION. save_model(model, optimizer, args, ep, user_embeddings_dystat, item_embeddings_dystat, train_end_idx, user_embeddings_timeseries, item_embeddings_timeseries) + # SAVE TBATCHES IF NECESSARY + if args.cache_tbatches and len(os.listdir(args.tbatchdir)) == 0: + lib.save_tbatches() + user_embeddings = initial_user_embedding.repeat(num_users, 1) item_embeddings = initial_item_embedding.repeat(num_items, 1) diff --git a/library_models.py b/library_models.py index b51ab78..7666e4b 100644 --- a/library_models.py +++ b/library_models.py @@ -19,6 +19,7 @@ from itertools import chain from tqdm import tqdm, trange, tqdm_notebook, tnrange import csv +import json PATH = "./" @@ -126,6 +127,39 @@ def reinitialize_tbatches(): global total_reinitialization_count total_reinitialization_count +=1 +# LOAD/SAVE TBATCHES +def save_tbatch_object(filename, tbatch): + with open(filename, 'w') as f: + f.write(json.dumps(tbatch)) + +def load_tbatch_object(filename, tbatch): + with open(filename) as f: + return json.loads(f.read()) + +def save_tbatches(dir): + save_tbatch_object(os.path.join(dir, "interactionids"), current_tbatches_interactionids) + save_tbatch_object(os.path.join(dir, "user"), current_tbatches_user) + save_tbatch_object(os.path.join(dir, "item"), current_tbatches_item) + save_tbatch_object(os.path.join(dir, "timestamp"), current_tbatches_timestamp) + save_tbatch_object(os.path.join(dir, "feature"), current_tbatches_feature) + save_tbatch_object(os.path.join(dir, "label"), current_tbatches_label) + save_tbatch_object(os.path.join(dir, "previous_item"), current_tbatches_previous_item) + save_tbatch_object(os.path.join(dir, "user_timediffs"), current_tbatches_user_timediffs) + save_tbatch_object(os.path.join(dir, "item_timediffs"), current_tbatches_item_timediffs) + save_tbatch_object(os.path.join(dir, "user_timediffs_next"), current_tbatches_user_timediffs_next) + +def load_tbatches(dir): + load_tbatch_object(os.path.join(dir, "interactionids"), current_tbatches_interactionids) + load_tbatch_object(os.path.join(dir, "user"), current_tbatches_user) + load_tbatch_object(os.path.join(dir, "item"), current_tbatches_item) + load_tbatch_object(os.path.join(dir, "timestamp"), current_tbatches_timestamp) + load_tbatch_object(os.path.join(dir, "feature"), current_tbatches_feature) + load_tbatch_object(os.path.join(dir, "label"), current_tbatches_label) + load_tbatch_object(os.path.join(dir, "previous_item"), current_tbatches_previous_item) + load_tbatch_object(os.path.join(dir, "user_timediffs"), current_tbatches_user_timediffs) + load_tbatch_object(os.path.join(dir, "item_timediffs"), current_tbatches_item_timediffs) + load_tbatch_object(os.path.join(dir, "user_timediffs_next"), current_tbatches_user_timediffs_next) + # CALCULATE LOSS FOR THE PREDICTED USER STATE def calculate_state_prediction_loss(model, tbatch_interactionids, user_embeddings_time_series, y_true, loss_function):