Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dir mixup #4

Open
wants to merge 117 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
117 commits
Select commit Hold shift + click to select a range
23df5b5
Debug for windows
seanyeo300 Apr 5, 2024
5d544de
Create new model and training files
seanyeo300 Apr 8, 2024
ab68bed
Added script to save logits
seanyeo300 Apr 8, 2024
fead781
Upload Checkpoint
seanyeo300 Apr 8, 2024
4104761
Create .gitignore
seanyeo300 Apr 9, 2024
50f51ab
ckpt files
seanyeo300 Apr 9, 2024
452a519
Delete dcase24.py
seanyeo300 Apr 9, 2024
f5e25e9
Update models
seanyeo300 Apr 9, 2024
e28709e
Update .gitignore
seanyeo300 Apr 9, 2024
c752444
Resolving version control
seanyeo300 Apr 9, 2024
670303a
Models and helpers
seanyeo300 Apr 9, 2024
baae2dd
Merge branch 'main' of https://github.com/seanyeo300/dcase2024_task1
seanyeo300 Apr 9, 2024
eb7b97b
creating dev scripts
seanyeo300 Apr 9, 2024
66dc199
Fixing version control
seanyeo300 Apr 9, 2024
a7b964a
Fixing version control
seanyeo300 Apr 9, 2024
d5ac480
Updated get_logits
seanyeo300 Apr 9, 2024
cbb0b08
Create test.ipynb
seanyeo300 Apr 9, 2024
a4b8ea4
version control
seanyeo300 Apr 9, 2024
6ba62d3
Merge branch 'main' of https://github.com/seanyeo300/dcase2024_task1
seanyeo300 Apr 9, 2024
900bb1c
Merge
seanyeo300 Apr 9, 2024
470ad84
Upload logits and focusnet script
seanyeo300 Apr 12, 2024
b5e78fd
Merge branch 'main' of https://github.com/seanyeo300/dcase2024_task1
seanyeo300 Apr 12, 2024
bb3d7cd
Update test.ipynb
seanyeo300 Apr 14, 2024
42d9205
Update runs
seanyeo300 Apr 15, 2024
a90253a
Merge branch 'main' of https://github.com/seanyeo300/dcase2024_task1
seanyeo300 Apr 15, 2024
650565b
creating no_aug script
seanyeo300 Apr 15, 2024
7d9be87
Merge branch 'main' of https://github.com/seanyeo300/dcase2024_task1
seanyeo300 Apr 15, 2024
26d7ef1
50 subset
seanyeo300 Apr 15, 2024
41f0d23
Merge branch 'main' of https://github.com/seanyeo300/dcase2024_task1
seanyeo300 Apr 15, 2024
d7945c3
Push ckpt files
seanyeo300 Apr 16, 2024
d9c6032
Merge branch 'main' of https://github.com/seanyeo300/dcase2024_task1
seanyeo300 Apr 16, 2024
9c68a1a
Updated model to get embeddings
seanyeo300 Apr 16, 2024
269deaf
Update run_training_metric.py
seanyeo300 Apr 16, 2024
d413675
DSP lab version of metric version. May conflict with WFH version
seanyeo300 Apr 16, 2024
1d4a7fa
focusnet subset 5 settings
seanyeo300 Apr 16, 2024
8668d86
update graphing script
seanyeo300 Apr 16, 2024
c2b54f8
Merge branch 'main' of https://github.com/seanyeo300/dcase2024_task1
seanyeo300 Apr 16, 2024
5874af2
Upload FocusNet ckpt
seanyeo300 Apr 16, 2024
b43538b
Update run_training_metric.py
seanyeo300 Apr 16, 2024
3be2382
Subset 10
seanyeo300 Apr 17, 2024
c9762ec
set eval model for dcase24_dev
seanyeo300 Apr 17, 2024
6f1d425
Merge branch 'main' of https://github.com/seanyeo300/dcase2024_task1
seanyeo300 Apr 17, 2024
4eb122a
run_training modified to use dcase24_dev.py
seanyeo300 Apr 17, 2024
f6358fa
Upload logits
seanyeo300 Apr 17, 2024
8247647
Removing embedding variables from non-metric scripts
seanyeo300 Apr 17, 2024
23b2039
Merge branch 'main' of https://github.com/seanyeo300/dcase2024_task1
seanyeo300 Apr 17, 2024
869e64c
Updating for logit retrieval
seanyeo300 Apr 17, 2024
e236a39
ckpt files
seanyeo300 Apr 18, 2024
11ea105
uploaded ckpts
seanyeo300 Apr 22, 2024
db30666
updated run_training to save best model, dev_logmel for logmel featur…
seanyeo300 Apr 23, 2024
11f8fe8
ckpt updates
seanyeo300 Apr 23, 2024
a2677e3
Merge branch 'main' of https://github.com/seanyeo300/dcase2024_task1
seanyeo300 Apr 23, 2024
86f8cfb
ckpt files and changing training arguments
seanyeo300 May 8, 2024
87edf05
Corrected FocusNet Imp
seanyeo300 May 8, 2024
0a6f03d
Further correction to FocusNet - softmax y_hat for entropy calc
seanyeo300 May 8, 2024
df8b5e5
update training, focusnet for scheduler scaling
seanyeo300 May 8, 2024
4e2569f
added callbacks for best model
seanyeo300 May 8, 2024
a9946ac
Merge branch 'main' of https://github.com/seanyeo300/dcase2024_task1
seanyeo300 May 8, 2024
a003921
Update run_training.py
seanyeo300 May 8, 2024
127935b
Update test.ipynb
seanyeo300 May 8, 2024
a04ec6a
Merge branch 'main' of https://github.com/seanyeo300/dcase2024_task1
seanyeo300 May 8, 2024
f9eec34
Update test.ipynb
seanyeo300 May 8, 2024
0a3e461
upload model ckpt
seanyeo300 May 9, 2024
9b962ec
Merge branch 'main' of https://github.com/seanyeo300/dcase2024_task1
seanyeo300 May 9, 2024
86bda06
Update test.ipynb
seanyeo300 May 9, 2024
49902fd
Updated for tuned hyperparameters
seanyeo300 May 9, 2024
b19ab5f
Merge branch 'main' of https://github.com/seanyeo300/dcase2024_task1
seanyeo300 May 9, 2024
3a250eb
Update test.ipynb
seanyeo300 May 9, 2024
8b1da23
Merge branch 'main' of https://github.com/seanyeo300/dcase2024_task1
seanyeo300 May 9, 2024
a63e3e6
created GRN
seanyeo300 May 10, 2024
6abef33
update test notebook
seanyeo300 May 10, 2024
fb595c1
Merge branch 'main' of https://github.com/seanyeo300/dcase2024_task1
seanyeo300 May 10, 2024
e42c7dd
Update run_training.py
seanyeo300 May 10, 2024
291f8bd
update grn script
seanyeo300 May 13, 2024
bee18ea
Update dcase24_dev_logmel.py
seanyeo300 May 17, 2024
2bb8189
Update get_logits.py
seanyeo300 May 17, 2024
679d610
Update run_training.py
seanyeo300 May 17, 2024
483231c
Update run_training_no_roll.py
seanyeo300 May 17, 2024
7193167
Update test.ipynb
seanyeo300 May 17, 2024
39a04a9
Added passt logits, updated local load and focusnet dev scripts. Adde…
seanyeo300 May 18, 2024
887f109
updating training scripts
seanyeo300 May 21, 2024
f8d52cf
Updated training scripts
seanyeo300 May 25, 2024
29d8f96
multi-run scripts
seanyeo300 May 25, 2024
4e48421
Update test loop checkpoint behavior
seanyeo300 May 26, 2024
dcbbe7f
grn training
seanyeo300 May 27, 2024
7e2ef24
Merge branch 'main' of https://github.com/seanyeo300/dcase2024_task1
seanyeo300 May 27, 2024
d0457b1
ckpt files
seanyeo300 May 27, 2024
6bb4e5a
Test scripts and half depth model
seanyeo300 May 28, 2024
65fbeb8
Update multi_run.py
seanyeo300 May 28, 2024
51bb3b8
Create baseline_half_depth.py
seanyeo300 May 28, 2024
82bc561
updated dev_script
seanyeo300 May 28, 2024
5e74f8c
DIR stuff
seanyeo300 May 28, 2024
34fcfdd
update half depth files
seanyeo300 May 28, 2024
0b82702
Update ensemble_logits.pt
seanyeo300 May 28, 2024
3fa7a00
run_training_half_depth_channel_exp
seanyeo300 May 29, 2024
b50e652
Merge branch 'main' of https://github.com/seanyeo300/dcase2024_task1
seanyeo300 May 29, 2024
cc4bef4
Name changes
seanyeo300 May 29, 2024
f69b5c6
Merge branch 'main' of https://github.com/seanyeo300/dcase2024_task1
seanyeo300 May 29, 2024
176a357
update training scripts
seanyeo300 May 30, 2024
eb0026b
half channel update
seanyeo300 May 30, 2024
e4d5ccd
mixup training script
seanyeo300 May 31, 2024
bf17a2d
channel scripts
seanyeo300 May 31, 2024
762202f
Update utils and mixup script
seanyeo300 May 31, 2024
1508083
Update test.ipynb
seanyeo300 May 31, 2024
2c276b2
Update test.ipynb
seanyeo300 May 31, 2024
f97afda
training script updates
seanyeo300 May 31, 2024
22cba9f
Merge branch 'main' of https://github.com/seanyeo300/dcase2024_task1
seanyeo300 May 31, 2024
17ba3d6
Update test.ipynb
seanyeo300 Jun 2, 2024
11d6cb1
Update test.ipynb
seanyeo300 Jun 2, 2024
b0472a5
upload baseline model checkpoints and training scripts
seanyeo300 Jun 3, 2024
80cbb9b
Merge branch 'main' of https://github.com/seanyeo300/dcase2024_task1
seanyeo300 Jun 3, 2024
556eb1d
Update run_training_KD_lam1.py
seanyeo300 Jun 3, 2024
aabf4e3
Merge branch 'main' of https://github.com/seanyeo300/dcase2024_task1
seanyeo300 Jun 3, 2024
a7e3f26
delete half depth
seanyeo300 Jun 3, 2024
bf32db6
Update multi_run.py
seanyeo300 Jun 3, 2024
825e503
update path handling for dcase_dev
seanyeo300 Jun 3, 2024
1b555b7
Merge branch 'main' of https://github.com/seanyeo300/dcase2024_task1
seanyeo300 Jun 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
dcase24.py
wandb/
*.pyc
Binary file added DCASE24_Task1/06fahajl/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
Binary file added DCASE24_Task1/0xcis77x/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
Binary file added DCASE24_Task1/1ea864zz/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
Binary file added DCASE24_Task1/24r6wnps/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
Binary file added DCASE24_Task1/25f9vshv/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
Binary file added DCASE24_Task1/2bsxy6hd/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
Binary file added DCASE24_Task1/2q7s0l6t/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
Binary file added DCASE24_Task1/3lebte40/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added DCASE24_Task1/3stswvyj/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
Binary file added DCASE24_Task1/3wwh507t/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
Binary file added DCASE24_Task1/3xt2jqi2/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
Binary file added DCASE24_Task1/4frm3efy/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
Binary file added DCASE24_Task1/4mz0jocd/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
Binary file added DCASE24_Task1/4uhextrk/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
Binary file added DCASE24_Task1/5q8m4d7e/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
Binary file added DCASE24_Task1/67g5uuam/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
Binary file added DCASE24_Task1/6h2pizmx/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
Binary file added DCASE24_Task1/6i4hqicm/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
Binary file added DCASE24_Task1/6s1u3qb0/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
Binary file added DCASE24_Task1/7ljkz6fx/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
Binary file added DCASE24_Task1/904bkzkq/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
Binary file added DCASE24_Task1/9hr1m441/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
Binary file added DCASE24_Task1/9nuiltlt/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
Binary file added DCASE24_Task1/anew3ofc/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
Binary file added DCASE24_Task1/bu7agibx/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
Binary file added DCASE24_Task1/cly8hv3k/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
Binary file added DCASE24_Task1/ctbuh5hi/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
Binary file added DCASE24_Task1/d8dscgh8/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
Binary file added DCASE24_Task1/diweggfu/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
Binary file added DCASE24_Task1/e21znxl1/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
Binary file added DCASE24_Task1/e2f0qnud/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
Binary file added DCASE24_Task1/elymyq0s/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
Binary file added DCASE24_Task1/eqof0i23/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
Binary file added DCASE24_Task1/erk0wu35/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
Binary file added DCASE24_Task1/f2cm10th/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
Binary file added DCASE24_Task1/g829jqbk/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
Binary file added DCASE24_Task1/gov6j8yq/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added DCASE24_Task1/mxllj5x1/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
Binary file added DCASE24_Task1/nhx1d8t1/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
Binary file added DCASE24_Task1/ot4j7tem/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
Binary file added DCASE24_Task1/qabvkf35/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added DCASE24_Task1/r3h3qvet/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
Binary file added DCASE24_Task1/r3m0113t/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
Binary file added DCASE24_Task1/srq5lv8d/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added DCASE24_Task1/tk8uhr60/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
Binary file added DCASE24_Task1/urhdivv7/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
Binary file added DCASE24_Task1/uykvcs08/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
Binary file added DCASE24_Task1/v11h7rui/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
Binary file added DCASE24_Task1/v8ayx1ww/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
Binary file added DCASE24_Task1/x7fq3iij/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
Binary file added DCASE24_Task1/xeao5a09/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
Binary file added DCASE24_Task1/zcdxhs2n/checkpoints/last.ckpt
Binary file not shown.
Binary file not shown.
42 changes: 42 additions & 0 deletions dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import torch
import time
from torch.utils.data import DataLoader
from dataset.dcase24_dev import get_training_set
from helpers.init import worker_init_fn
from dataset.dcase24_dev_logmel import get_training_set_log

# Function to obtain datasets
def get_datasets():
dataset1 = get_training_set_log(5, roll=0)
dataset2 = get_training_set(5, roll=0)
return dataset1, dataset2

def measure_dataloader_speed(dataloader, num_batches):
start_time = time.time()
for i, _ in enumerate(dataloader):
if i >= num_batches:
break
end_time = time.time()
elapsed_time = end_time - start_time
average_time_per_batch = elapsed_time / num_batches
return average_time_per_batch

def main():
# Obtain datasets from user-provided script
dataset1, dataset2 = get_datasets()

# Define dataloaders
batch_size = 256
dataloader1 = DataLoader(dataset1, batch_size=batch_size, shuffle=True, num_workers=0, worker_init_fn=worker_init_fn)
dataloader2 = DataLoader(dataset2, batch_size=batch_size, shuffle=True, num_workers=0, worker_init_fn=worker_init_fn)

# Measure the speed of dataloaders
num_batches_to_test = 100
avg_time_dataloader1 = measure_dataloader_speed(dataloader1, num_batches_to_test)
avg_time_dataloader2 = measure_dataloader_speed(dataloader2, num_batches_to_test)

print(f"Average loading time per batch for .pt Dataloader: {avg_time_dataloader1:.4f} seconds")
print(f"Average loading time per batch for .wav to Logmel Dataloader: {avg_time_dataloader2:.4f} seconds")

if __name__ == "__main__":
main()
95 changes: 87 additions & 8 deletions dataset/dcase24.py → dataset/dcase24_dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@
from torch.utils.data import Dataset as TorchDataset
import torch
import torchaudio
import torch.nn.functional as F
from torch.hub import download_url_to_file
import numpy as np
import librosa
from scipy.signal import convolve
import pathlib

dataset_dir = None
# dataset_dir = r"D:\Sean\DCASE\datasets\Extract_to_Folder\TAU-urban-acoustic-scenes-2022-mobile-development" # Alibaba
dataset_dir = r"F:\DCASE\2024\Datasets\TAU-urban-acoustic-scenes-2022-mobile-development" # DSP
assert dataset_dir is not None, "Specify 'TAU Urban Acoustic Scenes 2022 Mobile dataset' location in variable " \
"'dataset_dir'. The dataset can be downloaded from this URL:" \
" https://zenodo.org/record/6337421"
Expand All @@ -18,11 +23,69 @@
"split_path": "split_setup",
"split_url": "https://github.com/CPJKU/dcase2024_task1_baseline/releases/download/files/",
"test_split_csv": "test.csv",
"eval_dir": os.path.join(dataset_dir, "..", "TAU-urban-acoustic-scenes-2024-mobile-evaluation"),
"eval_meta_csv": os.path.join(dataset_dir, "..", "TAU-urban-acoustic-scenes-2024-mobile-evaluation", "meta.csv")
"dirs_path": os.path.join("dataset", "dirs"),
"eval_dir": os.path.join(dataset_dir),
"eval_meta_csv": os.path.join(dataset_dir, "split100.csv"), # to get the full prediction list with index intact
"logits_file": os.path.join("predictions","elymyq0s", "logits.pt") #specifies where the logit and predictions are stored. Still need to provide script with ckpt_id
# "eval_dir": os.path.join(dataset_dir, "TAU-urban-acoustic-scenes-2024-mobile-evaluation"),
# "eval_meta_csv": os.path.join(dataset_dir, "TAU-urban-acoustic-scenes-2024-mobile-evaluation", "meta.csv")
}

class DIRAugmentDataset(TorchDataset):
"""
Augments Waveforms with a Device Impulse Response (DIR)
"""

def __init__(self, ds, dirs, prob):
self.ds = ds
self.dirs = dirs
self.prob = prob

def __getitem__(self, index):
x, file, label, device, city, logits = self.ds[index]

fsplit = file.rsplit("-", 1)
device = fsplit[1][:-4]

if device == 'a' and torch.rand(1) < self.prob:
# choose a DIR at random
dir_idx = int(np.random.randint(0, len(self.dirs)))
dir = self.dirs[dir_idx]

x = convolve(x, dir, 'full')[:, :x.shape[1]]
x = torch.from_numpy(x)
return x, file, label, device, city, logits

def __len__(self):
return len(self.ds)
class AddLogitsDataset(TorchDataset):
"""A dataset that loads and adds teacher logits to audio samples.
"""

def __init__(self, dataset, map_indices, logits_file, temperature=2):
"""
@param dataset: dataset to load data from
@param map_indices: used to get correct indices in list of logits
@param logits_file: logits file to load the teacher logits from
@param temperature: used in Knowledge Distillation, change distribution of predictions
return: x, file name, label, device, city, logits
"""
self.dataset = dataset
if not os.path.isfile(logits_file):
print("Verify existence of teacher predictions.")
raise SystemExit
logits = torch.load(logits_file).float()
self.logits = logits
# self.logits = F.log_softmax(logits / temperature, dim=-1)
self.map_indices = map_indices

def __getitem__(self, index):
x, file, label, device, city = self.dataset[index]
return x, file, label, device, city, self.logits[self.map_indices[index]]

def __len__(self):
return len(self.dataset)

class BasicDCASE24Dataset(TorchDataset):
"""
Basic DCASE'24 Dataset: loads data from files
Expand Down Expand Up @@ -85,15 +148,29 @@ def __init__(self, dataset: TorchDataset, shift_range: int, axis=1):
self.axis = axis

def __getitem__(self, index):
x, file, label, device, city = self.dataset[index]
x, file, label, device, city, logits = self.dataset[index]
sf = int(np.random.random_integers(-self.shift_range, self.shift_range))
return x.roll(sf, self.axis), file, label, device, city
return x.roll(sf, self.axis), file, label, device, city, logits

def __len__(self):
return len(self.dataset)
def load_dirs(dirs_path, resample_rate):
all_paths = [path for path in pathlib.Path(os.path.expanduser(dirs_path)).rglob('*.wav')]
all_paths = sorted(all_paths)
all_paths_name = [str(p).rsplit("/", 1)[-1] for p in all_paths]

print("Augment waveforms with the following device impulse responses:")
for i in range(len(all_paths_name)):
print(i, ": ", all_paths_name[i])

def process_func(dir_file):
sig, _ = librosa.load(dir_file, sr=resample_rate, mono=True)
sig = torch.from_numpy(sig[np.newaxis])
return sig

def get_training_set(split=100, roll=False):
return [process_func(p) for p in all_paths]

def get_training_set(split=100, roll=False, dir_prob=0,resample_rate=44100):
assert str(split) in ("5", "10", "25", "50", "100"), "Parameters 'split' must be in [5, 10, 25, 50, 100]"
os.makedirs(dataset_config['split_path'], exist_ok=True)
subset_fname = f"split{split}.csv"
Expand All @@ -104,6 +181,8 @@ def get_training_set(split=100, roll=False):
print(f"Downloading file: {subset_fname}")
download_url_to_file(subset_csv_url, subset_split_file)
ds = get_base_training_set(dataset_config['meta_csv'], subset_split_file)
if dir_prob > 0:
ds = DIRAugmentDataset(ds, load_dirs(dataset_config['dirs_path'], resample_rate), dir_prob)
if roll:
ds = RollDataset(ds, shift_range=roll)
return ds
Expand All @@ -115,6 +194,7 @@ def get_base_training_set(meta_csv, train_files_csv):
train_subset_indices = list(meta[meta['filename'].isin(train_files)].index)
ds = SimpleSelectionDataset(BasicDCASE24Dataset(meta_csv),
train_subset_indices)
# ds = AddLogitsDataset(ds, train_subset_indices, dataset_config['logits_file'])
return ds


Expand Down Expand Up @@ -169,5 +249,4 @@ def get_eval_set():

def get_base_eval_set(meta_csv, eval_dir):
ds = BasicDCASE24EvalDataset(meta_csv, eval_dir)
return ds

return ds
Loading