Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for saving and loading simulation state to / from files #1227

Merged
merged 42 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
0269ea1
Factor out parts of simulate method
matt-graham Nov 14, 2023
a0a848f
Further refactoring of Simulation
matt-graham Dec 8, 2023
7ea292f
Add methods for saving and loading simulations
matt-graham Dec 8, 2023
3fd5cd3
Add initial test for simulation saving and loading
matt-graham Dec 8, 2023
7e5f666
Factor out and add additional simulation test checks
matt-graham Dec 11, 2023
c2a15c3
Explicitly set logger output file when loading from pickle
matt-graham Dec 11, 2023
921ab2e
Check next date in event queue before popping
matt-graham Dec 11, 2023
7596fc6
Make pytest seed parameter session scoped
matt-graham Dec 11, 2023
369ea88
Don't use next on counter in test check
matt-graham Dec 11, 2023
6c1afd8
Refactor global constants to fixtures in simulation tests + additiona…
matt-graham Dec 11, 2023
e8bd4d8
Move logging configuration out of load_from_pickle
matt-graham Dec 11, 2023
775cac1
Add test for exception when simulation past end date
matt-graham Dec 11, 2023
d3ec718
Add docstrings for new methods
matt-graham Dec 11, 2023
cc71c01
Add errors when running without initialising or initialising multiple…
matt-graham Dec 11, 2023
a5d7289
Add dill to dependencies
matt-graham Dec 11, 2023
2bb4066
Sort imports
matt-graham Dec 11, 2023
fc60e46
Merge branch 'master' into mmg/refactor-simulate
matt-graham Dec 11, 2023
97af3b0
Fix fenceposting error in simulation end date
matt-graham Dec 12, 2023
c81a0f3
Merge branch 'master' into mmg/refactor-simulate
matt-graham Dec 12, 2023
1d84be6
Merge branch 'master' into mmg/refactor-simulate
matt-graham Mar 20, 2024
cd1a310
Merge branch 'master' into mmg/refactor-simulate
matt-graham Apr 8, 2024
e520cca
Merge branch 'master' into mmg/refactor-simulate
matt-graham Jun 17, 2024
05c61dd
Merge branch 'master' into mmg/refactor-simulate
matt-graham Jul 24, 2024
d53da61
Fix explicit comparison to type
matt-graham Jul 24, 2024
40d3eaa
Add option to configure logging when loading from pickle
matt-graham Jul 25, 2024
5205a27
Move check for open log file in close_output_file method
matt-graham Jul 25, 2024
e795670
Tidy up docstrings and type hints
matt-graham Jul 25, 2024
1b1c179
Remove use of configure_logging in test
matt-graham Jul 25, 2024
1604bc2
Update scenario to allow suspending and resuming
matt-graham Jul 31, 2024
dc20983
Add utility function to merge log files
matt-graham Jul 31, 2024
cbadce6
Add test to check equality of parsed log files in suspend-resume
matt-graham Jul 31, 2024
9c139e9
Fix import sort order
matt-graham Jul 31, 2024
1d91b95
Merge branch 'master' into mmg/refactor-simulate
matt-graham Sep 9, 2024
19a2603
Merge branch 'master' into mmg/refactor-simulate
matt-graham Sep 24, 2024
ec10b40
Update pinned dill version to 0.3.8
matt-graham Sep 26, 2024
8d9000a
Adding log message when loading suspended simulation
matt-graham Sep 26, 2024
39f5ce4
Adding log message when saving suspended simulation
matt-graham Sep 26, 2024
60f011c
Increase simulation pop size and duration in test
matt-graham Sep 26, 2024
6f5a76d
Avoid reading in log files to be merged all at once
matt-graham Sep 26, 2024
4eb2ebe
Add tests for merge_log_files function
matt-graham Sep 26, 2024
87e5fa9
Fix import order sorting
matt-graham Sep 26, 2024
f0bb572
Fix import order sorting (second attempt)
matt-graham Sep 26, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@
'exclude-members': '__dict__, name, rng, sim' # , read_parameters',
}

# Include both class level and __init__ docstring content in class documentation
autoclass_content = 'both'

# The checker can't see private repos
linkcheck_ignore = ['^https://github.com/UCL/TLOmodel.*',
'https://www.who.int/bulletin/volumes/88/8/09-068213/en/nn']
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ dependencies = [
"azure-identity",
"azure-keyvault",
"azure-storage-file-share",
# For saving and loading simulation state
"dill",
]
description = "Thanzi la Onse Epidemiology Model"
dynamic = ["version"]
Expand Down
3 changes: 3 additions & 0 deletions requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ cryptography==41.0.3
# pyjwt
cycler==0.11.0
# via matplotlib
dill==0.3.8
# via tlo (pyproject.toml)
et-xmlfile==1.1.0
# via openpyxl
fonttools==4.42.1
Expand Down Expand Up @@ -112,6 +114,7 @@ pyjwt[crypto]==2.8.0
# via
# adal
# msal
# pyjwt
pyparsing==3.1.1
# via matplotlib
pyshp==2.3.1
Expand Down
35 changes: 9 additions & 26 deletions requirements/dev.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# This file is autogenerated by pip-compile with Python 3.8
# This file is autogenerated by pip-compile with Python 3.11
# by the following command:
#
# pip-compile --extra=dev --output-file=requirements/dev.txt
Expand Down Expand Up @@ -61,7 +61,9 @@ colorama==0.4.6
contourpy==1.1.1
# via matplotlib
coverage[toml]==7.3.1
# via pytest-cov
# via
# coverage
# pytest-cov
cryptography==41.0.3
# via
# adal
Expand All @@ -72,14 +74,14 @@ cryptography==41.0.3
# pyjwt
cycler==0.11.0
# via matplotlib
dill==0.3.7
# via pylint
dill==0.3.8
# via
# pylint
# tlo (pyproject.toml)
distlib==0.3.7
# via virtualenv
et-xmlfile==1.1.0
# via openpyxl
exceptiongroup==1.1.3
# via pytest
execnet==2.0.2
# via pytest-xdist
filelock==3.12.4
Expand All @@ -94,10 +96,6 @@ gitpython==3.1.36
# via tlo (pyproject.toml)
idna==3.4
# via requests
importlib-metadata==6.8.0
# via build
importlib-resources==6.1.1
# via matplotlib
iniconfig==2.0.0
# via pytest
isodate==0.6.1
Expand Down Expand Up @@ -172,6 +170,7 @@ pyjwt[crypto]==2.8.0
# via
# adal
# msal
# pyjwt
pylint==3.0.1
# via tlo (pyproject.toml)
pyparsing==3.1.1
Expand Down Expand Up @@ -221,29 +220,17 @@ smmap==5.0.1
# via gitdb
squarify==0.4.3
# via tlo (pyproject.toml)
tomli==2.0.1
# via
# build
# coverage
# pip-tools
# pylint
# pyproject-api
# pyproject-hooks
# pytest
# tox
tomlkit==0.12.1
# via pylint
tox==4.11.3
# via tlo (pyproject.toml)
typing-extensions==4.8.0
# via
# astroid
# azure-core
# azure-keyvault-certificates
# azure-keyvault-keys
# azure-keyvault-secrets
# azure-storage-file-share
# pylint
tzdata==2023.3
# via pandas
urllib3==2.0.4
Expand All @@ -254,10 +241,6 @@ virtualenv==20.24.5
# tox
wheel==0.41.2
# via pip-tools
zipp==3.17.0
# via
# importlib-metadata
# importlib-resources

# The following packages are considered to be unsafe in a requirements file:
# pip
Expand Down
37 changes: 36 additions & 1 deletion src/tlo/analysis/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
General utility functions for TLO analysis
"""
import fileinput
import gzip
import json
import os
Expand Down Expand Up @@ -86,6 +87,40 @@ def parse_log_file(log_filepath, level: int = logging.INFO):
return LogsDict({name: handle.name for name, handle in module_name_to_filehandle.items()}, level)


def merge_log_files(log_path_1: Path, log_path_2: Path, output_path: Path) -> None:
"""Merge two log files, skipping any repeated header lines.

:param log_path_1: Path to first log file to merge. Records from this log file will
appear first in merged log file.
:param log_path_2: Path to second log file to merge. Records from this log file will
appear after those in log file at `log_path_1` and any header lines in this file
which are also present in log file at `log_path_1` will be skipped.
:param output_path: Path to write merged log file to. Must not be one of `log_path_1`
or `log_path_2` as data is read from files while writing to this path.
"""
if output_path == log_path_1 or output_path == log_path_2:
msg = "output_path must not be equal to log_path_1 or log_path_2"
raise ValueError(msg)
with fileinput.input(files=(log_path_1, log_path_2), mode="r") as log_lines:
with output_path.open("w") as output_file:
written_header_lines = {}
for log_line in log_lines:
log_data = json.loads(log_line)
if "type" in log_data and log_data["type"] == "header":
if log_data["uuid"] in written_header_lines:
previous_header_line = written_header_lines[log_data["uuid"]]
if previous_header_line == log_line:
continue
else:
msg = (
"Inconsistent header lines with matching UUIDs found when merging logs:\n"
f"{previous_header_line}\n{log_line}\n"
)
raise RuntimeError(msg)
written_header_lines[log_data["uuid"]] = log_line
output_file.write(log_line)


def write_log_to_excel(filename, log_dataframes):
"""Takes the output of parse_log_file() and creates an Excel file from dataframes"""
metadata = list()
Expand Down Expand Up @@ -1131,7 +1166,7 @@ def get_parameters_for_status_quo() -> Dict:
"equip_availability": "all", # <--- NB. Existing calibration is assuming all equipment is available
},
}

def get_parameters_for_standard_mode2_runs() -> Dict:
"""
Returns a dictionary of parameters and their updated values to indicate
Expand Down
73 changes: 61 additions & 12 deletions src/tlo/scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def draw_parameters(self, draw_number, rng):

from tlo import Date, Simulation, logging
from tlo.analysis.utils import parse_log_file
from tlo.util import str_to_pandas_date

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -141,6 +142,16 @@ def parse_arguments(self, extra_arguments: List[str]) -> None:
self.arguments = extra_arguments

parser = argparse.ArgumentParser()
parser.add_argument(
"--resume-simulation",
type=str,
help="Directory containing suspended state files to resume simulation from",
)
parser.add_argument(
"--suspend-date",
type=str_to_pandas_date,
help="Date to suspend the simulation at",
)

# add arguments from the subclass
self.add_arguments(parser)
Expand Down Expand Up @@ -382,20 +393,58 @@ def run_sample_by_number(self, output_directory, draw_number, sample_number):
sample = self.get_sample(draw, sample_number)
log_config = self.scenario.get_log_config(output_directory)

logger.info(key="message", data=f"Running draw {sample['draw_number']}, sample {sample['sample_number']}")

sim = Simulation(
start_date=self.scenario.start_date,
seed=sample["simulation_seed"],
log_config=log_config
logger.info(
key="message",
data=f"Running draw {sample['draw_number']}, sample {sample['sample_number']}",
)
sim.register(*self.scenario.modules())

if sample["parameters"] is not None:
self.override_parameters(sim, sample["parameters"])

sim.make_initial_population(n=self.scenario.pop_size)
sim.simulate(end_date=self.scenario.end_date)
# if user has specified a restore simulation, we load it from a pickle file
if (
hasattr(self.scenario, "resume_simulation")
and self.scenario.resume_simulation is not None
):
suspended_simulation_path = (
Path(self.scenario.resume_simulation)
/ str(draw_number)
/ str(sample_number)
/ "suspended_simulation.pickle"
)
logger.info(
key="message",
data=f"Loading pickled suspended simulation from {suspended_simulation_path}",
)
sim = Simulation.load_from_pickle(pickle_path=suspended_simulation_path, log_config=log_config)
else:
sim = Simulation(
start_date=self.scenario.start_date,
seed=sample["simulation_seed"],
log_config=log_config,
)
sim.register(*self.scenario.modules())

if sample["parameters"] is not None:
self.override_parameters(sim, sample["parameters"])

sim.make_initial_population(n=self.scenario.pop_size)
sim.initialise(end_date=self.scenario.end_date)

# if user has specified a suspend date, we run the simulation to that date and
# save it to a pickle file
if (
hasattr(self.scenario, "suspend_date")
and self.scenario.suspend_date is not None
):
sim.run_simulation_to(to_date=self.scenario.suspend_date)
matt-graham marked this conversation as resolved.
Show resolved Hide resolved
suspended_simulation_path = Path(log_config["directory"]) / "suspended_simulation.pickle"
sim.save_to_pickle(pickle_path=suspended_simulation_path)
sim.close_output_file()
logger.info(
key="message",
data=f"Simulation suspended at {self.scenario.suspend_date} and saved to {suspended_simulation_path}",
)
else:
sim.run_simulation_to(to_date=self.scenario.end_date)
sim.finalise()

if sim.log_filepath is not None:
outputs = parse_log_file(sim.log_filepath)
Expand Down
Loading