From 30b29bd2f509d5ee957e48b558bd09f80f6b2771 Mon Sep 17 00:00:00 2001 From: Henrik Andersson Date: Fri, 26 Jan 2024 14:31:17 +0100 Subject: [PATCH 1/3] Accept custom metric funs without registering them --- modelskill/metrics.py | 19 +++++++++++++++---- tests/test_metrics.py | 15 +++++++++++++++ tests/test_multimodelcompare.py | 9 +++------ 3 files changed, 33 insertions(+), 10 deletions(-) diff --git a/modelskill/metrics.py b/modelskill/metrics.py index a053b8e0a..189ac1fe6 100644 --- a/modelskill/metrics.py +++ b/modelskill/metrics.py @@ -65,6 +65,7 @@ 0.39614855570839064 """ from __future__ import annotations +import inspect import sys import warnings @@ -1199,11 +1200,21 @@ def _parse_metric( elif isinstance(metric, Iterable): metrics = list(metric) - for metric in metrics: - if not isinstance(metric, str) and not callable(metric): - raise TypeError(f"metric {metric} must be a string or callable") + parsed_metrics = [] + + for m in metrics: + if isinstance(m, str): + parsed_metrics.append(get_metric(m)) + elif callable(m): + if len(inspect.signature(m).parameters) < 2: + raise ValueError( + "Metrics must have at least two arguments (obs, model)" + ) + parsed_metrics.append(m) + else: + raise TypeError(f"metric {m} must be a string or callable") - return [get_metric(m) for m in metrics] + return parsed_metrics __all__ = list(defined_metrics) diff --git a/tests/test_metrics.py b/tests/test_metrics.py index df4221177..62e8e5b5e 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -241,3 +241,18 @@ def test_add_metric_is_not_a_valid_metric(): def test_get_metric(): rmse = mtr.get_metric("rmse") assert isinstance(rmse, Callable) + + +def test_parse_metric_custom_fun(): + def my_metric(obs, model): + return 1.0 + + assert mtr._parse_metric(my_metric) == [my_metric] + + +def test_parse_bad_metric(): + def not_a_metric(obs): + return 1.0 + + with pytest.raises(ValueError): + mtr._parse_metric(not_a_metric) diff --git a/tests/test_multimodelcompare.py b/tests/test_multimodelcompare.py index 56893054f..2cf4f6473 100644 --- a/tests/test_multimodelcompare.py +++ b/tests/test_multimodelcompare.py @@ -361,16 +361,13 @@ def test_custom_metric_skilltable_mm_scatter(cc): cc.skill(metrics=["cm_1"]) assert sk["cm_1"] is not None - # using a non-registred metric raises an error, even though it is a defined function, but not registered + # using a non-registred metric raises an error, since it cannot be found in the registry with pytest.raises(ValueError) as e_info: cc.skill(metrics=["cm_3"]) - assert "add_metric" in str(e_info.value) - with pytest.raises(ValueError) as e_info: - cc.skill(metrics=[cm_3]) - - assert "add_metric" in str(e_info.value) + # using it as a function directly is ok + cc.skill(metrics=[cm_3]) def test_mm_kde(cc): From 0e30b68490dd40cc6b9de610e7b18607dfeb2ad2 Mon Sep 17 00:00:00 2001 From: Henrik Andersson Date: Mon, 29 Jan 2024 16:07:59 +0100 Subject: [PATCH 2/3] Update metrics notebook --- notebooks/Metrics_custom_metric.ipynb | 301 +++++++++++++------------- tests/testdata/nu_plugin_plot | 1 + 2 files changed, 147 insertions(+), 155 deletions(-) create mode 160000 tests/testdata/nu_plugin_plot diff --git a/notebooks/Metrics_custom_metric.ipynb b/notebooks/Metrics_custom_metric.ipynb index 9e6545b05..b027bbdea 100644 --- a/notebooks/Metrics_custom_metric.ipynb +++ b/notebooks/Metrics_custom_metric.ipynb @@ -89,118 +89,95 @@ { "data": { "text/html": [ - "\n", - "\n", + "
\n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - "
 nbiasrmseurmsemaeccsir2
nbiasrmseurmsemaeccsir2
observation        observation
EPL67-0.070.220.210.190.970.080.93HKNA386-0.3153800.4473110.3172100.3413440.9683230.1021220.847042
HKNA386-0.190.350.290.250.970.090.91EPL67-0.0775200.2279270.2143390.1926890.9694540.0828660.929960
c2113-0.000.350.350.290.970.130.90c2113-0.0047010.3524700.3524390.2947580.9750500.1280100.899121
\n" + "\n", + "" ], "text/plain": [ - "" + " n bias rmse urmse mae cc si \\\n", + "observation \n", + "HKNA 386 -0.315380 0.447311 0.317210 0.341344 0.968323 0.102122 \n", + "EPL 67 -0.077520 0.227927 0.214339 0.192689 0.969454 0.082866 \n", + "c2 113 -0.004701 0.352470 0.352439 0.294758 0.975050 0.128010 \n", + "\n", + " r2 \n", + "observation \n", + "HKNA 0.847042 \n", + "EPL 0.929960 \n", + "c2 0.899121 " ] }, "execution_count": 4, @@ -209,7 +186,7 @@ } ], "source": [ - "cc.skill().style(precision=2)" + "cc.skill()" ] }, { @@ -257,19 +234,19 @@ " \n", " \n", " \n", - " EPL\n", - " 67\n", - " 0.188513\n", - " \n", - " \n", " HKNA\n", " 386\n", - " 0.251839\n", + " 0.341344\n", + " \n", + " \n", + " EPL\n", + " 67\n", + " 0.192689\n", " \n", " \n", " c2\n", " 113\n", - " 0.294585\n", + " 0.294758\n", " \n", " \n", "\n", @@ -278,9 +255,9 @@ "text/plain": [ " n mean_absolute_error\n", "observation \n", - "EPL 67 0.188513\n", - "HKNA 386 0.251839\n", - "c2 113 0.294585" + "HKNA 386 0.341344\n", + "EPL 67 0.192689\n", + "c2 113 0.294758" ] }, "execution_count": 5, @@ -308,47 +285,64 @@ { "data": { "text/html": [ - "\n", - "\n", + "
\n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - "
 nhit_ratio_05_pcthit_ratio_01_pct
nhit_ratio_05_pcthit_ratio_01_pct
observation   observation
EPL679927HKNA38680.05181317.098446
HKNA3868730EPL6798.50746328.358209
c21138617c211385.84070817.699115
\n" + "\n", + "" ], "text/plain": [ - "" + " n hit_ratio_05_pct hit_ratio_01_pct\n", + "observation \n", + "HKNA 386 80.051813 17.098446\n", + "EPL 67 98.507463 28.358209\n", + "c2 113 85.840708 17.699115" ] }, "execution_count": 6, @@ -367,10 +361,7 @@ " return hit_ratio(obs, model, 0.1) * 100\n", "\n", "\n", - "mtr.add_metric(hit_ratio_05_pct)\n", - "mtr.add_metric(hit_ratio_01_pct)\n", - "\n", - "cc.skill(metrics=[hit_ratio_05_pct, hit_ratio_01_pct]).style(precision=0)" + "cc.skill(metrics=[hit_ratio_05_pct, hit_ratio_01_pct])" ] }, { @@ -378,7 +369,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "And you are of course always free to specify your own special metric." + "And you are of course always free to specify your own special metric or import metrics from other libraries, e.g. scikit-learn." ] }, { @@ -408,7 +399,7 @@ " \n", " \n", " n\n", - " my_special_metric\n", + " mcae\n", " \n", " \n", " observation\n", @@ -418,30 +409,30 @@ " \n", " \n", " \n", - " EPL\n", - " 67\n", - " 0.127555\n", - " \n", - " \n", " HKNA\n", " 386\n", - " 0.223049\n", + " 0.328362\n", + " \n", + " \n", + " EPL\n", + " 67\n", + " 0.135104\n", " \n", " \n", " c2\n", " 113\n", - " 0.147897\n", + " 0.149729\n", " \n", " \n", "\n", "" ], "text/plain": [ - " n my_special_metric\n", - "observation \n", - "EPL 67 0.127555\n", - "HKNA 386 0.223049\n", - "c2 113 0.147897" + " n mcae\n", + "observation \n", + "HKNA 386 0.328362\n", + "EPL 67 0.135104\n", + "c2 113 0.149729" ] }, "execution_count": 7, @@ -450,7 +441,7 @@ } ], "source": [ - "def my_special_metric(obs, model):\n", + "def my_special_metric_with_long_descriptive_name(obs, model):\n", "\n", " res = obs - model\n", "\n", @@ -458,10 +449,10 @@ "\n", " return np.mean(np.abs(res_clipped))\n", "\n", + "# short alias to avoid long column names in output\n", + "def mcae(obs, model): return my_special_metric_with_long_descriptive_name(obs, model)\n", "\n", - "mtr.add_metric(my_special_metric)\n", - "\n", - "cc.skill(metrics=my_special_metric)" + "cc.skill(metrics=mcae)" ] } ], @@ -483,7 +474,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.1" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/tests/testdata/nu_plugin_plot b/tests/testdata/nu_plugin_plot new file mode 160000 index 000000000..3e633b4fc --- /dev/null +++ b/tests/testdata/nu_plugin_plot @@ -0,0 +1 @@ +Subproject commit 3e633b4fc74e10a2b992de8d38130d8e7094a4d8 From 1637b6b04eed6854a3584e4cb8b6ff3f962312c9 Mon Sep 17 00:00:00 2001 From: Henrik Andersson Date: Tue, 30 Jan 2024 09:51:23 +0100 Subject: [PATCH 3/3] Not relevant --- tests/testdata/nu_plugin_plot | 1 - 1 file changed, 1 deletion(-) delete mode 160000 tests/testdata/nu_plugin_plot diff --git a/tests/testdata/nu_plugin_plot b/tests/testdata/nu_plugin_plot deleted file mode 160000 index 3e633b4fc..000000000 --- a/tests/testdata/nu_plugin_plot +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 3e633b4fc74e10a2b992de8d38130d8e7094a4d8