Skip to content

Commit

Permalink
Added verde to python virtualenv packages list.
Browse files Browse the repository at this point in the history
- split_obs_per_species_frequency.py: changed variable names
- split_obs_spatially.py: changed variable names
- malpolon.data.utils: added the 2 previously mentioned scripts as modules to be called within malpolon's API and changed variable names
- malpolon.data.datasets.geolifeclef2024: updated docstring of JpegPatchProvider()
  • Loading branch information
tlarcher committed Jun 3, 2024
1 parent 2ade4e6 commit 98dab49
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 13 deletions.
5 changes: 5 additions & 0 deletions malpolon/data/datasets/geolifeclef2024.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,11 @@ class JpegPatchProvider(PatchProvider):
Provides tensors of multi-modal patches from JPEG patch files
of rasters of the GLC23 challenge.
Image patches are expected to be named by a patch ID and arranged
in folders and sub-folders in the following way:
root_path/YZ/WX/patch_id.jpeg with patch_id being the value
ABCDWXYZ.
Attributes:
(PatchProvider): inherits PatchProvider.
"""
Expand Down
13 changes: 9 additions & 4 deletions malpolon/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,16 +163,16 @@ def get_files_path_recursively(path, *args, suffix='') -> list:
return result


def split_obs_spatially(input_name: str,
def split_obs_spatially(input_path: str,
spacing: float = 10 / 60,
plot: bool = False,
val_size: float = 0.15):
"""Perform a spatial train/val split on the input csv file.
Parameters
----------
input_name : str
obs CSV input file's name without the .csv extension.
input_path : str
obs CSV input file's path
spacing : float, optional
size of the spatial split in degrees (or whatever unit the coordinates are in),
by default 10/60
Expand All @@ -182,6 +182,7 @@ def split_obs_spatially(input_name: str,
val_size : float, optional
size of the validaiton split, by default 0.15
"""
input_name = input_path[:-4] if input_path.endswith(".csv") else input_path
df = pd.read_csv(f'{input_name}.csv')
coords, data = {}, {}
for col in df.columns:
Expand All @@ -203,14 +204,17 @@ def split_obs_spatially(input_name: str,
df_train_val = pd.concat([df_train, df_val])

df_train_val.to_csv(f'{input_name}_train_val-{spacing*60}min.csv', index=False)
print(f'Done: {input_name}_train_val-{spacing*60}min.csv')
df_train.to_csv(f'{input_name}_train-{spacing*60}min.csv', index=False)
print(f'Done: {input_name}_train-{spacing*60}min.csv')
df_val.to_csv(f'{input_name}_val-{spacing*60}min.csv', index=False)
print(f'Done: {input_name}_val-{spacing*60}min.csv')

if plot:
plot_od(df=df_train_val, show_map=True)


def split_obs_per_species_frequency(input_name: str,
def split_obs_per_species_frequency(input_path: str,
output_name: str,
val_ratio: float = 0.05):
"""Split an obs csv in val/train.
Expand All @@ -225,6 +229,7 @@ def split_obs_per_species_frequency(input_name: str,
Input csv is expected to have at least the following columns:
['speciesId']
"""
input_name = input_path[:-4] if input_path.endswith(".csv") else input_path
pa_train = pd.read_csv(f'{input_name}.csv')
pa_train['subset'] = ['train'] * len(pa_train)
pa_train_uniques = np.unique(pa_train['speciesId'], return_counts=True)
Expand Down
1 change: 1 addition & 0 deletions requirements_python3.10.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,5 @@ torchgeo==0.5.0
torchmetrics==1.2.0
torchvision==0.16.0
tqdm==4.66.3
verde==1.8.0
yarl==1.9.2
7 changes: 4 additions & 3 deletions scripts/split_obs_per_species_frequency.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from tqdm import tqdm


def main(input_name: str,
def main(input_path: str,
output_name: str,
val_ratio: float = 0.05):
"""Split an obs csv in val/train.
Expand All @@ -25,6 +25,7 @@ def main(input_name: str,
Input csv is expected to have at least the following columns:
['speciesId']
"""
input_name = input_path[:-4] if input_path.endswith(".csv") else input_path
pa_train = pd.read_csv(f'{input_name}.csv')
pa_train['subset'] = ['train'] * len(pa_train)
pa_train_uniques = np.unique(pa_train['speciesId'], return_counts=True)
Expand Down Expand Up @@ -52,6 +53,6 @@ def main(input_name: str,


if __name__ == '__main__':
INPUT_NAME = 'sample_obs'
INPUT_PATH = 'sample_obs.csv'
OUTPUT_NAME = 'sample_obs'
main(INPUT_NAME, OUTPUT_NAME)
main(INPUT_PATH, OUTPUT_NAME)
13 changes: 7 additions & 6 deletions scripts/split_obs_spatially.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@
from malpolon.plot.map import plot_observation_dataset as plot_od


def main(input_name: str,
def main(input_path: str,
spacing: float = 10 / 60,
plot: bool = False,
val_size: float = 0.15):
"""Perform a spatial train/val split on the input csv file.
Parameters
----------
input_name : str
obs CSV input file's name without the .csv extension.
obs_path : str
obs CSV input file's path.
spacing : float, optional
size of the spatial split in degrees (or whatever unit the coordinates are in),
by default 10/60
Expand All @@ -31,6 +31,7 @@ def main(input_name: str,
val_size : float, optional
size of the validaiton split, by default 0.15
"""
input_name = input_path[:-4] if input_path.endswith(".csv") else input_path
df = pd.read_csv(f'{input_name}.csv')
coords, data = {}, {}
for col in df.columns:
Expand Down Expand Up @@ -61,8 +62,8 @@ def main(input_name: str,

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input_name",
help="Name of the input csv obs file without the .csv extension",
parser.add_argument("-i", "--input_path",
help="Path to the input csv obs file.",
default='GLC24_PA_metadata_train',
type=str)
parser.add_argument("-s", "--spacing",
Expand All @@ -77,4 +78,4 @@ def main(input_name: str,
help="If true, plot the train/val split at the end of the script.",
action='store_true')
args = parser.parse_args()
main(args.input_name, args.spacing, plot=args.plot, val_size=args.val_size)
main(args.input_path, args.spacing, plot=args.plot, val_size=args.val_size)

0 comments on commit 98dab49

Please sign in to comment.