Skip to content

Commit

Permalink
Updated iterative mean std script by adding per channel image version
Browse files Browse the repository at this point in the history
  • Loading branch information
tlarcher committed Jun 14, 2024
1 parent 5e74539 commit 2dd8617
Showing 1 changed file with 63 additions and 5 deletions.
68 changes: 63 additions & 5 deletions scripts/compute_mean_std_iteratively_from_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,57 @@ def iterative_mean_std(fps: list,
print(f'Numpy mean: {INFO}{np.mean(data)}{RESET}, Numpy std: {INFO}{np.std(data)}{RESET}')
return mean, np.sqrt(var)

def iterative_mean_std_img_per_channel(fps: list,
load_fun: Callable,
compare_numpy: bool = False):
"""Compute the mean and std of a dataset iteratively per channel.
This method is intended to compute the mean and std for each
individual channel of 3D matrices.
Parameters
----------
fps : str
list of paths to the dataset file.
load_fun : callable
loading function to load the data from the file (depends on the
type of file).
compare_numpy : bool, optional
if True, computes the numpy mean and std by loading all dataset
files content in a single numpy array to compare with the
iteratively computed values.
By default False
Returns
-------
(tuple)
tuple of iterative mean and std of the dataset as float values.
"""
mean, mean2, var = (np.zeros(1),) * 3
data = []
for k, fp in tqdm(enumerate(fps), total=len(fps)):
try:
x = load_fun(fp) # Giving a large type is important to avoid value overflow with mean squared
assert len(x.shape) >= 3
if k == 0:
mean = np.zeros(x.shape[-1]).astype(np.float32)
mean2 = np.zeros(x.shape[-1]).astype(np.float32)
if compare_numpy:
data.append(x)
mean += (np.nanmean(x, axis=(0,1)) - mean) / (k + 1)
mean2 += (np.nanmean(x**2, axis=(0,1)) - mean2) / (k + 1)
except AssertionError:
print(f'File {fp} does not contain 3D data. Passing...')
var = mean2 - mean**2
if compare_numpy:
print(f'Numpy mean: {INFO}{np.nanmean(data, axis=(0,1))}{RESET}, Numpy std: {INFO}{np.nanstd(data, axis=(0,1))}{RESET}')
return mean.tolist(), np.sqrt(var).tolist()

def main(paths_file: str,
output: str = None,
data_type: str = 'image',
data_type: str = 'img',
max_items: int = None,
per_channel: bool = False,
compare_numpy: bool = False):
"""Run the main function.
Expand All @@ -132,6 +178,9 @@ def main(paths_file: str,
max_items : _type_, optional
maximum number of items to compute the mean/std on.
By default None
per_channel : bool, optionnal
if True, calls computaiton of mean / std seperately for each
data channel
compare_numpy : bool, optional
if True, the numpy mean and std will also be computed for
comparison.
Expand All @@ -146,21 +195,26 @@ def main(paths_file: str,
with open(paths_file, 'r', encoding="utf-8") as f:
fps = f.read().splitlines()
fps = fps[:max_items]
ims = iterative_mean_std
if per_channel and data_type == 'img':
ims = iterative_mean_std_img_per_channel

if data_type == 'img':
it_mean, it_std = iterative_mean_std(fps, load_img, compare_numpy)
it_mean, it_std = ims(fps, load_img, compare_numpy)
print(f'Processed {INFO}{len(fps)}{RESET} images. Iterative mean: {INFO}{it_mean}{RESET}, Iterative std: {INFO}{it_std}{RESET} in {LINK}{(time() - t1):.3f}{RESET}s')
elif data_type == 'csv':
it_mean, it_std = iterative_mean_std(fps, load_csv, compare_numpy)
it_mean, it_std = ims(fps, load_csv, compare_numpy)
print(f'Processed {INFO}{len(fps)}{RESET} csv pre-extracted obs files. Iterative mean: {INFO}{it_mean}{RESET}, Iterative std: {INFO}{it_std}{RESET} in {LINK}{(time() - t1):.3f}{RESET}s')
elif data_type == 'pt':
it_mean, it_std = iterative_mean_std(fps, load_pt, compare_numpy)
it_mean, it_std = ims(fps, load_pt, compare_numpy)
print(f'Processed {INFO}{len(fps)}{RESET} pytorch cubes. Iterative mean: {INFO}{it_mean}{RESET}, Iterative std: {INFO}{it_std}{RESET} in {LINK}{(time() - t1):.3f}{RESET}s')
else:
raise ValueError(f"Type {data_type} not recognized.")

if output:
df = pd.DataFrame({'mean': [it_mean], 'std': [it_std]})
it_mean = [it_mean] if not isinstance(it_mean, list) else it_mean
it_std = [it_std] if not isinstance(it_std, list) else it_std
df = pd.DataFrame({'mean': it_mean, 'std': it_std})
df.to_csv(output, index=False, sep=',')
print(f'Stats saved to {INFO}{output}{RESET}')

Expand All @@ -183,9 +237,13 @@ def main(paths_file: str,
help="Type of files to process.",
choices=['img', 'csv', 'pt'],
type=str)
parser.add_argument("--per_channel",
help="Compute mean/std over each channel seperately.",
action='store_true')
parser.add_argument("--compare_numpy",
help="If true, computes the Numpy mean and std for comparison. WARNING: this will load all the items in memory, only use with a reasonable value of --max_items.",
action='store_true')
args = parser.parse_args()
main(args.paths_file, args.output, data_type=args.type, max_items=args.max_items,
per_channel=True,
compare_numpy=args.compare_numpy)

0 comments on commit 2dd8617

Please sign in to comment.