Skip to content

Commit

Permalink
sped up and updated the test
Browse files Browse the repository at this point in the history
  • Loading branch information
stewarthe6 committed Dec 16, 2024
1 parent d1b25d8 commit 74c1953
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@
"system": "LC",
"transformers": "True",
"model_type": "NN",
"featurizer": "computed_descriptors",
"descriptor_type": "rdkit_raw",
"featurizer": "ecfp",
"weight_transform_type": "balancing",
"learning_rate": ".0007",
"layer_sizes": "512,128",
"layer_sizes": "20,10",
"dropouts": "0.3,0.3",
"save_results": "False",
"max_epochs": "2",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"model_type": "NN",
"featurizer": "ecfp",
"learning_rate": ".0007",
"layer_sizes": "512,128",
"layer_sizes": "20,10",
"dropouts": "0.3,0.3",
"save_results": "False",
"max_epochs": "2",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,11 @@ def make_pipeline(params):

def make_pipeline_and_get_weights(params):
model_pipeline = make_pipeline(params)
model_wrapper = model_pipeline.model_wrapper
train_dataset = model_pipeline.data.train_valid_dsets[0][0]
transformed_data = model_wrapper.transform_dataset(train_dataset, fold=0)

print(model_pipeline.model_wrapper.transformers_w)
print(np.unique(model_pipeline.data.train_valid_dsets[0][0].y, return_counts=True))
return model_pipeline.data.train_valid_dsets[0][0].w
return transformed_data.w

def make_relative_to_file(relative_path):
script_path = os.path.dirname(os.path.realpath(__file__))
Expand Down Expand Up @@ -131,5 +132,5 @@ def params_w_balan(dset_key, res_dir):
return params

if __name__ == '__main__':
test_all_transformers()
#test_balancing_transformer()
#test_all_transformers()
test_balancing_transformer()

0 comments on commit 74c1953

Please sign in to comment.