Skip to content

Commit

Permalink
Merge commit '0dfa3c5d462552e93e944b21ba98c3249b1c816b' into dev_dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
ouioui199 committed Jan 10, 2025
2 parents bfc1cd0 + 0dfa3c5 commit bbec6a1
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions src/torchcvnn/datasets/mstar/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@ def gather_mstar_datafiles(rootdir: pathlib.Path, target_name_depth: int = 1) ->

logging.debug(f"Successfully parsed {filename} as a {target_name} sample.")
data_files[target_name].append(filename)

return data_files


Expand All @@ -165,6 +164,7 @@ class MSTARTargets(Dataset):
This dataset object expects all the datasets to be unpacked in the same directory. We can parse the following :
- MSTAR_PUBLIC_T_72_VARIANTS_CD1 : https://www.sdms.afrl.af.mil/index.php?collection=mstar&page=variants
- MSTAR_PUBLIC_T_72_VARIANTS_CD2 : https://www.sdms.afrl.af.mil/index.php?collection=mstar&page=variants
- MSTAR_PUBLIC_MIXED_TARGETS_CD1 : https://www.sdms.afrl.af.mil/index.php?collection=mstar&page=mixed
- MSTAR_PUBLIC_MIXED_TARGETS_CD2 : https://www.sdms.afrl.af.mil/index.php?collection=mstar&page=mixed
- MSTAR_PUBLIC_TARGETS_CHIPS_T72_BMP2_BTR70_SLICY :
Expand Down Expand Up @@ -214,6 +214,7 @@ def __init__(self, rootdir: str, transform=None):
# with respect to a datafile
sub_datasets = {
"MSTAR_PUBLIC_T_72_VARIANTS_CD1": 2,
"MSTAR_PUBLIC_T_72_VARIANTS_CD2": 2,
"MSTAR_PUBLIC_MIXED_TARGETS_CD1": 2,
"MSTAR_PUBLIC_MIXED_TARGETS_CD2": 2,
"MSTAR_PUBLIC_TARGETS_CHIPS_T72_BMP2_BTR70_SLICY": 3,
Expand All @@ -226,7 +227,13 @@ def __init__(self, rootdir: str, transform=None):
if not sub_dir.exists():
logging.warning(f"Directory {sub_dir} does not exist.")
continue
self.data_files.update(gather_mstar_datafiles(sub_dir, target_name_depth))
# Append the data files from the sub-dataset
for key, value in gather_mstar_datafiles(
sub_dir, target_name_depth
).items():
if key not in self.data_files:
self.data_files[key] = []
self.data_files[key].extend(value)
self.class_names = list(self.data_files.keys())

# We then count how many samples have been loaded for all the classes
Expand Down

0 comments on commit bbec6a1

Please sign in to comment.