diff --git a/CIME/XML/env_batch.py b/CIME/XML/env_batch.py
index 580a8e9d434..aea6342b457 100644
--- a/CIME/XML/env_batch.py
+++ b/CIME/XML/env_batch.py
@@ -223,12 +223,35 @@ def get_job_overrides(self, job, case):
overrides["tasks_per_node"] = tasks_per_node
if thread_count:
overrides["thread_count"] = thread_count
+ total_tasks = total_tasks * thread_count
+ else:
+ total_tasks = total_tasks * case.thread_count
else:
- total_tasks = case.get_value("TOTALPES") * int(case.thread_count)
+ # Total PES accounts for threads as well as mpi tasks
+ total_tasks = case.get_value("TOTALPES")
thread_count = case.thread_count
- if int(total_tasks) * int(thread_count) < case.get_value("MAX_TASKS_PER_NODE"):
+ if int(total_tasks) < case.get_value("MAX_TASKS_PER_NODE"):
overrides["max_tasks_per_node"] = int(total_tasks)
+ # when developed this variable was only needed on derecho, but I have tried to
+ # make it general enough that it can be used on other systems by defining MEM_PER_TASK and MAX_MEM_PER_NODE in config_machines.xml
+ # and adding {{ mem_per_node }} in config_batch.xml
+ try:
+ mem_per_task = case.get_value("MEM_PER_TASK")
+ max_mem_per_node = case.get_value("MAX_MEM_PER_NODE")
+ mem_per_node = total_tasks
+
+ if mem_per_node < mem_per_task:
+ mem_per_node = mem_per_task
+ elif mem_per_node > max_mem_per_node:
+ mem_per_node = max_mem_per_node
+ overrides["mem_per_node"] = mem_per_node
+ except TypeError:
+ # ignore this, the variables are not defined for this machine
+ pass
+ except Exception as error:
+ print("An exception occured:", error)
+
overrides["ngpus_per_node"] = ngpus_per_node
overrides["mpirun"] = case.get_mpirun_cmd(job=job, overrides=overrides)
return overrides
diff --git a/CIME/data/config/xml_schemas/config_machines.xsd b/CIME/data/config/xml_schemas/config_machines.xsd
index 6be3efd952e..d6cbe51f1ec 100644
--- a/CIME/data/config/xml_schemas/config_machines.xsd
+++ b/CIME/data/config/xml_schemas/config_machines.xsd
@@ -55,6 +55,8 @@
+
+
@@ -164,6 +166,13 @@
+
+
+
+
+
+
+
+
+
+
+
+
+