Skip to content

Commit

Permalink
Merge pull request #76 from oceanmodeling/feature/ROC_curve
Browse files Browse the repository at this point in the history
Feature/roc curve
  • Loading branch information
SorooshMani-NOAA authored Sep 5, 2024
2 parents 7045f62 + 21e8f9b commit 7974f4f
Show file tree
Hide file tree
Showing 10 changed files with 132 additions and 58 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -99,4 +99,4 @@ download_data = "stormworkflow.prep.download_data:cli"
setup_ensemble = "stormworkflow.prep.setup_ensemble:cli"
combine_ensemble = "stormworkflow.post.combine_ensemble:cli"
analyze_ensemble = "stormworkflow.post.analyze_ensemble:cli"
storm_roc_curve = "stormworkflow.post.ROC_single_run:cli"
storm_roc_curve = "stormworkflow.post.storm_roc_curve:cli"
51 changes: 31 additions & 20 deletions stormworkflow/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,32 @@

_logger = logging.getLogger(__file__)

CUR_INPUT_VER = Version('0.0.3')
CUR_INPUT_VER = Version('0.0.4')
VER_UPDATE_FUNCS = []


def _handle_input_v0_0_1_to_v0_0_2(inout_conf):
def _input_version(prev, curr):
def decorator(handler):
def wrapper(inout_conf):
ver = Version(inout_conf['input_version'])

ver = Version(inout_conf['input_version'])
# Only update config if specified version matches the
# assumed one
if ver != Version(prev):
return ver

# Only update config if specified version matches the assumed one
if ver != Version('0.0.1'):
return ver
# TODO: Need return values?
handler(inout_conf)

return Version(curr)
global VER_UPDATE_FUNCS
VER_UPDATE_FUNCS.append(wrapper)
return wrapper
return decorator


@_input_version('0.0.1', '0.0.2')
def _handle_input_v0_0_1_to_v0_0_2(inout_conf):

_logger.info(
"Adding perturbation variables for persistent RMW perturbation"
Expand All @@ -40,24 +55,23 @@ def _handle_input_v0_0_1_to_v0_0_2(inout_conf):
'max_sustained_wind_speed',
]

return Version('0.0.2')


@_input_version('0.0.2', '0.0.3')
def _handle_input_v0_0_2_to_v0_0_3(inout_conf):

ver = Version(inout_conf['input_version'])

# Only update config if specified version matches the assumed one
if ver != Version('0.0.2'):
return ver


_logger.info(
"Adding RMW fill method default to persistent"
)
inout_conf['rmw_fill_method'] = 'persistent'

return Version('0.0.3')

@_input_version('0.0.3', '0.0.4')
def _handle_input_v0_0_3_to_v0_0_4(inout_conf):

_logger.info(
"Path to observations"
)
inout_conf['NHC_OBS'] = ''


def handle_input_version(inout_conf):
Expand All @@ -77,10 +91,7 @@ def handle_input_version(inout_conf):
f"Input version not supported! Max version supported is {CUR_INPUT_VER}"
)

for fn in [
_handle_input_v0_0_1_to_v0_0_2,
_handle_input_v0_0_2_to_v0_0_3,
]:
for fn in VER_UPDATE_FUNCS:
ver = fn(inout_conf)
inout_conf['input_version'] = str(ver)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from pathlib import Path
from cartopy.feature import NaturalEarthFeature

import geodatasets

os.environ['USE_PYGEOS'] = '0'
import geopandas as gpd

Expand Down Expand Up @@ -117,12 +119,17 @@ def calculate_POD_FAR(hit, miss, false_alarm, correct_neg):


def main(args):
storm_name = args.storm_name.capitalize()
storm_year = args.storm_year
storm = args.storm.capitalize()
year = args.year
leadtime = args.leadtime
prob_nc_path = Path(args.prob_nc_path)
obs_df_path = Path(args.obs_df_path)
save_dir = args.save_dir
ensemble_dir = Path(args.ensemble_dir)

output_directory = ensemble_dir / 'analyze/linear_k1_p1_n0.025'
prob_nc_path = output_directory / 'probabilities.nc'

if leadtime == -1:
leadtime = 48

# *.nc file coordinates
thresholds_ft = [3, 6, 9] # in ft
Expand Down Expand Up @@ -150,7 +157,7 @@ def main(args):

# Load obs file, extract storm obs points and coordinates
df_obs = pd.read_csv(obs_df_path)
Event_name = f'{storm_name}_{storm_year}'
Event_name = f'{storm}_{year}'
df_obs_storm = df_obs[df_obs.Event == Event_name]
obs_coordinates = stack_station_coordinates(
df_obs_storm.Longitude.values, df_obs_storm.Latitude.values
Expand All @@ -159,10 +166,12 @@ def main(args):
# Load probabilities.nc file
ds_prob = xr.open_dataset(prob_nc_path)

gdf_countries = gpd.GeoSeries(
NaturalEarthFeature(category='physical', scale='10m', name='land',).geometries(),
crs=4326,
)
gdf_countries = gpd.read_file(geodatasets.get_path('naturalearth land'))

# gdf_countries = gpd.GeoSeries(
# NaturalEarthFeature(category='physical', scale='10m', name='land',).geometries(),
# crs=4326,
# )

# Loop through thresholds and sources and find corresponding values from probabilities.nc
threshold_count = -1
Expand All @@ -186,10 +195,10 @@ def main(args):
df_obs_storm,
f'{source}_prob',
gdf_countries,
f'Probability of {source} exceeding {thresholds_ft[threshold_count]} ft \n {storm_name}, {storm_year}, {leadtime}-hr leadtime',
f'Probability of {source} exceeding {thresholds_ft[threshold_count]} ft \n {storm}, {year}, {leadtime}-hr leadtime',
os.path.join(
save_dir,
f'prob_{source}_above_{thresholds_ft[threshold_count]}ft_{storm_name}_{storm_year}_{leadtime}-hr.png',
output_directory,
f'prob_{source}_above_{thresholds_ft[threshold_count]}ft_{storm}_{year}_{leadtime}-hr.png',
),
)

Expand All @@ -212,7 +221,7 @@ def main(args):
ds_ROC = xr.Dataset(
coords=dict(
threshold=thresholds_ft,
storm=[storm_name],
storm=[storm],
leadtime=[leadtime],
source=sources,
prob=probabilities,
Expand All @@ -232,9 +241,7 @@ def main(args):
FAR=(['threshold', 'storm', 'leadtime', 'source', 'prob'], FAR_arr),
),
)
ds_ROC.to_netcdf(
os.path.join(save_dir, f'{storm_name}_{storm_year}_{leadtime}hr_leadtime_POD_FAR.nc')
)
ds_ROC.to_netcdf(os.path.join(output_directory, f'{storm}_{year}_{leadtime}hr_POD_FAR.nc'))

# plot ROC curves
marker_list = ['s', 'x']
Expand Down Expand Up @@ -262,12 +269,11 @@ def main(args):
plt.xlabel('False Alarm Rate')
plt.ylabel('Probability of Detection')

plt.title(
f'{storm_name}_{storm_year}, {leadtime}-hr leadtime, {threshold} ft threshold'
)
plt.title(f'{storm}_{year}, {leadtime}-hr leadtime, {threshold} ft threshold')
plt.savefig(
os.path.join(
save_dir, f'ROC_{storm_name}_{leadtime}hr_leadtime_{threshold}_ft.png'
output_directory,
f'ROC_{storm}_{year}_{leadtime}hr_leadtime_{threshold}_ft.png',
)
)
plt.close()
Expand All @@ -276,20 +282,11 @@ def main(args):
def cli():
parser = argparse.ArgumentParser()

parser.add_argument('--storm_name', help='name of the storm', type=str)

parser.add_argument('--storm_year', help='year of the storm', type=int)

parser.add_argument('--storm', help='name of the storm', type=str)
parser.add_argument('--year', help='year of the storm', type=int)
parser.add_argument('--leadtime', help='OFCL track leadtime hr', type=int)

parser.add_argument('--prob_nc_path', help='path to probabilities.nc', type=str)

parser.add_argument('--obs_df_path', help='Path to observations dataframe', type=str)

# optional
parser.add_argument(
'--save_dir', help='directory for saving analysis', default=os.getcwd(), type=str
)
parser.add_argument('--obs_df_path', help='path to NHC obs data', type=str)
parser.add_argument('--ensemble-dir', help='path to ensemble.dir', type=str)

main(parser.parse_args())

Expand Down
3 changes: 2 additions & 1 deletion stormworkflow/refs/input.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
---
input_version: 0.0.3
input_version: 0.0.4

storm: "florence"
year: 2018
Expand Down Expand Up @@ -38,6 +38,7 @@ L_DEM_LO: ""
L_MESH_HI: ""
L_MESH_LO: ""
L_SHP_DIR: ""
NHC_OBS: ""

TMPDIR: "/tmp"
PATH_APPEND: ""
Expand Down
2 changes: 1 addition & 1 deletion stormworkflow/slurm/mesh.sbatch
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/bash
#SBATCH --parsable
#SBATCH --exclusive
#SBATCH --time=00:30:00
#SBATCH --time=01:00:00
#SBATCH --nodes=1

set -ex
Expand Down
7 changes: 7 additions & 0 deletions stormworkflow/slurm/post.sbatch
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,10 @@ combine_ensemble \
analyze_ensemble \
--ensemble-dir $ENSEMBLE_DIR \
--tracks-dir $ENSEMBLE_DIR/track_files

storm_roc_curve \
--storm ${storm} \
--year ${year} \
--leadtime ${hr_prelandfall} \
--obs_df_path ${NHC_OBS} \
--ensemble-dir $ENSEMBLE_DIR
2 changes: 1 addition & 1 deletion stormworkflow/slurm/prep.sbatch
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/bash
#SBATCH --parsable
#SBATCH --exclusive
#SBATCH --time=00:30:00
#SBATCH --time=01:00:00
#SBATCH --nodes=1

set -ex
Expand Down
6 changes: 5 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@


refs = files('stormworkflow.refs')
test_refs = files('tests.data.refs')
test_refs = files('data.refs')
input_v0_0_1 = test_refs.joinpath('input_v0.0.1.yaml')
input_v0_0_2 = test_refs.joinpath('input_v0.0.2.yaml')
input_v0_0_3 = test_refs.joinpath('input_v0.0.3.yaml')
input_v0_0_4 = test_refs.joinpath('input_v0.0.4.yaml')
input_latest = refs.joinpath('input.yaml')


Expand All @@ -33,6 +34,9 @@ def conf_v0_0_2():
def conf_v0_0_3():
return read_conf(input_v0_0_3)

@pytest.fixture
def conf_v0_0_4():
return read_conf(input_v0_0_4)

@pytest.fixture
def conf_latest():
Expand Down
49 changes: 49 additions & 0 deletions tests/data/refs/input_v0.0.4.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
---
input_version: 0.0.4

storm: "florence"
year: 2018
suffix: ""
subset_mesh: 1
hr_prelandfall: -1
past_forecast: 1
hydrology: 0
use_wwm: 0
pahm_model: "gahm"
num_perturb: 2
sample_rule: "korobov"
perturb_vars:
- "cross_track"
- "along_track"
# - "radius_of_maximum_winds"
- "radius_of_maximum_winds_persistent"
- "max_sustained_wind_speed"
rmw_fill_method: "persistent"

spinup_exec: "pschism_PAHM_TVD-VL"
hotstart_exec: "pschism_PAHM_TVD-VL"

hpc_solver_nnodes: 3
hpc_solver_ntasks: 108
hpc_account: ""
hpc_partition: ""

RUN_OUT: ""
L_NWM_DATASET: ""
L_TPXO_DATASET: ""
L_LEADTIMES_DATASET: ""
L_TRACK_DIR: ""
L_DEM_HI: ""
L_DEM_LO: ""
L_MESH_HI: ""
L_MESH_LO: ""
L_SHP_DIR: ""
NHC_OBS: ""

TMPDIR: "/tmp"
PATH_APPEND: ""

L_SOLVE_MODULES:
- "intel/2022.1.2"
- "impi/2022.1.2"
- "netcdf"
5 changes: 5 additions & 0 deletions tests/test_input_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,8 @@ def test_v0_0_2_to_latest(conf_v0_0_2, conf_latest):
def test_v0_0_3_to_latest(conf_v0_0_3, conf_latest):
handle_input_version(conf_v0_0_3)
assert conf_latest == conf_v0_0_3


def test_v0_0_4_to_latest(conf_v0_0_4, conf_latest):
handle_input_version(conf_v0_0_4)
assert conf_latest == conf_v0_0_4

0 comments on commit 7974f4f

Please sign in to comment.