diff --git a/scripts/compute_mean_std_iteratively_from_sample.py b/scripts/compute_mean_std_iteratively_from_sample.py index 78e2afc4..cd7aceed 100644 --- a/scripts/compute_mean_std_iteratively_from_sample.py +++ b/scripts/compute_mean_std_iteratively_from_sample.py @@ -5,6 +5,7 @@ import argparse from time import time +from typing import Callable import numpy as np import pandas as pd @@ -16,17 +17,83 @@ RESET = '\033[0m' LINK = '\033[94m' -def load_img(fp): + +def load_img(fp: str): + """Load an image from a file path. + + Parameters + ---------- + fp : str + file path to the image. + + Returns + ------- + (array) + image as a numpy array. + """ return np.array(Image.open(fp)).astype(np.float32) -def load_csv(fp): + +def load_csv(fp: str): + """Load a csv file from a file path. + + The CSV file is expected to contain pre-extracted observations (PA or + PO), from any source (bioclimatic, landsat, sentinel...) with + the data contained in columns 1 to end. Column 0 is skipped. + + Parameters + ---------- + fp : str + file path to the csv file. + + Returns + ------- + (array) + pre-extracted obs as a numpy array. + """ df = pd.read_csv(fp) return df.iloc[:, 1:].values.astype(np.float32) -def load_pt(fp): + +def load_pt(fp: str): + """Load a PyTorch cube from a file path. + + Parameters + ---------- + fp : str + file path to the PyTorch cube. + + Returns + ------- + (array) + numpy array of the PyTorch cube. + """ return torch.load(fp).numpy() -def iterative_mean_std(fps, load_fun, compare_numpy=False): + +def iterative_mean_std(fps: list, + load_fun: Callable, + compare_numpy: bool = False): + """Compute the mean and std of a dataset iteratively. + + Parameters + ---------- + fps : str + list of paths to the dataset file. + load_fun : callable + loading function to load the data from the file (depends on the + type of file). + compare_numpy : bool, optional + if True, computes the numpy mean and std by loading all dataset + files content in a single numpy array to compare with the + iteratively computed values. + By default False + + Returns + ------- + (tuple) + tuple of iterative mean and std of the dataset as float values. + """ mean = 0 mean2 = 0 data = [] @@ -41,23 +108,56 @@ def iterative_mean_std(fps, load_fun, compare_numpy=False): print(f'Numpy mean: {INFO}{np.mean(data)}{RESET}, Numpy std: {INFO}{np.std(data)}{RESET}') return mean, np.sqrt(var) -def main(paths_file, output=None, type='image', max_items=None, compare_numpy=False): + +def main(paths_file: str, + output: str = None, + data_type: str = 'image', + max_items: int = None, + compare_numpy: bool = False): + """Run the main function. + + This method calls the correct loading functions and contrains the + max amount of items to compute the mean/std on. + + Parameters + ---------- + paths_file : str + path to a file containing the paths to the files to process. + output : str, optional + path to the output file to store the mean/std values. This file + is expected to be a CSV of 1 line and 2 columns. + By default None + data_type : str, optional + type of file to process, by default 'image' + max_items : _type_, optional + maximum number of items to compute the mean/std on. + By default None + compare_numpy : bool, optional + if True, the numpy mean and std will also be computed for + comparison. + By default False + + Raises + ------ + ValueError + triggers when the type is not recognized. + """ t1 = time() with open(paths_file, 'r', encoding="utf-8") as f: fps = f.read().splitlines() fps = fps[:max_items] - if type == 'img': + if data_type == 'img': it_mean, it_std = iterative_mean_std(fps, load_img, compare_numpy) print(f'Processed {INFO}{len(fps)}{RESET} images. Iterative mean: {INFO}{it_mean}{RESET}, Iterative std: {INFO}{it_std}{RESET} in {LINK}{(time() - t1):.3f}{RESET}s') - elif type == 'csv': + elif data_type == 'csv': it_mean, it_std = iterative_mean_std(fps, load_csv, compare_numpy) print(f'Processed {INFO}{len(fps)}{RESET} csv pre-extracted obs files. Iterative mean: {INFO}{it_mean}{RESET}, Iterative std: {INFO}{it_std}{RESET} in {LINK}{(time() - t1):.3f}{RESET}s') - elif type == 'pt': + elif data_type == 'pt': it_mean, it_std = iterative_mean_std(fps, load_pt, compare_numpy) print(f'Processed {INFO}{len(fps)}{RESET} pytorch cubes. Iterative mean: {INFO}{it_mean}{RESET}, Iterative std: {INFO}{it_std}{RESET} in {LINK}{(time() - t1):.3f}{RESET}s') else: - raise ValueError(f"Type {type} not recognized.") + raise ValueError(f"Type {data_type} not recognized.") if output: df = pd.DataFrame({'mean': [it_mean], 'std': [it_std]}) @@ -87,5 +187,5 @@ def main(paths_file, output=None, type='image', max_items=None, compare_numpy=Fa help="If true, computes the Numpy mean and std for comparison. WARNING: this will load all the items in memory, only use with a reasonable value of --max_items.", action='store_true') args = parser.parse_args() - main(args.paths_file, args.output, type=args.type, max_items=args.max_items, + main(args.paths_file, args.output, data_type=args.type, max_items=args.max_items, compare_numpy=args.compare_numpy)