Skip to content

Commit

Permalink
Update collect results for unit tests and README.md
Browse files Browse the repository at this point in the history
  • Loading branch information
Piotr Teterwark committed Jul 26, 2024
1 parent 31d4652 commit dad3ca3
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 8 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ The [currently available algorithms](domainbed/algorithms.py) are:
* Empirical Quantile Risk Minimization (EQRM, [Eastwood et al., 2022](https://arxiv.org/abs/2207.09944)), contributed by [@cianeastwood](https://github.com/cianeastwood)
* Domain Generalisation via Risk Distribution Matching (RDM, [Nguyen et al., 2024](https://arxiv.org/abs/2310.18598)), contributed by [@nktoan](https://github.com/nktoan), [authors' contact email](mailto:[email protected])
* ADRMX: Additive Disentanglement of Domain Features with Remix Loss (ADRMX, [Demirel et al., 2023](https://arxiv.org/abs/2308.06624)), contributed by [@berkerdemirel](https://github.com/berkerdemirel)
* ERM++: An Improved Baseline for Domain Generalization( ERM++, [Teterwak et. al. 2023](https://arxiv.org/abs/2304.01973), contributed by [@piotr-teterwak](https://cs-people.bu.edu/piotrt/).

Send us a PR to add your algorithm! Our implementations use ResNet50 / ResNet18 networks ([He et al., 2015](https://arxiv.org/abs/1512.03385)) and the hyper-parameter grids [described here](domainbed/hparams_registry.py).

Expand Down
20 changes: 12 additions & 8 deletions domainbed/scripts/collect_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,9 @@ def print_results_tables(records, selection_method, latex):
"""Given all records, print a results table for each dataset."""
grouped_records = reporting.get_grouped_records(records)

for r in grouped_records:
r['records'] = merge_records(r['records'])
if selection_method == model_selection.IIDAutoLRAccuracySelectionMethod:
for r in grouped_records:
r['records'] = merge_records(r['records'])

grouped_records = grouped_records.map(lambda group:
{ **group, "sweep_acc": selection_method.sweep_acc(group["records"]) }
Expand Down Expand Up @@ -192,6 +193,7 @@ def print_results_tables(records, selection_method, latex):
description="Domain generalization testbed")
parser.add_argument("--input_dir", type=str, required=True)
parser.add_argument("--latex", action="store_true")
parser.add_argument("--auto_lr", action="store_true")
args = parser.parse_args()

results_file = "results.tex" if args.latex else "results.txt"
Expand All @@ -210,12 +212,14 @@ def print_results_tables(records, selection_method, latex):
else:
print("Total records:", len(records))

SELECTION_METHODS = [
model_selection.IIDAccuracySelectionMethod,
model_selection.IIDAutoLRAccuracySelectionMethod,
model_selection.LeaveOneOutSelectionMethod,
model_selection.OracleSelectionMethod,
]
if args.auto_lr:
SELECTION_METHODS = [model_selection.IIDAutoLRAccuracySelectionMethod]
else:
SELECTION_METHODS = [
model_selection.IIDAccuracySelectionMethod,
model_selection.LeaveOneOutSelectionMethod,
model_selection.OracleSelectionMethod,
]

for selection_method in SELECTION_METHODS:
if args.latex:
Expand Down

0 comments on commit dad3ca3

Please sign in to comment.