diff --git a/README.md b/README.md index 2064bda1..b2868b9b 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,7 @@ The [currently available algorithms](domainbed/algorithms.py) are: * Empirical Quantile Risk Minimization (EQRM, [Eastwood et al., 2022](https://arxiv.org/abs/2207.09944)), contributed by [@cianeastwood](https://github.com/cianeastwood) * Domain Generalisation via Risk Distribution Matching (RDM, [Nguyen et al., 2024](https://arxiv.org/abs/2310.18598)), contributed by [@nktoan](https://github.com/nktoan), [authors' contact email](mailto:ktoan271199@gmail.com) * ADRMX: Additive Disentanglement of Domain Features with Remix Loss (ADRMX, [Demirel et al., 2023](https://arxiv.org/abs/2308.06624)), contributed by [@berkerdemirel](https://github.com/berkerdemirel) +* ERM++: An Improved Baseline for Domain Generalization( ERM++, [Teterwak et. al. 2023](https://arxiv.org/abs/2304.01973), contributed by [@piotr-teterwak](https://cs-people.bu.edu/piotrt/). Send us a PR to add your algorithm! Our implementations use ResNet50 / ResNet18 networks ([He et al., 2015](https://arxiv.org/abs/1512.03385)) and the hyper-parameter grids [described here](domainbed/hparams_registry.py). diff --git a/domainbed/scripts/collect_results.py b/domainbed/scripts/collect_results.py index 073d3dd5..1fadec80 100644 --- a/domainbed/scripts/collect_results.py +++ b/domainbed/scripts/collect_results.py @@ -108,8 +108,9 @@ def print_results_tables(records, selection_method, latex): """Given all records, print a results table for each dataset.""" grouped_records = reporting.get_grouped_records(records) - for r in grouped_records: - r['records'] = merge_records(r['records']) + if selection_method == model_selection.IIDAutoLRAccuracySelectionMethod: + for r in grouped_records: + r['records'] = merge_records(r['records']) grouped_records = grouped_records.map(lambda group: { **group, "sweep_acc": selection_method.sweep_acc(group["records"]) } @@ -192,6 +193,7 @@ def print_results_tables(records, selection_method, latex): description="Domain generalization testbed") parser.add_argument("--input_dir", type=str, required=True) parser.add_argument("--latex", action="store_true") + parser.add_argument("--auto_lr", action="store_true") args = parser.parse_args() results_file = "results.tex" if args.latex else "results.txt" @@ -210,12 +212,14 @@ def print_results_tables(records, selection_method, latex): else: print("Total records:", len(records)) - SELECTION_METHODS = [ - model_selection.IIDAccuracySelectionMethod, - model_selection.IIDAutoLRAccuracySelectionMethod, - model_selection.LeaveOneOutSelectionMethod, - model_selection.OracleSelectionMethod, - ] + if args.auto_lr: + SELECTION_METHODS = [model_selection.IIDAutoLRAccuracySelectionMethod] + else: + SELECTION_METHODS = [ + model_selection.IIDAccuracySelectionMethod, + model_selection.LeaveOneOutSelectionMethod, + model_selection.OracleSelectionMethod, + ] for selection_method in SELECTION_METHODS: if args.latex: