Skip to content

Commit

Permalink
Fix tests part two - somehow not broken in CI, but formerly broken on…
Browse files Browse the repository at this point in the history
… my local machine
  • Loading branch information
esoteric-ephemera committed Jan 9, 2024
1 parent 2b8eb89 commit 5fcdc8b
Showing 1 changed file with 33 additions and 19 deletions.
52 changes: 33 additions & 19 deletions tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,8 +727,16 @@ def test_nscf_kpoints_checks(test_dir, object_name):

# Explicit kpoints for NSCF calc check (this should not raise any flags for NSCF calcs)
temp_task_doc = copy.deepcopy(task_doc)
temp_task_doc.calcs_reversed[0].input.kpoints = Kpoints(
kpts=[[0, 0, 0], [0, 0, 0.5]], num_kpts=2, kpts_weights=[0.5, 0.5], labels=["Gamma", "X"], style="line_mode"
_update_kpoints_for_test(
temp_task_doc,
{
"kpoints": [[0, 0, 0], [0, 0, 0.5]],
"nkpoints": 2,
"kpts_weights": [0.5, 0.5],
"labels": ["Gamma", "X"],
"style": "line_mode",
"generation_style": "line_mode",
},
)
temp_validation_doc = ValidationDoc.from_task_doc(temp_task_doc)
assert not any(["INPUT SETTINGS --> KPOINTS: explicitly" in reason for reason in temp_validation_doc.reasons])
Expand Down Expand Up @@ -830,6 +838,15 @@ def test_common_error_checks(test_dir, object_name):
assert any(["COMPOSITION" in reason for reason in temp_validation_doc.reasons])


def _update_kpoints_for_test(task_doc: TaskDoc, kpoints_updates: dict):
if isinstance(task_doc.calcs_reversed[0].input.kpoints, Kpoints):
kpoints = task_doc.calcs_reversed[0].input.kpoints.as_dict()
elif isinstance(task_doc.calcs_reversed[0].input.kpoints, dict):
kpoints = task_doc.calcs_reversed[0].input.kpoints.copy()
kpoints.update(kpoints_updates)
task_doc.calcs_reversed[0].input.kpoints = Kpoints.from_dict(kpoints)


@pytest.mark.parametrize(
"object_name",
[
Expand All @@ -849,9 +866,7 @@ def test_kpoints_checks(test_dir, object_name):
coords=[[0, 0, 0], [0.333333333333333, -0.333333333333333, 0.5]],
species=["H", "H"],
) # HCP structure
kpoints = temp_task_doc.calcs_reversed[0].input.kpoints.as_dict()
kpoints["generation_style"] = "monkhorst"
temp_task_doc.calcs_reversed[0].input.kpoints = Kpoints.from_dict(kpoints)
_update_kpoints_for_test(temp_task_doc, {"generation_style": "monkhorst"})
temp_validation_doc = ValidationDoc.from_task_doc(temp_task_doc)
assert any(
["INPUT SETTINGS --> KPOINTS or KGAMMA: monkhorst-pack" in reason for reason in temp_validation_doc.reasons]
Expand All @@ -862,9 +877,7 @@ def test_kpoints_checks(test_dir, object_name):
temp_task_doc.calcs_reversed[0].input.structure = Structure(
lattice=[[0.0, 0.5, 0.5], [0.5, 0.0, 0.5], [0.5, 0.5, 0.0]], coords=[[0, 0, 0]], species=["H"]
) # FCC structure
kpoints = temp_task_doc.calcs_reversed[0].input.kpoints.as_dict()
kpoints["generation_style"] = "monkhorst"
temp_task_doc.calcs_reversed[0].input.kpoints = kpoints
_update_kpoints_for_test(temp_task_doc, {"generation_style": "monkhorst"})
temp_validation_doc = ValidationDoc.from_task_doc(temp_task_doc)
assert any(
["INPUT SETTINGS --> KPOINTS or KGAMMA: monkhorst-pack" in reason for reason in temp_validation_doc.reasons]
Expand All @@ -875,35 +888,36 @@ def test_kpoints_checks(test_dir, object_name):
temp_task_doc.calcs_reversed[0].input.structure = Structure(
lattice=[[2.9, 0, 0], [0, 2.9, 0], [0, 0, 2.9]], species=["H", "H"], coords=[[0, 0, 0], [0.5, 0.5, 0.5]]
) # BCC structure
kpoints = temp_task_doc.calcs_reversed[0].input.kpoints.as_dict()
kpoints["generation_style"] = "monkhorst"
temp_task_doc.calcs_reversed[0].input.kpoints = Kpoints.from_dict(kpoints)
_update_kpoints_for_test(temp_task_doc, {"generation_style": "monkhorst"})
temp_validation_doc = ValidationDoc.from_task_doc(temp_task_doc)
assert not any(
["INPUT SETTINGS --> KPOINTS or KGAMMA: monkhorst-pack" in reason for reason in temp_validation_doc.reasons]
)

# Too few kpoints check
temp_task_doc = copy.deepcopy(task_doc)
kpoints = temp_task_doc.calcs_reversed[0].input.kpoints.as_dict()
kpoints["kpoints"] = [[3, 3, 3]]
temp_task_doc.calcs_reversed[0].input.kpoints = Kpoints.from_dict(kpoints)
_update_kpoints_for_test(temp_task_doc, {"kpoints": [[3, 3, 3]]})
temp_validation_doc = ValidationDoc.from_task_doc(temp_task_doc)
assert any(["INPUT SETTINGS --> KPOINTS or KSPACING:" in reason for reason in temp_validation_doc.reasons])

# Explicit kpoints for SCF calc check
temp_task_doc = copy.deepcopy(task_doc)
temp_task_doc.calcs_reversed[0].input.kpoints = Kpoints(
kpts=[[0, 0, 0], [0, 0, 0.5]], num_kpts=2, kpts_weights=[0.5, 0.5], style="reciprocal"
_update_kpoints_for_test(
temp_task_doc,
{
"kpoints": [[0, 0, 0], [0, 0, 0.5]],
"nkpoints": 2,
"kpts_weights": [0.5, 0.5],
"style": "reciprocal",
"generation_style": "Reciprocal",
},
)
temp_validation_doc = ValidationDoc.from_task_doc(temp_task_doc)
assert any(["INPUT SETTINGS --> KPOINTS: explicitly" in reason for reason in temp_validation_doc.reasons])

# Shifting kpoints for SCF calc check
temp_task_doc = copy.deepcopy(task_doc)
kpoints = temp_task_doc.calcs_reversed[0].input.kpoints.as_dict()
kpoints["usershift"] = [0.5, 0, 0]
temp_task_doc.calcs_reversed[0].input.kpoints = Kpoints.from_dict(kpoints)
_update_kpoints_for_test(temp_task_doc, {"usershift": [0.5, 0, 0]})
temp_validation_doc = ValidationDoc.from_task_doc(temp_task_doc)
assert any(["INPUT SETTINGS --> KPOINTS: shifting" in reason for reason in temp_validation_doc.reasons])

Expand Down

0 comments on commit 5fcdc8b

Please sign in to comment.