From 3bcb0d2f4c066e207c659aefc8ba3208a575519f Mon Sep 17 00:00:00 2001 From: Fariborz Daneshvar Date: Thu, 29 Aug 2024 12:20:01 -0500 Subject: [PATCH 1/9] call roc_curve in post.sbatch --- stormworkflow/slurm/post.sbatch | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/stormworkflow/slurm/post.sbatch b/stormworkflow/slurm/post.sbatch index ea61395..2e74c23 100644 --- a/stormworkflow/slurm/post.sbatch +++ b/stormworkflow/slurm/post.sbatch @@ -13,3 +13,11 @@ combine_ensemble \ analyze_ensemble \ --ensemble-dir $ENSEMBLE_DIR \ --tracks-dir $ENSEMBLE_DIR/track_files + +storm_roc_curve \ + --storm_name ${storm} \ + --storm_year ${year} \ + --leadtime ${hr_prelandfall} \ + --prob_nc_path $ENSEMBLE_DIR/analyze/linear_k1_p1_n0.025/probabilities.nc \ + --obs_df_path ${NHC_OBS} \ + --save_dir $ENSEMBLE_DIR/analyze/linear_k1_p1_n0.025 From 579ab2cc6d84b2aaad3e16496e2c84cf6bdd61b7 Mon Sep 17 00:00:00 2001 From: Fariborz Daneshvar Date: Fri, 30 Aug 2024 12:39:39 -0500 Subject: [PATCH 2/9] replace ROC_single_run.py with storm_roc_curve.py in [project.scripts] --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 8646a81..d85956d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -99,4 +99,4 @@ download_data = "stormworkflow.prep.download_data:cli" setup_ensemble = "stormworkflow.prep.setup_ensemble:cli" combine_ensemble = "stormworkflow.post.combine_ensemble:cli" analyze_ensemble = "stormworkflow.post.analyze_ensemble:cli" -storm_roc_curve = "stormworkflow.post.ROC_single_run:cli" +storm_roc_curve = "stormworkflow.post.storm_roc_curve:cli" From dac175329b806c87ea3c7fa24db9ff4cdd9fc34a Mon Sep 17 00:00:00 2001 From: Fariborz Daneshvar Date: Fri, 30 Aug 2024 12:42:40 -0500 Subject: [PATCH 3/9] add storm_roc_curve with args to the script --- stormworkflow/slurm/post.sbatch | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/stormworkflow/slurm/post.sbatch b/stormworkflow/slurm/post.sbatch index 2e74c23..33ef357 100644 --- a/stormworkflow/slurm/post.sbatch +++ b/stormworkflow/slurm/post.sbatch @@ -15,9 +15,8 @@ analyze_ensemble \ --tracks-dir $ENSEMBLE_DIR/track_files storm_roc_curve \ - --storm_name ${storm} \ - --storm_year ${year} \ + --storm ${storm} \ + --year ${year} \ --leadtime ${hr_prelandfall} \ - --prob_nc_path $ENSEMBLE_DIR/analyze/linear_k1_p1_n0.025/probabilities.nc \ --obs_df_path ${NHC_OBS} \ - --save_dir $ENSEMBLE_DIR/analyze/linear_k1_p1_n0.025 + --ensemble-dir $ENSEMBLE_DIR From da5d501b4838c5cf7611ff003ec5ee7c29a46651 Mon Sep 17 00:00:00 2001 From: Fariborz Daneshvar Date: Fri, 30 Aug 2024 12:49:18 -0500 Subject: [PATCH 4/9] make a copy of ROC_single_run.py and format it to be consistent with the rest of workflow. i.e. variable names and args definition --- stormworkflow/post/storm_roc_curve.py | 293 ++++++++++++++++++++++++++ 1 file changed, 293 insertions(+) create mode 100644 stormworkflow/post/storm_roc_curve.py diff --git a/stormworkflow/post/storm_roc_curve.py b/stormworkflow/post/storm_roc_curve.py new file mode 100644 index 0000000..4758374 --- /dev/null +++ b/stormworkflow/post/storm_roc_curve.py @@ -0,0 +1,293 @@ +import argparse +import logging +import os +import warnings +import numpy as np +import pandas as pd +import xarray as xr +import scipy as sp +import matplotlib.pyplot as plt +from pathlib import Path +from cartopy.feature import NaturalEarthFeature + +os.environ['USE_PYGEOS'] = '0' +import geopandas as gpd + +pd.options.mode.copy_on_write = True + + +def stack_station_coordinates(x, y): + """ + Create numpy.column_stack based on + coordinates of observation points + """ + coord_combined = np.column_stack([x, y]) + return coord_combined + + +def create_search_tree(longitude, latitude): + """ + Create scipy.spatial.CKDTree based on Lat. and Long. + """ + long_lat = np.column_stack((longitude.T.ravel(), latitude.T.ravel())) + tree = sp.spatial.cKDTree(long_lat) + return tree + + +def find_nearby_prediction(ds, variable, indices): + """ + Reads netcdf file, target variable, and indices + Returns max value among corresponding indices for each point + """ + obs_count = indices.shape[0] # total number of search/observation points + max_prediction_index = len(ds.node.values) # total number of nodes + + prediction_prob = np.zeros(obs_count) # assuming all are dry (probability of zero) + + for obs_point in range(obs_count): + idx_arr = np.delete( + indices[obs_point], np.where(indices[obs_point] == max_prediction_index)[0] + ) # len is length of surrogate model array + val_arr = ds[variable].values[idx_arr] + val_arr = np.nan_to_num(val_arr) # replace nan with zero (dry node) + + # # Pick the nearest non-zero probability (option #1) + # for val in val_arr: + # if val > 0.0: + # prediction_prob[obs_point] = round(val,4) #round to 0.1 mm + # break + + # pick the largest value (option #2) + if val_arr.size > 0: + prediction_prob[obs_point] = val_arr.max() + return prediction_prob + + +def plot_probabilities(df, prob_column, gdf_countries, title, save_name): + """ + plot probabilities of exceeding given threshold at obs. points + """ + figure, axis = plt.subplots(1, 1) + figure.set_size_inches(10, 10 / 1.6) + + plt.scatter(x=df.Longitude, y=df.Latitude, vmin=0, vmax=1.0, c=df[prob_column]) + xlim = axis.get_xlim() + ylim = axis.get_ylim() + + gdf_countries.plot(color='lightgrey', ax=axis, zorder=-5) + + axis.set_xlim(xlim) + axis.set_ylim(ylim) + plt.colorbar(shrink=0.75) + plt.title(title) + plt.savefig(save_name) + plt.close() + + +def calculate_hit_miss(df, obs_column, prob_column, threshold, probability): + """ + Reads dataframe with two columns for obs_elev, and probabilities + returns hit/miss/... based on user-defined threshold & probability + """ + hit = len(df[(df[obs_column] >= threshold) & (df[prob_column] >= probability)]) + miss = len(df[(df[obs_column] >= threshold) & (df[prob_column] < probability)]) + false_alarm = len(df[(df[obs_column] < threshold) & (df[prob_column] >= probability)]) + correct_neg = len(df[(df[obs_column] < threshold) & (df[prob_column] < probability)]) + + return hit, miss, false_alarm, correct_neg + + +def calculate_POD_FAR(hit, miss, false_alarm, correct_neg): + """ + Reads hit, miss, false_alarm, and correct_neg + returns POD and FAR + default POD and FAR are np.nan + """ + POD = np.nan + FAR = np.nan + try: + POD = round(hit / (hit + miss), 4) # Probability of Detection + except ZeroDivisionError: + pass + try: + FAR = round(false_alarm / (false_alarm + correct_neg), 4) # False Alarm Rate + except ZeroDivisionError: + pass + return POD, FAR + + +def main(args): + storm = args.storm.capitalize() + year = args.year + leadtime = args.leadtime + obs_df_path = Path(args.obs_df_path) + ensemble_dir = Path(args.ensemble_dir) + + output_directory = ensemble_dir / 'analyze/linear_k1_p1_n0.025' + prob_nc_path = output_directory / 'probabilities.nc' + + if leadtime == -1: + leadtime = 48 + + # *.nc file coordinates + thresholds_ft = [3, 6, 9] # in ft + thresholds_m = [round(i * 0.3048, 4) for i in thresholds_ft] # convert to meter + sources = ['model', 'surrogate'] + probabilities = [0.0, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] + + # attributes of input files + prediction_variable = 'probabilities' + obs_attribute = 'Elev_m_xGEOID20b' + + # search criteria + max_distance = 1000 # [in meters] to set distance_upper_bound + max_neighbors = 10 # to set k + + blank_arr = np.empty((len(thresholds_ft), 1, 1, len(sources), len(probabilities))) + blank_arr[:] = np.nan + + hit_arr = blank_arr.copy() + miss_arr = blank_arr.copy() + false_alarm_arr = blank_arr.copy() + correct_neg_arr = blank_arr.copy() + POD_arr = blank_arr.copy() + FAR_arr = blank_arr.copy() + + # Load obs file, extract storm obs points and coordinates + df_obs = pd.read_csv(obs_df_path) + Event_name = f'{storm}_{year}' + df_obs_storm = df_obs[df_obs.Event == Event_name] + obs_coordinates = stack_station_coordinates( + df_obs_storm.Longitude.values, df_obs_storm.Latitude.values + ) + + # Load probabilities.nc file + ds_prob = xr.open_dataset(prob_nc_path) + + # gdf_countries = gpd.GeoSeries( + # NaturalEarthFeature(category='physical', scale='10m', name='land',).geometries(), + # crs=4326, + # ) + + # Loop through thresholds and sources and find corresponding values from probabilities.nc + threshold_count = -1 + for threshold in thresholds_m: + threshold_count += 1 + source_count = -1 + for source in sources: + source_count += 1 + ds_temp = ds_prob.sel(level=threshold, source=source) + tree = create_search_tree(ds_temp.x.values, ds_temp.y.values) + dist, indices = tree.query( + obs_coordinates, k=max_neighbors, distance_upper_bound=max_distance * 1e-5 + ) # 0.01 is equivalent to 1000 m + prediction_prob = find_nearby_prediction( + ds=ds_temp, variable=prediction_variable, indices=indices + ) + df_obs_storm[f'{source}_prob'] = prediction_prob + + # # Plot probabilities at obs. points + # plot_probabilities( + # df_obs_storm, + # f'{source}_prob', + # gdf_countries, + # f'Probability of {source} exceeding {thresholds_ft[threshold_count]} ft \n {storm}, {year}, {leadtime}-hr leadtime', + # os.path.join( + # output_directory, + # f'prob_{source}_above_{thresholds_ft[threshold_count]}ft_{storm}_{year}_{leadtime}-hr.png', + # ), + # ) + + # Loop through probabilities: calculate hit/miss/... & POD/FAR + prob_count = -1 + for prob in probabilities: + prob_count += 1 + hit, miss, false_alarm, correct_neg = calculate_hit_miss( + df_obs_storm, obs_attribute, f'{source}_prob', threshold, prob + ) + hit_arr[threshold_count, 0, 0, source_count, prob_count] = hit + miss_arr[threshold_count, 0, 0, source_count, prob_count] = miss + false_alarm_arr[threshold_count, 0, 0, source_count, prob_count] = false_alarm + correct_neg_arr[threshold_count, 0, 0, source_count, prob_count] = correct_neg + + pod, far = calculate_POD_FAR(hit, miss, false_alarm, correct_neg) + POD_arr[threshold_count, 0, 0, source_count, prob_count] = pod + FAR_arr[threshold_count, 0, 0, source_count, prob_count] = far + + ds_ROC = xr.Dataset( + coords=dict( + threshold=thresholds_ft, + storm=[storm], + leadtime=[leadtime], + source=sources, + prob=probabilities, + ), + data_vars=dict( + hit=(['threshold', 'storm', 'leadtime', 'source', 'prob'], hit_arr), + miss=(['threshold', 'storm', 'leadtime', 'source', 'prob'], miss_arr), + false_alarm=( + ['threshold', 'storm', 'leadtime', 'source', 'prob'], + false_alarm_arr, + ), + correct_neg=( + ['threshold', 'storm', 'leadtime', 'source', 'prob'], + correct_neg_arr, + ), + POD=(['threshold', 'storm', 'leadtime', 'source', 'prob'], POD_arr), + FAR=(['threshold', 'storm', 'leadtime', 'source', 'prob'], FAR_arr), + ), + ) + ds_ROC.to_netcdf(os.path.join(output_directory, f'{storm}_{year}_{leadtime}hr_POD_FAR.nc')) + + # plot ROC curves + marker_list = ['s', 'x'] + linestyle_list = ['dashed', 'dotted'] + threshold_count = -1 + for threshold in thresholds_ft: + threshold_count += 1 + fig = plt.figure() + ax = fig.add_subplot(111) + plt.axline( + (0.0, 0.0), (1.0, 1.0), linestyle='--', color='grey', label='random prediction' + ) + source_count = -1 + for source in sources: + source_count += 1 + plt.plot( + FAR_arr[threshold_count, 0, 0, source_count, :], + POD_arr[threshold_count, 0, 0, source_count, :], + label=f'{source}', + marker=marker_list[source_count], + linestyle=linestyle_list[source_count], + markersize=5, + ) + plt.legend() + plt.xlabel('False Alarm Rate') + plt.ylabel('Probability of Detection') + + plt.title(f'{storm}_{year}, {leadtime}-hr leadtime, {threshold} ft threshold') + plt.savefig( + os.path.join( + output_directory, + f'ROC_{storm}_{year}_{leadtime}hr_leadtime_{threshold}_ft.png', + ) + ) + plt.close() + + +def cli(): + parser = argparse.ArgumentParser() + + parser.add_argument('--storm', help='name of the storm', type=str) + parser.add_argument('--year', help='year of the storm', type=int) + parser.add_argument('--leadtime', help='OFCL track leadtime hr', type=int) + parser.add_argument('--obs_df_path', help='path to NHC obs data', type=str) + parser.add_argument('--ensemble-dir', help='path to ensemble.dir', type=str) + + main(parser.parse_args()) + + +if __name__ == '__main__': + warnings.filterwarnings('ignore') + # warnings.filterwarnings("ignore", category=DeprecationWarning) + cli() From d3adee26b76c87fc2dac7f457765a789f848a0bd Mon Sep 17 00:00:00 2001 From: Fariborz Daneshvar Date: Fri, 30 Aug 2024 12:52:30 -0500 Subject: [PATCH 5/9] delete ROC_single_run.py. storm_roc_curve.py is the new format of this file. --- stormworkflow/post/ROC_single_run.py | 300 --------------------------- 1 file changed, 300 deletions(-) delete mode 100644 stormworkflow/post/ROC_single_run.py diff --git a/stormworkflow/post/ROC_single_run.py b/stormworkflow/post/ROC_single_run.py deleted file mode 100644 index 26dbd2c..0000000 --- a/stormworkflow/post/ROC_single_run.py +++ /dev/null @@ -1,300 +0,0 @@ -import argparse -import logging -import os -import warnings -import numpy as np -import pandas as pd -import xarray as xr -import scipy as sp -import matplotlib.pyplot as plt -from pathlib import Path -from cartopy.feature import NaturalEarthFeature - -os.environ['USE_PYGEOS'] = '0' -import geopandas as gpd - -pd.options.mode.copy_on_write = True - - -def stack_station_coordinates(x, y): - """ - Create numpy.column_stack based on - coordinates of observation points - """ - coord_combined = np.column_stack([x, y]) - return coord_combined - - -def create_search_tree(longitude, latitude): - """ - Create scipy.spatial.CKDTree based on Lat. and Long. - """ - long_lat = np.column_stack((longitude.T.ravel(), latitude.T.ravel())) - tree = sp.spatial.cKDTree(long_lat) - return tree - - -def find_nearby_prediction(ds, variable, indices): - """ - Reads netcdf file, target variable, and indices - Returns max value among corresponding indices for each point - """ - obs_count = indices.shape[0] # total number of search/observation points - max_prediction_index = len(ds.node.values) # total number of nodes - - prediction_prob = np.zeros(obs_count) # assuming all are dry (probability of zero) - - for obs_point in range(obs_count): - idx_arr = np.delete( - indices[obs_point], np.where(indices[obs_point] == max_prediction_index)[0] - ) # len is length of surrogate model array - val_arr = ds[variable].values[idx_arr] - val_arr = np.nan_to_num(val_arr) # replace nan with zero (dry node) - - # # Pick the nearest non-zero probability (option #1) - # for val in val_arr: - # if val > 0.0: - # prediction_prob[obs_point] = round(val,4) #round to 0.1 mm - # break - - # pick the largest value (option #2) - if val_arr.size > 0: - prediction_prob[obs_point] = val_arr.max() - return prediction_prob - - -def plot_probabilities(df, prob_column, gdf_countries, title, save_name): - """ - plot probabilities of exceeding given threshold at obs. points - """ - figure, axis = plt.subplots(1, 1) - figure.set_size_inches(10, 10 / 1.6) - - plt.scatter(x=df.Longitude, y=df.Latitude, vmin=0, vmax=1.0, c=df[prob_column]) - xlim = axis.get_xlim() - ylim = axis.get_ylim() - - gdf_countries.plot(color='lightgrey', ax=axis, zorder=-5) - - axis.set_xlim(xlim) - axis.set_ylim(ylim) - plt.colorbar(shrink=0.75) - plt.title(title) - plt.savefig(save_name) - plt.close() - - -def calculate_hit_miss(df, obs_column, prob_column, threshold, probability): - """ - Reads dataframe with two columns for obs_elev, and probabilities - returns hit/miss/... based on user-defined threshold & probability - """ - hit = len(df[(df[obs_column] >= threshold) & (df[prob_column] >= probability)]) - miss = len(df[(df[obs_column] >= threshold) & (df[prob_column] < probability)]) - false_alarm = len(df[(df[obs_column] < threshold) & (df[prob_column] >= probability)]) - correct_neg = len(df[(df[obs_column] < threshold) & (df[prob_column] < probability)]) - - return hit, miss, false_alarm, correct_neg - - -def calculate_POD_FAR(hit, miss, false_alarm, correct_neg): - """ - Reads hit, miss, false_alarm, and correct_neg - returns POD and FAR - default POD and FAR are np.nan - """ - POD = np.nan - FAR = np.nan - try: - POD = round(hit / (hit + miss), 4) # Probability of Detection - except ZeroDivisionError: - pass - try: - FAR = round(false_alarm / (false_alarm + correct_neg), 4) # False Alarm Rate - except ZeroDivisionError: - pass - return POD, FAR - - -def main(args): - storm_name = args.storm_name.capitalize() - storm_year = args.storm_year - leadtime = args.leadtime - prob_nc_path = Path(args.prob_nc_path) - obs_df_path = Path(args.obs_df_path) - save_dir = args.save_dir - - # *.nc file coordinates - thresholds_ft = [3, 6, 9] # in ft - thresholds_m = [round(i * 0.3048, 4) for i in thresholds_ft] # convert to meter - sources = ['model', 'surrogate'] - probabilities = [0.0, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] - - # attributes of input files - prediction_variable = 'probabilities' - obs_attribute = 'Elev_m_xGEOID20b' - - # search criteria - max_distance = 1000 # [in meters] to set distance_upper_bound - max_neighbors = 10 # to set k - - blank_arr = np.empty((len(thresholds_ft), 1, 1, len(sources), len(probabilities))) - blank_arr[:] = np.nan - - hit_arr = blank_arr.copy() - miss_arr = blank_arr.copy() - false_alarm_arr = blank_arr.copy() - correct_neg_arr = blank_arr.copy() - POD_arr = blank_arr.copy() - FAR_arr = blank_arr.copy() - - # Load obs file, extract storm obs points and coordinates - df_obs = pd.read_csv(obs_df_path) - Event_name = f'{storm_name}_{storm_year}' - df_obs_storm = df_obs[df_obs.Event == Event_name] - obs_coordinates = stack_station_coordinates( - df_obs_storm.Longitude.values, df_obs_storm.Latitude.values - ) - - # Load probabilities.nc file - ds_prob = xr.open_dataset(prob_nc_path) - - gdf_countries = gpd.GeoSeries( - NaturalEarthFeature(category='physical', scale='10m', name='land',).geometries(), - crs=4326, - ) - - # Loop through thresholds and sources and find corresponding values from probabilities.nc - threshold_count = -1 - for threshold in thresholds_m: - threshold_count += 1 - source_count = -1 - for source in sources: - source_count += 1 - ds_temp = ds_prob.sel(level=threshold, source=source) - tree = create_search_tree(ds_temp.x.values, ds_temp.y.values) - dist, indices = tree.query( - obs_coordinates, k=max_neighbors, distance_upper_bound=max_distance * 1e-5 - ) # 0.01 is equivalent to 1000 m - prediction_prob = find_nearby_prediction( - ds=ds_temp, variable=prediction_variable, indices=indices - ) - df_obs_storm[f'{source}_prob'] = prediction_prob - - # Plot probabilities at obs. points - plot_probabilities( - df_obs_storm, - f'{source}_prob', - gdf_countries, - f'Probability of {source} exceeding {thresholds_ft[threshold_count]} ft \n {storm_name}, {storm_year}, {leadtime}-hr leadtime', - os.path.join( - save_dir, - f'prob_{source}_above_{thresholds_ft[threshold_count]}ft_{storm_name}_{storm_year}_{leadtime}-hr.png', - ), - ) - - # Loop through probabilities: calculate hit/miss/... & POD/FAR - prob_count = -1 - for prob in probabilities: - prob_count += 1 - hit, miss, false_alarm, correct_neg = calculate_hit_miss( - df_obs_storm, obs_attribute, f'{source}_prob', threshold, prob - ) - hit_arr[threshold_count, 0, 0, source_count, prob_count] = hit - miss_arr[threshold_count, 0, 0, source_count, prob_count] = miss - false_alarm_arr[threshold_count, 0, 0, source_count, prob_count] = false_alarm - correct_neg_arr[threshold_count, 0, 0, source_count, prob_count] = correct_neg - - pod, far = calculate_POD_FAR(hit, miss, false_alarm, correct_neg) - POD_arr[threshold_count, 0, 0, source_count, prob_count] = pod - FAR_arr[threshold_count, 0, 0, source_count, prob_count] = far - - ds_ROC = xr.Dataset( - coords=dict( - threshold=thresholds_ft, - storm=[storm_name], - leadtime=[leadtime], - source=sources, - prob=probabilities, - ), - data_vars=dict( - hit=(['threshold', 'storm', 'leadtime', 'source', 'prob'], hit_arr), - miss=(['threshold', 'storm', 'leadtime', 'source', 'prob'], miss_arr), - false_alarm=( - ['threshold', 'storm', 'leadtime', 'source', 'prob'], - false_alarm_arr, - ), - correct_neg=( - ['threshold', 'storm', 'leadtime', 'source', 'prob'], - correct_neg_arr, - ), - POD=(['threshold', 'storm', 'leadtime', 'source', 'prob'], POD_arr), - FAR=(['threshold', 'storm', 'leadtime', 'source', 'prob'], FAR_arr), - ), - ) - ds_ROC.to_netcdf( - os.path.join(save_dir, f'{storm_name}_{storm_year}_{leadtime}hr_leadtime_POD_FAR.nc') - ) - - # plot ROC curves - marker_list = ['s', 'x'] - linestyle_list = ['dashed', 'dotted'] - threshold_count = -1 - for threshold in thresholds_ft: - threshold_count += 1 - fig = plt.figure() - ax = fig.add_subplot(111) - plt.axline( - (0.0, 0.0), (1.0, 1.0), linestyle='--', color='grey', label='random prediction' - ) - source_count = -1 - for source in sources: - source_count += 1 - plt.plot( - FAR_arr[threshold_count, 0, 0, source_count, :], - POD_arr[threshold_count, 0, 0, source_count, :], - label=f'{source}', - marker=marker_list[source_count], - linestyle=linestyle_list[source_count], - markersize=5, - ) - plt.legend() - plt.xlabel('False Alarm Rate') - plt.ylabel('Probability of Detection') - - plt.title( - f'{storm_name}_{storm_year}, {leadtime}-hr leadtime, {threshold} ft threshold' - ) - plt.savefig( - os.path.join( - save_dir, f'ROC_{storm_name}_{leadtime}hr_leadtime_{threshold}_ft.png' - ) - ) - plt.close() - - -def cli(): - parser = argparse.ArgumentParser() - - parser.add_argument('--storm_name', help='name of the storm', type=str) - - parser.add_argument('--storm_year', help='year of the storm', type=int) - - parser.add_argument('--leadtime', help='OFCL track leadtime hr', type=int) - - parser.add_argument('--prob_nc_path', help='path to probabilities.nc', type=str) - - parser.add_argument('--obs_df_path', help='Path to observations dataframe', type=str) - - # optional - parser.add_argument( - '--save_dir', help='directory for saving analysis', default=os.getcwd(), type=str - ) - - main(parser.parse_args()) - - -if __name__ == '__main__': - warnings.filterwarnings('ignore') - # warnings.filterwarnings("ignore", category=DeprecationWarning) - cli() From d172802c17359a3d1918b2e9bbcaa9d8cdce6292 Mon Sep 17 00:00:00 2001 From: Fariborz Daneshvar Date: Fri, 30 Aug 2024 13:11:04 -0500 Subject: [PATCH 6/9] add tentative path to NHC_OBS for calculation of POD/FAR and plotting ROC ruves --- stormworkflow/refs/input.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/stormworkflow/refs/input.yaml b/stormworkflow/refs/input.yaml index 02df434..26d3b00 100644 --- a/stormworkflow/refs/input.yaml +++ b/stormworkflow/refs/input.yaml @@ -38,6 +38,7 @@ L_DEM_LO: "" L_MESH_HI: "" L_MESH_LO: "" L_SHP_DIR: "" +NHC_OBS: "" TMPDIR: "/tmp" PATH_APPEND: "" From 9ae8e0432a04ba40c8c0723a1a8fdaf69f7defe8 Mon Sep 17 00:00:00 2001 From: Fariborz Daneshvar Date: Tue, 3 Sep 2024 09:36:37 -0500 Subject: [PATCH 7/9] increase timelimits of mesh and prep to 1 hour --- stormworkflow/slurm/mesh.sbatch | 2 +- stormworkflow/slurm/prep.sbatch | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/stormworkflow/slurm/mesh.sbatch b/stormworkflow/slurm/mesh.sbatch index 757b9cd..503381a 100644 --- a/stormworkflow/slurm/mesh.sbatch +++ b/stormworkflow/slurm/mesh.sbatch @@ -1,7 +1,7 @@ #!/bin/bash #SBATCH --parsable #SBATCH --exclusive -#SBATCH --time=00:30:00 +#SBATCH --time=01:00:00 #SBATCH --nodes=1 set -ex diff --git a/stormworkflow/slurm/prep.sbatch b/stormworkflow/slurm/prep.sbatch index 8beb298..d5dd455 100644 --- a/stormworkflow/slurm/prep.sbatch +++ b/stormworkflow/slurm/prep.sbatch @@ -1,7 +1,7 @@ #!/bin/bash #SBATCH --parsable #SBATCH --exclusive -#SBATCH --time=00:30:00 +#SBATCH --time=01:00:00 #SBATCH --nodes=1 set -ex From 3111110a04a91e4a50d3b5e7e0cf9a0f99c3518a Mon Sep 17 00:00:00 2001 From: Fariborz Daneshvar Date: Wed, 4 Sep 2024 14:05:18 -0500 Subject: [PATCH 8/9] use geodatasets and make spatial plot of probabilities at obs. points --- stormworkflow/post/storm_roc_curve.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/stormworkflow/post/storm_roc_curve.py b/stormworkflow/post/storm_roc_curve.py index 4758374..fe7bde5 100644 --- a/stormworkflow/post/storm_roc_curve.py +++ b/stormworkflow/post/storm_roc_curve.py @@ -10,6 +10,8 @@ from pathlib import Path from cartopy.feature import NaturalEarthFeature +import geodatasets + os.environ['USE_PYGEOS'] = '0' import geopandas as gpd @@ -164,6 +166,8 @@ def main(args): # Load probabilities.nc file ds_prob = xr.open_dataset(prob_nc_path) + gdf_countries = gpd.read_file(geodatasets.get_path('naturalearth land')) + # gdf_countries = gpd.GeoSeries( # NaturalEarthFeature(category='physical', scale='10m', name='land',).geometries(), # crs=4326, @@ -186,17 +190,17 @@ def main(args): ) df_obs_storm[f'{source}_prob'] = prediction_prob - # # Plot probabilities at obs. points - # plot_probabilities( - # df_obs_storm, - # f'{source}_prob', - # gdf_countries, - # f'Probability of {source} exceeding {thresholds_ft[threshold_count]} ft \n {storm}, {year}, {leadtime}-hr leadtime', - # os.path.join( - # output_directory, - # f'prob_{source}_above_{thresholds_ft[threshold_count]}ft_{storm}_{year}_{leadtime}-hr.png', - # ), - # ) + # Plot probabilities at obs. points + plot_probabilities( + df_obs_storm, + f'{source}_prob', + gdf_countries, + f'Probability of {source} exceeding {thresholds_ft[threshold_count]} ft \n {storm}, {year}, {leadtime}-hr leadtime', + os.path.join( + output_directory, + f'prob_{source}_above_{thresholds_ft[threshold_count]}ft_{storm}_{year}_{leadtime}-hr.png', + ), + ) # Loop through probabilities: calculate hit/miss/... & POD/FAR prob_count = -1 From 21e8f9b8731c7be47be14c7db61db07e3d5719bd Mon Sep 17 00:00:00 2001 From: SorooshMani-NOAA Date: Thu, 5 Sep 2024 16:42:25 -0400 Subject: [PATCH 9/9] Update and improve version check --- stormworkflow/main.py | 51 +++++++++++++++++++------------ stormworkflow/refs/input.yaml | 2 +- tests/conftest.py | 6 +++- tests/data/refs/input_v0.0.4.yaml | 49 +++++++++++++++++++++++++++++ tests/test_input_version.py | 5 +++ 5 files changed, 91 insertions(+), 22 deletions(-) create mode 100644 tests/data/refs/input_v0.0.4.yaml diff --git a/stormworkflow/main.py b/stormworkflow/main.py index 7889365..244a3e0 100644 --- a/stormworkflow/main.py +++ b/stormworkflow/main.py @@ -18,17 +18,32 @@ _logger = logging.getLogger(__file__) -CUR_INPUT_VER = Version('0.0.3') +CUR_INPUT_VER = Version('0.0.4') +VER_UPDATE_FUNCS = [] -def _handle_input_v0_0_1_to_v0_0_2(inout_conf): +def _input_version(prev, curr): + def decorator(handler): + def wrapper(inout_conf): + ver = Version(inout_conf['input_version']) - ver = Version(inout_conf['input_version']) + # Only update config if specified version matches the + # assumed one + if ver != Version(prev): + return ver - # Only update config if specified version matches the assumed one - if ver != Version('0.0.1'): - return ver + # TODO: Need return values? + handler(inout_conf) + return Version(curr) + global VER_UPDATE_FUNCS + VER_UPDATE_FUNCS.append(wrapper) + return wrapper + return decorator + + +@_input_version('0.0.1', '0.0.2') +def _handle_input_v0_0_1_to_v0_0_2(inout_conf): _logger.info( "Adding perturbation variables for persistent RMW perturbation" @@ -40,24 +55,23 @@ def _handle_input_v0_0_1_to_v0_0_2(inout_conf): 'max_sustained_wind_speed', ] - return Version('0.0.2') - +@_input_version('0.0.2', '0.0.3') def _handle_input_v0_0_2_to_v0_0_3(inout_conf): - ver = Version(inout_conf['input_version']) - - # Only update config if specified version matches the assumed one - if ver != Version('0.0.2'): - return ver - - _logger.info( "Adding RMW fill method default to persistent" ) inout_conf['rmw_fill_method'] = 'persistent' - return Version('0.0.3') + +@_input_version('0.0.3', '0.0.4') +def _handle_input_v0_0_3_to_v0_0_4(inout_conf): + + _logger.info( + "Path to observations" + ) + inout_conf['NHC_OBS'] = '' def handle_input_version(inout_conf): @@ -77,10 +91,7 @@ def handle_input_version(inout_conf): f"Input version not supported! Max version supported is {CUR_INPUT_VER}" ) - for fn in [ - _handle_input_v0_0_1_to_v0_0_2, - _handle_input_v0_0_2_to_v0_0_3, - ]: + for fn in VER_UPDATE_FUNCS: ver = fn(inout_conf) inout_conf['input_version'] = str(ver) diff --git a/stormworkflow/refs/input.yaml b/stormworkflow/refs/input.yaml index 26d3b00..30aefa2 100644 --- a/stormworkflow/refs/input.yaml +++ b/stormworkflow/refs/input.yaml @@ -1,5 +1,5 @@ --- -input_version: 0.0.3 +input_version: 0.0.4 storm: "florence" year: 2018 diff --git a/tests/conftest.py b/tests/conftest.py index edaf46c..4d3c7e1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,10 +6,11 @@ refs = files('stormworkflow.refs') -test_refs = files('tests.data.refs') +test_refs = files('data.refs') input_v0_0_1 = test_refs.joinpath('input_v0.0.1.yaml') input_v0_0_2 = test_refs.joinpath('input_v0.0.2.yaml') input_v0_0_3 = test_refs.joinpath('input_v0.0.3.yaml') +input_v0_0_4 = test_refs.joinpath('input_v0.0.4.yaml') input_latest = refs.joinpath('input.yaml') @@ -33,6 +34,9 @@ def conf_v0_0_2(): def conf_v0_0_3(): return read_conf(input_v0_0_3) +@pytest.fixture +def conf_v0_0_4(): + return read_conf(input_v0_0_4) @pytest.fixture def conf_latest(): diff --git a/tests/data/refs/input_v0.0.4.yaml b/tests/data/refs/input_v0.0.4.yaml new file mode 100644 index 0000000..30aefa2 --- /dev/null +++ b/tests/data/refs/input_v0.0.4.yaml @@ -0,0 +1,49 @@ +--- +input_version: 0.0.4 + +storm: "florence" +year: 2018 +suffix: "" +subset_mesh: 1 +hr_prelandfall: -1 +past_forecast: 1 +hydrology: 0 +use_wwm: 0 +pahm_model: "gahm" +num_perturb: 2 +sample_rule: "korobov" +perturb_vars: + - "cross_track" + - "along_track" +# - "radius_of_maximum_winds" + - "radius_of_maximum_winds_persistent" + - "max_sustained_wind_speed" +rmw_fill_method: "persistent" + +spinup_exec: "pschism_PAHM_TVD-VL" +hotstart_exec: "pschism_PAHM_TVD-VL" + +hpc_solver_nnodes: 3 +hpc_solver_ntasks: 108 +hpc_account: "" +hpc_partition: "" + +RUN_OUT: "" +L_NWM_DATASET: "" +L_TPXO_DATASET: "" +L_LEADTIMES_DATASET: "" +L_TRACK_DIR: "" +L_DEM_HI: "" +L_DEM_LO: "" +L_MESH_HI: "" +L_MESH_LO: "" +L_SHP_DIR: "" +NHC_OBS: "" + +TMPDIR: "/tmp" +PATH_APPEND: "" + +L_SOLVE_MODULES: + - "intel/2022.1.2" + - "impi/2022.1.2" + - "netcdf" diff --git a/tests/test_input_version.py b/tests/test_input_version.py index 5e2ae38..126b981 100644 --- a/tests/test_input_version.py +++ b/tests/test_input_version.py @@ -45,3 +45,8 @@ def test_v0_0_2_to_latest(conf_v0_0_2, conf_latest): def test_v0_0_3_to_latest(conf_v0_0_3, conf_latest): handle_input_version(conf_v0_0_3) assert conf_latest == conf_v0_0_3 + + +def test_v0_0_4_to_latest(conf_v0_0_4, conf_latest): + handle_input_version(conf_v0_0_4) + assert conf_latest == conf_v0_0_4