Skip to content

Commit

Permalink
Merge pull request #782 from opensafely-core/log-stats
Browse files Browse the repository at this point in the history
Log some study definition stats and timings
  • Loading branch information
rebkwok authored Apr 22, 2022
2 parents 9e1e3dc + e7c8f14 commit 8d85aa9
Show file tree
Hide file tree
Showing 7 changed files with 654 additions and 70 deletions.
160 changes: 106 additions & 54 deletions cohortextractor/cohortextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
from cohortextractor.exceptions import DummyDataValidationError
from cohortextractor.generate_codelist_report import generate_codelist_report

from .log_utils import log_execution_time, log_stats

logger = structlog.get_logger()

notebook_tag = "opencorona-research"
Expand Down Expand Up @@ -149,16 +151,19 @@ def generate_cohort(
msg = "You can only provide dummy data for a single study definition"
raise DummyDataValidationError(msg)
for study_name, suffix in study_definitions:
_generate_cohort(
output_dir,
study_name,
suffix,
expectations_population,
dummy_data_file,
index_date_range=index_date_range,
skip_existing=skip_existing,
output_format=output_format,
)
with log_execution_time(
logger, description=f"generate_cohort for {study_name} (all index dates)"
):
_generate_cohort(
output_dir,
study_name,
suffix,
expectations_population,
dummy_data_file,
index_date_range=index_date_range,
skip_existing=skip_existing,
output_format=output_format,
)


def _generate_cohort(
Expand Down Expand Up @@ -187,25 +192,37 @@ def _generate_cohort(
study = load_study_definition(study_name)

os.makedirs(output_dir, exist_ok=True)
for index_date in _generate_date_range(index_date_range):

index_dates = _generate_date_range(index_date_range)
log_stats(logger, index_date_count=len(index_dates) if index_date_range else 0)
if index_date_range:
log_stats(logger, min_index_date=index_dates[-1], max_index_date=index_dates[0])

for index_date in index_dates:
log_event = f"generate_cohort for {study_name}"
if index_date is not None:
logger.info(f"Setting index_date to {index_date}")
study.set_index_date(index_date)
date_suffix = f"_{index_date}"
else:
date_suffix = ""
# If this is changed then the regex in `_generate_measures()`
# must be updated
output_file = f"{output_dir}/input{suffix}{date_suffix}.{output_format}"
if skip_existing and os.path.exists(output_file):
logger.info(f"Not regenerating pre-existing file at {output_file}")
else:
study.to_file(
output_file,
expectations_population=expectations_population,
dummy_data_file=dummy_data_file,
)
logger.info(f"Successfully created cohort and covariates at {output_file}")
log_event += f" at {index_date}"
with log_execution_time(logger, description=log_event):
if index_date is not None:
logger.info(f"Setting index_date to {index_date}")
study.set_index_date(index_date)
date_suffix = f"_{index_date}"
else:
date_suffix = ""
# If this is changed then the regex in `_generate_measures()`
# must be updated
output_file = f"{output_dir}/input{suffix}{date_suffix}.{output_format}"
if skip_existing and os.path.exists(output_file):
logger.info(f"Not regenerating pre-existing file at {output_file}")
else:
study.to_file(
output_file,
expectations_population=expectations_population,
dummy_data_file=dummy_data_file,
)
logger.info(
f"Successfully created cohort and covariates at {output_file}"
)


def _generate_date_range(date_range_str):
Expand Down Expand Up @@ -281,13 +298,17 @@ def generate_measures(
if study_name == selected_study_name:
study_definitions = [(study_name, suffix)]
break

for study_name, suffix in study_definitions:
_generate_measures(
output_dir,
study_name,
suffix,
skip_existing=skip_existing,
)
with log_execution_time(
logger, description="generate_measures (all input files)", study=study_name
):
_generate_measures(
output_dir,
study_name,
suffix,
skip_existing=skip_existing,
)


def _generate_measures(
Expand All @@ -303,6 +324,8 @@ def _generate_measures(
measures = load_study_definition(study_name, value="measures")
measure_outputs = defaultdict(list)
filename_re = re.compile(rf"^input{re.escape(suffix)}.+\.({EXTENSION_REGEX})$")

log_stats(logger, measures_count=len(measures))
for file in os.listdir(output_dir):
if not filename_re.match(file):
continue
Expand All @@ -311,26 +334,55 @@ def _generate_measures(
continue
filepath = os.path.join(output_dir, file)
logger.info(f"Calculating measures for {filepath}")
patient_df = None
for measure in measures:
logger.info(f"Calculating {measure.id}")
output_file = f"{output_dir}/measure_{measure.id}_{date}.csv"
measure_outputs[measure.id].append(output_file)
if skip_existing and os.path.exists(output_file):
logger.info(f"Not generating pre-existing file {output_file}")
continue
# We do this lazily so that if all corresponding output files
# already exist we can avoid loading the patient data entirely
if patient_df is None:
logger.info(f"Loading patient data from {filepath}")
patient_df = _load_dataframe_for_measures(filepath, measures)
logger.info(patient_df.memory_usage())

measure_df = measure.calculate(patient_df, _report)
logger.info(f"Data size for measure {measure.id}:")
logger.info(measure_df.memory_usage())
measure_df.to_csv(output_file, index=False)
logger.info(f"Created measure output at {output_file}")
with log_execution_time(
logger,
description="generate_measures",
input_file=filepath,
date=date,
study=study_name,
):
patient_df = None
for measure in measures:
logger.info(f"Calculating {measure.id}")
output_file = f"{output_dir}/measure_{measure.id}_{date}.csv"
measure_outputs[measure.id].append(output_file)
if skip_existing and os.path.exists(output_file):
logger.info(f"Not generating pre-existing file {output_file}")
continue
# We do this lazily so that if all corresponding output files
# already exist we can avoid loading the patient data entirely
if patient_df is None:
logger.info(f"Loading patient data from {filepath}")
with log_execution_time(
logger,
description="Load patient dataframe for measures",
input_file=filepath,
date=date,
):
patient_df = _load_dataframe_for_measures(filepath, measures)
log_stats(
logger,
dataframe="patient_df",
measure_id=measure.id,
date=date,
memory=patient_df.memory_usage(deep=True).sum(),
)
with log_execution_time(
logger,
description="Calculate measure",
measure_id=measure.id,
date=date,
):
measure_df = measure.calculate(patient_df, _report)
log_stats(
logger,
dataframe="measure_df",
measure_id=measure.id,
date=date,
memory=measure_df.memory_usage(deep=True).sum(),
)
measure_df.to_csv(output_file, index=False)
logger.info(f"Created measure output at {output_file}")
if not measure_outputs:
logger.warn(
"No matching output files found. You may need to first run:\n"
Expand Down
41 changes: 32 additions & 9 deletions cohortextractor/emis_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .csv_utils import is_csv_filename, write_rows_to_csv
from .date_expressions import TrinoDateFormatter
from .expressions import format_expression
from .log_utils import LoggingDatabaseConnection, log_execution_time, log_stats
from .pandas_utils import dataframe_from_rows, dataframe_to_file
from .trino_utils import trino_connection_from_url

Expand Down Expand Up @@ -99,10 +100,16 @@ def replace_ids_and_log(result):
# Special handling for CSV as we can stream this directly to disk
# without building a dataframe in memory
if is_csv_filename(filename):
write_rows_to_csv(results, filename)
with log_execution_time(
logger, description=f"write_rows_to_csv {filename}"
):
write_rows_to_csv(results, filename)
else:
df = dataframe_from_rows(self.covariate_definitions, results)
dataframe_to_file(df, filename)
with log_execution_time(
logger, description=f"Create df and write dataframe_to_file {filename}"
):
df = dataframe_from_rows(self.covariate_definitions, results)
dataframe_to_file(df, filename)

duplicates = total_rows - len(unique_ids)
if duplicates != 0:
Expand Down Expand Up @@ -247,6 +254,13 @@ def get_queries(self, covariate_definitions):
for sql_list in table_queries.values():
all_queries.extend(sql_list)
all_queries.append(joined_output_query)

log_stats(
logger,
output_column_count=len(output_columns),
table_count=len(table_queries),
table_joins_count=len(joins),
)
return all_queries

def get_column_expression(self, column_type, source, returning, date_format=None):
Expand Down Expand Up @@ -405,23 +419,30 @@ def execute_query(self):
for sql in queries:
table_name = re.search(r"CREATE TABLE IF NOT EXISTS (\w+)", sql).groups()[0]
logger.info(f"Running query for {table_name}")
cursor.execute(sql)
cursor.execute(sql, log_desc=f"Create table {table_name}")
if run_analyze:
cursor.execute(f"ANALYZE {table_name}")
cursor.execute(
f"ANALYZE {table_name}", log_desc=f"Analyze table {table_name}"
)

output_table = self.get_output_table_name(os.environ.get("TEMP_DATABASE_NAME"))
if output_table:
logger.info(f"Running final query and writing output to '{output_table}'")
sql = f"CREATE TABLE IF NOT EXISTS {output_table} AS {final_query}"
cursor.execute(sql)
cursor.execute(sql, log_desc="Create final query table")
logger.info(f"Downloading data from '{output_table}'")
cursor.execute(f"SELECT * FROM {output_table}")
cursor.execute(
f"SELECT * FROM {output_table}",
log_desc=f"Downloading data from '{output_table}'",
)
else:
logger.info(
"No TEMP_DATABASE_NAME defined in environment, downloading results "
"directly without writing to output table"
)
cursor.execute(final_query)
cursor.execute(
final_query, log_desc="Download results without writing to output table"
)
return cursor

def get_output_table_name(self, temporary_database):
Expand Down Expand Up @@ -1569,7 +1590,9 @@ def patients_with_death_recorded_in_cpns(
def get_db_connection(self):
if self._db_connection:
return self._db_connection
self._db_connection = trino_connection_from_url(self.database_url)
self._db_connection = LoggingDatabaseConnection(
logger, trino_connection_from_url(self.database_url)
)
return self._db_connection

def close(self):
Expand Down
83 changes: 83 additions & 0 deletions cohortextractor/log_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import logging.config
import os
from contextlib import contextmanager
from datetime import timedelta
from time import process_time

import structlog

Expand Down Expand Up @@ -60,3 +63,83 @@ def init_logging():
},
}
)


def log_stats(logger, **kwargs):
logger.info("cohortextractor-stats", **kwargs)


@contextmanager
def log_execution_time(logger, **log_kwargs):
sql = log_kwargs.pop("sql", None)
if sql:
sql_lines = sql.split("\n")
if len(sql_lines):
sql = "\n".join(sql_lines) + "..."
log_kwargs["sql"] = sql

start = process_time()
try:
yield
finally:
elapsed_time = process_time() - start
log_kwargs.update(
execution_time_secs=elapsed_time,
execution_time=str(timedelta(seconds=elapsed_time)),
)
log_stats(logger, **log_kwargs)


class BaseLoggingWrapper:
"""
Wraps a class instance and provides a logger instance as an attribute.
Subclasses can implement their own methods to override methods on the
wrapped instance and make use of the logger.
Any attribute or method called on the wrapper calls its own implementation if
one exists, otherwise calls it on the class instance
"""

def __init__(self, logger, wrapped_instance):
self.logger = logger
self.wrapped_instance = wrapped_instance

def __getattr__(self, attr):
if attr in dir(self):
return attr
return getattr(self.wrapped_instance, attr)


class LoggingCursor(BaseLoggingWrapper):
"""
Provides a database cursor instance that willl log the execution time of any
`execute` call
"""

def __init__(self, logger, cursor):
super().__init__(logger, cursor)
self.cursor = cursor

def __iter__(self):
return self.cursor.__iter__()

def __next__(self):
return self.cursor.__next__()

def execute(self, query, log_desc=None):
with log_execution_time(self.logger, sql=query, description=log_desc):
self.cursor.execute(query)


class LoggingDatabaseConnection(BaseLoggingWrapper):
"""
Provides a database connection instance with a LoggingCursor
"""

def __init__(self, logger, database_connection):
super().__init__(logger, database_connection)
self.db_connection = database_connection

def cursor(self):
return LoggingCursor(self.logger, self.db_connection.cursor())
Loading

0 comments on commit 8d85aa9

Please sign in to comment.