diff --git a/CIME/XML/env_batch.py b/CIME/XML/env_batch.py index 580a8e9d434..aea6342b457 100644 --- a/CIME/XML/env_batch.py +++ b/CIME/XML/env_batch.py @@ -223,12 +223,35 @@ def get_job_overrides(self, job, case): overrides["tasks_per_node"] = tasks_per_node if thread_count: overrides["thread_count"] = thread_count + total_tasks = total_tasks * thread_count + else: + total_tasks = total_tasks * case.thread_count else: - total_tasks = case.get_value("TOTALPES") * int(case.thread_count) + # Total PES accounts for threads as well as mpi tasks + total_tasks = case.get_value("TOTALPES") thread_count = case.thread_count - if int(total_tasks) * int(thread_count) < case.get_value("MAX_TASKS_PER_NODE"): + if int(total_tasks) < case.get_value("MAX_TASKS_PER_NODE"): overrides["max_tasks_per_node"] = int(total_tasks) + # when developed this variable was only needed on derecho, but I have tried to + # make it general enough that it can be used on other systems by defining MEM_PER_TASK and MAX_MEM_PER_NODE in config_machines.xml + # and adding {{ mem_per_node }} in config_batch.xml + try: + mem_per_task = case.get_value("MEM_PER_TASK") + max_mem_per_node = case.get_value("MAX_MEM_PER_NODE") + mem_per_node = total_tasks + + if mem_per_node < mem_per_task: + mem_per_node = mem_per_task + elif mem_per_node > max_mem_per_node: + mem_per_node = max_mem_per_node + overrides["mem_per_node"] = mem_per_node + except TypeError: + # ignore this, the variables are not defined for this machine + pass + except Exception as error: + print("An exception occured:", error) + overrides["ngpus_per_node"] = ngpus_per_node overrides["mpirun"] = case.get_mpirun_cmd(job=job, overrides=overrides) return overrides diff --git a/CIME/data/config/xml_schemas/config_machines.xsd b/CIME/data/config/xml_schemas/config_machines.xsd index 6be3efd952e..d6cbe51f1ec 100644 --- a/CIME/data/config/xml_schemas/config_machines.xsd +++ b/CIME/data/config/xml_schemas/config_machines.xsd @@ -55,6 +55,8 @@ + + @@ -164,6 +166,13 @@ + + + + + + + + + + + + +