diff --git a/src/torchcvnn/datasets/mstar/dataset.py b/src/torchcvnn/datasets/mstar/dataset.py index 71c5f0c..ddde31d 100644 --- a/src/torchcvnn/datasets/mstar/dataset.py +++ b/src/torchcvnn/datasets/mstar/dataset.py @@ -149,7 +149,6 @@ def gather_mstar_datafiles(rootdir: pathlib.Path, target_name_depth: int = 1) -> logging.debug(f"Successfully parsed {filename} as a {target_name} sample.") data_files[target_name].append(filename) - return data_files @@ -165,6 +164,7 @@ class MSTARTargets(Dataset): This dataset object expects all the datasets to be unpacked in the same directory. We can parse the following : - MSTAR_PUBLIC_T_72_VARIANTS_CD1 : https://www.sdms.afrl.af.mil/index.php?collection=mstar&page=variants + - MSTAR_PUBLIC_T_72_VARIANTS_CD2 : https://www.sdms.afrl.af.mil/index.php?collection=mstar&page=variants - MSTAR_PUBLIC_MIXED_TARGETS_CD1 : https://www.sdms.afrl.af.mil/index.php?collection=mstar&page=mixed - MSTAR_PUBLIC_MIXED_TARGETS_CD2 : https://www.sdms.afrl.af.mil/index.php?collection=mstar&page=mixed - MSTAR_PUBLIC_TARGETS_CHIPS_T72_BMP2_BTR70_SLICY : @@ -214,6 +214,7 @@ def __init__(self, rootdir: str, transform=None): # with respect to a datafile sub_datasets = { "MSTAR_PUBLIC_T_72_VARIANTS_CD1": 2, + "MSTAR_PUBLIC_T_72_VARIANTS_CD2": 2, "MSTAR_PUBLIC_MIXED_TARGETS_CD1": 2, "MSTAR_PUBLIC_MIXED_TARGETS_CD2": 2, "MSTAR_PUBLIC_TARGETS_CHIPS_T72_BMP2_BTR70_SLICY": 3, @@ -226,7 +227,13 @@ def __init__(self, rootdir: str, transform=None): if not sub_dir.exists(): logging.warning(f"Directory {sub_dir} does not exist.") continue - self.data_files.update(gather_mstar_datafiles(sub_dir, target_name_depth)) + # Append the data files from the sub-dataset + for key, value in gather_mstar_datafiles( + sub_dir, target_name_depth + ).items(): + if key not in self.data_files: + self.data_files[key] = [] + self.data_files[key].extend(value) self.class_names = list(self.data_files.keys()) # We then count how many samples have been loaded for all the classes