-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
59 lines (45 loc) · 2.16 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from MMNDB.Model.pipeline import Pipeline
from MMNDB.Data.data import CustomCocoDataset
from MMNDB.Utils.utils import write_report, precision_recall_f1, macro_metrics, write_report_processor
import pytorch_lightning as pl
import hydra
import logging
logger = logging.getLogger(__name__)
@hydra.main(config_path="conf", config_name="config")
def main(opt):
pl.seed_everything(opt.seed)
data = CustomCocoDataset(path=opt.data.data_path, split=opt.data.split)
pipe = Pipeline(config=opt)
retriver_obj_metrics = []
processor_obj_metrics = []
object_keys = list(data.dict_obj_id_name.keys())
if opt.data.obj_id != 0:
object_keys = [opt.data.obj_id]
for obj_id in object_keys:
processor_scores, retriever_scores = pipe.pipeline(obj_id, mode=opt.task.task_type)
processor_obj_metrics.append(processor_scores)
retriver_obj_metrics.append(retriever_scores)
logging.info(f"Processor mode only, OBJECT ID: {obj_id}")
#logger.info(f"Best threshold: {max(pipe.retriever.results_dict, key=pipe.retriever.results_dict.get)}")
retriver_obj_metrics = list(filter(lambda x: x is not None, retriver_obj_metrics))
processor_obj_metrics = list(filter(lambda x: x is not None, processor_obj_metrics))
if len(retriver_obj_metrics) != 0:
metrics_macro = precision_recall_f1(retriver_obj_metrics, mode="macro")
metrics_micro = precision_recall_f1(retriver_obj_metrics, mode="micro")
if opt.data.obj_id == 0:
bootstrapped_metrics = precision_recall_f1(
retriver_obj_metrics, mode="bootstrap"
)
else:
bootstrapped_metrics = None
write_report(
pipe.model_name, opt, metrics_micro, metrics_macro, bootstrapped_metrics
)
if len(processor_obj_metrics) != 0:
processor_macro_metrics = macro_metrics(processor_obj_metrics)
write_report_processor(pipe.model_name, opt, processor_macro_metrics)
logger.info(f"Broken classes: {pipe.retriever.broken_classes}")
#logger.info(f"Best threshold: {np.mean(pipe.retriever.grid_search_value)}")
logger.info("Done!")
if __name__ == "__main__":
main()