diff --git a/modelskill/comparison/_comparison.py b/modelskill/comparison/_comparison.py index bbd9b0b3..16755825 100644 --- a/modelskill/comparison/_comparison.py +++ b/modelskill/comparison/_comparison.py @@ -928,6 +928,38 @@ def sel( d = d.isel(time=mask) return Comparer.from_matched_data(data=d, raw_mod_data=raw_mod_data) + def drop( + self, + model: Optional[IdxOrNameTypes] = None, + ) -> "Comparer": + """Drop specified model(s) from the Comparer. + + Parameters + ---------- + model : str or int or list of str or list of int, optional + Model name or index. If None, all models are selected. + + Returns + ------- + Comparer + New Comparer excluding specified data. + """ + dropped_cmp = self + + if model is not None: + if isinstance(model, (str, int)): + models = [model] + else: + models = list(model) + models_to_drop: List[str] = [_get_name(m, self.mod_names) for m in models] + models_to_keep: List[str] = [ + m for m in self.mod_names if m not in models_to_drop + ] + + dropped_cmp = dropped_cmp.sel(model=models_to_keep) + + return dropped_cmp + def where( self, cond: Union[bool, np.ndarray, xr.DataArray], diff --git a/tests/test_comparer.py b/tests/test_comparer.py index 2d8691d0..325aa9a4 100644 --- a/tests/test_comparer.py +++ b/tests/test_comparer.py @@ -524,6 +524,49 @@ def test_tc_sel_time_and_area(tc): assert tc2.data.Observation.values.tolist() == [2.0] +def test_pc_drop_model(pc): + pc2 = pc.drop(model="m2") + assert isinstance(pc2, type(pc)) + assert pc2.n_models == pc.n_models - 1 + assert "m2" not in pc2.mod_names + assert "m2" not in pc2.raw_mod_data + assert "m2" not in pc2.data + assert np.all(pc.data.m1 == pc2.data.m1) + assert np.all(pc.raw_mod_data["m1"] == pc2.raw_mod_data["m1"]) + + +def test_pc_drop_model_first(pc): + pc2 = pc.drop(model=0) + assert isinstance(pc2, type(pc)) + assert pc2.n_models == pc.n_models - 1 + assert "m1" not in pc2.mod_names + assert "m1" not in pc2.raw_mod_data + assert "m1" not in pc2.data + assert np.all(pc.data.m2 == pc2.data.m2) + assert np.all(pc.raw_mod_data["m2"] == pc2.raw_mod_data["m2"]) + + +def test_pc_drop_model_last(pc): + pc2 = pc.drop(model=-1) + assert isinstance(pc2, type(pc)) + assert pc2.n_models == pc.n_models - 1 + assert "m2" not in pc2.mod_names + assert "m2" not in pc2.raw_mod_data + assert "m2" not in pc2.data + assert np.all(pc.data.m1 == pc2.data.m1) + assert np.all(pc.raw_mod_data["m1"] == pc2.raw_mod_data["m1"]) + + +def test_tc_drop_model(tc): + tc2 = tc.drop(model="m2") + assert isinstance(tc2, type(tc)) + assert tc2.n_models == tc.n_models - 1 + assert "m2" not in tc2.mod_names + assert "m2" not in tc2.raw_mod_data + assert "m2" not in tc2.data + assert np.all(tc.data.m1 == tc2.data.m1) + + def test_pc_where(pc): pc2 = pc.where(pc.data.Observation > 2.5) assert pc2.n_points == 3