diff --git a/modelskill/comparison/_collection.py b/modelskill/comparison/_collection.py index 23e9bf860..6393251c7 100644 --- a/modelskill/comparison/_collection.py +++ b/modelskill/comparison/_collection.py @@ -817,6 +817,9 @@ def weighted_mean(x): agg[metric.__name__] = weighted_mean # type: ignore res = skilldf.groupby(by).agg(agg) + # TODO is this correct? + res.index.name = "model" + # output res = cmp._add_as_col_if_not_in_index(df, res, fields=["model", "variable"]) return SkillTable(res.astype({"n": int})) @@ -1031,14 +1034,7 @@ def score( metric_name = metric if isinstance(metric, str) else metric.__name__ - # TODO dict comprehension? - score = {} - for model in models: - mtr_val = df.loc[model][metric_name] - if not np.isscalar(mtr_val): - # e.g. mean over different variables! - mtr_val = mtr_val.values.mean() - score[model] = mtr_val + score = df[metric_name].to_dict() return score