forked from v-iashin/video_features
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
71 lines (55 loc) · 2.42 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from omegaconf import OmegaConf
from tqdm import tqdm
from pathlib import Path
import torch
from utils.utils import build_cfg_path, form_list_from_user_input, sanity_check
def main(args_cli):
torch.set_num_threads(16)
# config
args_yml = OmegaConf.load(build_cfg_path(args_cli.feature_type))
args = OmegaConf.merge(args_yml, args_cli) # the latter arguments are prioritized
# OmegaConf.set_readonly(args, True)
print(args)
sanity_check(args)
# verbosing with the print -- haha (TODO: logging)
print(OmegaConf.to_yaml(args))
if args.on_extraction in ["save_numpy", "save_pickle"]:
print(f"Saving features to {args.output_path}")
print("Device:", args.device)
# import are done here to avoid import errors (we have two conda environements)
if args.feature_type == "i3d":
from models.i3d.extract_i3d import ExtractI3D as Extractor
elif args.feature_type == "r21d":
from models.r21d.extract_r21d import ExtractR21D as Extractor
elif args.feature_type == "s3d":
from models.s3d.extract_s3d import ExtractS3D as Extractor
elif args.feature_type == "vggish":
from models.vggish.extract_vggish import ExtractVGGish as Extractor
elif args.feature_type == "resnet":
from models.resnet.extract_resnet import ExtractResNet as Extractor
elif args.feature_type == "raft":
from models.raft.extract_raft import ExtractRAFT as Extractor
elif args.feature_type == "pwc":
from models.pwc.extract_pwc import ExtractPWC as Extractor
elif args.feature_type == "clip":
from models.clip.extract_clip import ExtractCLIP as Extractor
else:
raise NotImplementedError(f"Extractor {args.feature_type} is not implemented.")
extractor = Extractor(args)
# unifies whatever a user specified as paths into a list of paths
# print("output path =", args.output_path)
# Path(args.output_path).mkdir(parents=True, exist_ok=True)
video_paths = form_list_from_user_input(
args.video_paths,
args.file_with_video_paths,
args.output_path,
to_shuffle=False,
ignore_existing=True,
)
print(f"The number of specified videos: {len(video_paths)}")
for video_path in tqdm(video_paths):
extractor._extract(video_path) # note the `_` in the method name
# yep, it is this simple!
if __name__ == "__main__":
args_cli = OmegaConf.from_cli()
main(args_cli)