Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python 3 compatibility and t-batch caching. #9

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions evaluate_interaction_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
if args.train_proportion > 0.8:
sys.exit('Training sequence proportion cannot be greater than 0.8.')
if args.network == "mooc":
print "No interaction prediction for %s" % args.network
print("No interaction prediction for %s" % args.network)
sys.exit(0)

# SET GPU
Expand All @@ -43,7 +43,7 @@
for l in f:
l = l.strip()
if search_string in l:
print "Output file already has results of epoch %d" % args.epoch
print("Output file already has results of epoch %d" % args.epoch)
sys.exit(0)
f.close()

Expand All @@ -58,7 +58,7 @@
num_users = len(user2id)
num_items = len(item2id) + 1
true_labels_ratio = len(y_true)/(sum(y_true)+1)
print "*** Network statistics:\n %d users\n %d items\n %d interactions\n %d/%d true labels ***\n\n" % (num_users, num_items, num_interactions, sum(y_true), len(y_true))
print("*** Network statistics:\n %d users\n %d items\n %d interactions\n %d/%d true labels ***\n\n" % (num_users, num_items, num_interactions, sum(y_true), len(y_true)))

# SET TRAIN, VALIDATION, AND TEST BOUNDARIES
train_end_idx = validation_start_idx = int(num_interactions * args.train_proportion)
Expand Down Expand Up @@ -121,7 +121,7 @@
tbatch_start_time = None
loss = 0
# FORWARD PASS
print "*** Making interaction predictions by forward pass (no t-batching) ***"
print("*** Making interaction predictions by forward pass (no t-batching) ***")
with trange(train_end_idx, test_end_idx) as progress_bar:
for j in progress_bar:
progress_bar.set_description('%dth interaction for validation and testing' % j)
Expand Down Expand Up @@ -218,13 +218,13 @@
fw = open(output_fname, "a")
metrics = ['Mean Reciprocal Rank', 'Recall@10']

print '\n\n*** Validation performance of epoch %d ***' % args.epoch
print('\n\n*** Validation performance of epoch %d ***' % args.epoch)
fw.write('\n\n*** Validation performance of epoch %d ***\n' % args.epoch)
for i in xrange(len(metrics)):
print(metrics[i] + ': ' + str(performance_dict['validation'][i]))
fw.write("Validation: " + metrics[i] + ': ' + str(performance_dict['validation'][i]) + "\n")

print '\n\n*** Test performance of epoch %d ***' % args.epoch
print('\n\n*** Test performance of epoch %d ***' % args.epoch)
fw.write('\n\n*** Test performance of epoch %d ***\n' % args.epoch)
for i in xrange(len(metrics)):
print(metrics[i] + ': ' + str(performance_dict['test'][i]))
Expand Down
12 changes: 6 additions & 6 deletions evaluate_state_change_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
if args.train_proportion > 0.8:
sys.exit('Training sequence proportion cannot be greater than 0.8.')
if args.network == "lastfm":
print "No state change prediction for %s" % args.network
print("No state change prediction for %s" % args.network)
sys.exit(0)

# SET GPU
Expand All @@ -44,7 +44,7 @@
for l in f:
l = l.strip()
if search_string in l:
print "Output file already has results of epoch %d" % args.epoch
print("Output file already has results of epoch %d" % args.epoch)
sys.exit(0)
f.close()

Expand All @@ -59,7 +59,7 @@
num_users = len(user2id)
num_items = len(item2id) + 1
true_labels_ratio = len(y_true)/(sum(y_true)+1)
print "*** Network statistics:\n %d users\n %d items\n %d interactions\n %d/%d true labels ***\n\n" % (num_users, num_items, num_interactions, sum(y_true), len(y_true))
print("*** Network statistics:\n %d users\n %d items\n %d interactions\n %d/%d true labels ***\n\n" % (num_users, num_items, num_interactions, sum(y_true), len(y_true)))

# SET TRAIN, VALIDATION, AND TEST BOUNDARIES
train_end_idx = validation_start_idx = int(num_interactions * args.train_proportion)
Expand Down Expand Up @@ -124,7 +124,7 @@
tbatch_start_time = None
loss = 0
# FORWARD PASS
print "*** Making state change predictions by forward pass (no t-batching) ***"
print("*** Making state change predictions by forward pass (no t-batching) ***")
with trange(train_end_idx, test_end_idx) as progress_bar:
for j in progress_bar:
progress_bar.set_description('%dth interaction for validation and testing' % j)
Expand Down Expand Up @@ -218,14 +218,14 @@
fw = open(output_fname, "a")
metrics = ['AUC']

print '\n\n*** Validation performance of epoch %d ***' % args.epoch
print('\n\n*** Validation performance of epoch %d ***' % args.epoch)
fw.write('\n\n*** Validation performance of epoch %d ***\n' % args.epoch)

for i in xrange(len(metrics)):
print(metrics[i] + ': ' + str(performance_dict['validation'][i]))
fw.write("Validation: " + metrics[i] + ': ' + str(performance_dict['validation'][i]) + "\n")

print '\n\n*** Test performance of epoch %d ***' % args.epoch
print('\n\n*** Test performance of epoch %d ***' % args.epoch)
fw.write('\n\n*** Test performance of epoch %d ***\n' % args.epoch)
for i in xrange(len(metrics)):
print(metrics[i] + ': ' + str(performance_dict['test'][i]))
Expand Down
8 changes: 4 additions & 4 deletions get_final_performance_numbers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,13 @@
else:
metrics = ['AUC']

print '\n\n*** For file: %s ***' % fname
print('\n\n*** For file: %s ***' % fname)
best_val_idx = np.argmax(validation_performances[:,1])
print "Best validation epoch: %d" % best_val_idx
print '\n\n*** Best validation performance (epoch %d) ***' % best_val_idx
print("Best validation epoch: %d" % best_val_idx)
print('\n\n*** Best validation performance (epoch %d) ***' % best_val_idx)
for i in xrange(len(metrics)):
print(metrics[i] + ': ' + str(validation_performances[best_val_idx][i+1]))

print '\n\n*** Final model performance on the test set, i.e., in epoch %d ***' % best_val_idx
print('\n\n*** Final model performance on the test set, i.e., in epoch %d ***' % best_val_idx)
for i in xrange(len(metrics)):
print(metrics[i] + ': ' + str(test_performances[best_val_idx][i+1]))
59 changes: 37 additions & 22 deletions jodie.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
parser.add_argument('--epochs', default=50, type=int, help='Number of epochs to train the model')
parser.add_argument('--embedding_dim', default=128, type=int, help='Number of dimensions of the dynamic embedding')
parser.add_argument('--train_proportion', default=0.8, type=float, help='Fraction of interactions (from the beginning) that are used for training.The next 10% are used for validation and the next 10% for testing')
parser.add_argument('--state_change', default=True, type=bool, help='True if training with state change of users along with interaction prediction. False otherwise. By default, set to True.')
parser.add_argument('--state_change', default=True, type=bool, help='True if training with state change of users along with interaction prediction. False otherwise. By default, set to True.')
parser.add_argument('--cache_tbatches', default=False, type=bool, help='If True, will load pre-computed tbatches from data/${network}/tbatches. By default, set to False.')
args = parser.parse_args()

args.datapath = "data/%s.csv" % args.network
args.datapath = "data/%s.csv" % args.network
if args.train_proportion > 0.8:
sys.exit('Training sequence proportion cannot be greater than 0.8.')

Expand All @@ -42,7 +43,13 @@
num_items = len(item2id) + 1 # one extra item for "none-of-these"
num_features = len(feature_sequence[0])
true_labels_ratio = len(y_true)/(1.0+sum(y_true)) # +1 in denominator in case there are no state change labels, which will throw an error.
print "*** Network statistics:\n %d users\n %d items\n %d interactions\n %d/%d true labels ***\n\n" % (num_users, num_items, num_interactions, sum(y_true), len(y_true))
print("*** Network statistics:\n %d users\n %d items\n %d interactions\n %d/%d true labels ***\n\n" % (num_users, num_items, num_interactions, sum(y_true), len(y_true)))

# ENSURE TBATCH DIRECTORY EXISTS IF NEEDED
if args.cache_tbatches:
args.tbatchdir = "data/%s_tbatches" % args.network
if not os.path.isdir(args.tbatchdir):
os.mkdir(args.tbatchdir)

# SET TRAINING, VALIDATION, TESTING, and TBATCH BOUNDARIES
train_end_idx = validation_start_idx = int(num_interactions * args.train_proportion)
Expand Down Expand Up @@ -85,7 +92,7 @@
'''
THE MODEL IS TRAINED FOR SEVERAL EPOCHS. IN EACH EPOCH, JODIES USES THE TRAINING SET OF INTERACTIONS TO UPDATE ITS PARAMETERS.
'''
print "*** Training the JODIE model for %d epochs ***" % args.epochs
print("*** Training the JODIE model for %d epochs ***" % args.epochs)
with trange(args.epochs) as progress_bar1:
for ep in progress_bar1:
progress_bar1.set_description('Epoch %d of %d' % (ep, args.epochs))
Expand All @@ -102,8 +109,8 @@
tbatch_to_insert = -1
tbatch_full = False

# TRAIN TILL THE END OF TRAINING INTERACTION IDX
with trange(train_end_idx) as progress_bar2:
# TRAIN TILL THE END OF TRAINING INTERACTION IDX
with trange(train_end_idx) as progress_bar2:
for j in progress_bar2:
progress_bar2.set_description('Processed %dth interactions' % j)

Expand All @@ -114,18 +121,21 @@
user_timediff = user_timediffs_sequence[j]
item_timediff = item_timediffs_sequence[j]

# CREATE T-BATCHES: ADD INTERACTION J TO THE CORRECT T-BATCH
tbatch_to_insert = max(lib.tbatchid_user[userid], lib.tbatchid_item[itemid]) + 1
lib.tbatchid_user[userid] = tbatch_to_insert
lib.tbatchid_item[itemid] = tbatch_to_insert

lib.current_tbatches_user[tbatch_to_insert].append(userid)
lib.current_tbatches_item[tbatch_to_insert].append(itemid)
lib.current_tbatches_feature[tbatch_to_insert].append(feature)
lib.current_tbatches_interactionids[tbatch_to_insert].append(j)
lib.current_tbatches_user_timediffs[tbatch_to_insert].append(user_timediff)
lib.current_tbatches_item_timediffs[tbatch_to_insert].append(item_timediff)
lib.current_tbatches_previous_item[tbatch_to_insert].append(user_previous_itemid_sequence[j])
if not args.cache_tbatches or len(os.listdir(args.tbatchdir)) == 0:
# CREATE T-BATCHES: ADD INTERACTION J TO THE CORRECT T-BATCH
tbatch_to_insert = max(lib.tbatchid_user[userid], lib.tbatchid_item[itemid]) + 1
lib.tbatchid_user[userid] = tbatch_to_insert
lib.tbatchid_item[itemid] = tbatch_to_insert

lib.current_tbatches_user[tbatch_to_insert].append(userid)
lib.current_tbatches_item[tbatch_to_insert].append(itemid)
lib.current_tbatches_feature[tbatch_to_insert].append(feature)
lib.current_tbatches_interactionids[tbatch_to_insert].append(j)
lib.current_tbatches_user_timediffs[tbatch_to_insert].append(user_timediff)
lib.current_tbatches_item_timediffs[tbatch_to_insert].append(item_timediff)
lib.current_tbatches_previous_item[tbatch_to_insert].append(user_previous_itemid_sequence[j])
if args.cache_tbatches and len(os.listdir(args.tbatchdir)) > 0:
lib.load_tbatches()

timestamp = timestamp_sequence[j]
if tbatch_start_time is None:
Expand Down Expand Up @@ -196,20 +206,25 @@
user_embeddings_timeseries.detach_()

# REINITIALIZE
reinitialize_tbatches()
tbatch_to_insert = -1
if not args.cache_tbatches:
reinitialize_tbatches()
tbatch_to_insert = -1

# END OF ONE EPOCH
print "\n\nTotal loss in this epoch = %f" % (total_loss)
print("\n\nTotal loss in this epoch = %f" % (total_loss))
item_embeddings_dystat = torch.cat([item_embeddings, item_embedding_static], dim=1)
user_embeddings_dystat = torch.cat([user_embeddings, user_embedding_static], dim=1)
# SAVE CURRENT MODEL TO DISK TO BE USED IN EVALUATION.
save_model(model, optimizer, args, ep, user_embeddings_dystat, item_embeddings_dystat, train_end_idx, user_embeddings_timeseries, item_embeddings_timeseries)

# SAVE TBATCHES IF NECESSARY
if args.cache_tbatches and len(os.listdir(args.tbatchdir)) == 0:
lib.save_tbatches()

user_embeddings = initial_user_embedding.repeat(num_users, 1)
item_embeddings = initial_item_embedding.repeat(num_items, 1)

# END OF ALL EPOCHS. SAVE FINAL MODEL DISK TO BE USED IN EVALUATION.
print "\n\n*** Training complete. Saving final model. ***\n\n"
print("\n\n*** Training complete. Saving final model. ***\n\n")
save_model(model, optimizer, args, ep, user_embeddings_dystat, item_embeddings_dystat, train_end_idx, user_embeddings_timeseries, item_embeddings_timeseries)

13 changes: 6 additions & 7 deletions library_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import copy
from collections import defaultdict
import os, re
import cPickle
import argparse
from sklearn.preprocessing import scale

Expand Down Expand Up @@ -40,7 +39,7 @@ def load_network(args, time_scaling=True):
start_timestamp = None
y_true_labels = []

print "\n\n**** Loading %s network from file: %s ****" % (network, datapath)
print("\n\n**** Loading %s network from file: %s ****" % (network, datapath))
f = open(datapath,"r")
f.readline()
for cnt, l in enumerate(f):
Expand All @@ -52,14 +51,14 @@ def load_network(args, time_scaling=True):
start_timestamp = float(ls[2])
timestamp_sequence.append(float(ls[2]) - start_timestamp)
y_true_labels.append(int(ls[3])) # label = 1 at state change, 0 otherwise
feature_sequence.append(map(float,ls[4:]))
feature_sequence.append(list(map(float,ls[4:])))
f.close()

user_sequence = np.array(user_sequence)
item_sequence = np.array(item_sequence)
timestamp_sequence = np.array(timestamp_sequence)

print "Formating item sequence"
print("Formating item sequence")
nodeid = 0
item2id = {}
item_timedifference_sequence = []
Expand All @@ -74,7 +73,7 @@ def load_network(args, time_scaling=True):
num_items = len(item2id)
item_sequence_id = [item2id[item] for item in item_sequence]

print "Formating user sequence"
print("Formating user sequence")
nodeid = 0
user2id = {}
user_timedifference_sequence = []
Expand All @@ -94,11 +93,11 @@ def load_network(args, time_scaling=True):
user_sequence_id = [user2id[user] for user in user_sequence]

if time_scaling:
print "Scaling timestamps"
print("Scaling timestamps")
user_timedifference_sequence = scale(np.array(user_timedifference_sequence) + 1)
item_timedifference_sequence = scale(np.array(item_timedifference_sequence) + 1)

print "*** Network loading completed ***\n\n"
print("*** Network loading completed ***\n\n")
return [user2id, user_sequence_id, user_timedifference_sequence, user_previous_itemid_sequence, \
item2id, item_sequence_id, item_timedifference_sequence, \
timestamp_sequence, \
Expand Down
Loading