Skip to content

Commit

Permalink
Move changes to dummy_data into dummy_data_nextgen
Browse files Browse the repository at this point in the history
  • Loading branch information
DRMacIver committed Oct 22, 2024
1 parent 88572ed commit 3bfea65
Show file tree
Hide file tree
Showing 6 changed files with 270 additions and 270 deletions.
78 changes: 3 additions & 75 deletions ehrql/dummy_data/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,47 +173,14 @@ def get_patient_data(self, patient_id, table_names):
]
return data

def get_patient_column(self, column_name):
for table_name in self.query_info.population_table_names:
try:
return self.query_info.tables[table_name].columns[column_name]
except KeyError:
pass

def generate_patient_facts(self, patient_id):
# Seed the random generator using the patient_id so we always generate the same
# data for the same patient
self.rnd.seed(f"{self.random_seed}:{patient_id}")
# TODO: We could obviously generate more realistic age distributions than this

while True:
# Retry until we have a date of birth and date of death that are
# within reasonable ranges
dob_column = self.get_patient_column("date_of_birth")
if dob_column is not None and dob_column.get_constraint(
Constraint.GeneralRange
):
self.events_start = self.today - timedelta(days=120 * 365)
self.events_end = self.today
date_of_birth = self.get_random_value(dob_column)
else:
date_of_birth = self.today - timedelta(
days=self.rnd.randrange(0, 120 * 365)
)

dod_column = self.get_patient_column("date_of_death")
if dod_column is not None and dod_column.get_constraint(
Constraint.GeneralRange
):
date_of_death = self.get_random_value(dod_column)
else:
age_days = self.rnd.randrange(105 * 365)
date_of_death = date_of_birth + timedelta(days=age_days)

if date_of_death >= date_of_birth and (
date_of_death - date_of_birth < timedelta(105 * 365)
):
break
date_of_birth = self.today - timedelta(days=self.rnd.randrange(0, 120 * 365))
age_days = self.rnd.randrange(105 * 365)
date_of_death = date_of_birth + timedelta(days=age_days)

self.date_of_birth = date_of_birth
self.date_of_death = date_of_death if date_of_death < self.today else None
Expand Down Expand Up @@ -268,45 +235,6 @@ def get_random_value(self, column_info):
range_constraint.maximum + 1,
range_constraint.step,
)
elif (column_info.type is date) and (
date_range_constraint := column_info.get_constraint(Constraint.GeneralRange)
):
if date_range_constraint.maximum is not None:
maximum = date_range_constraint.maximum
else:
maximum = self.today

if not date_range_constraint.includes_maximum:
maximum -= timedelta(days=1)
if date_range_constraint.minimum is not None:
minimum = date_range_constraint.minimum
if not date_range_constraint.includes_minimum:
minimum += timedelta(days=1)
# TODO: Currently this code only runs when the column is date_of_birth
# so condition is always hit. Remove this pragma when that stops being
# the case.
if column_info.get_constraint(
Constraint.FirstOfMonth
): # pragma: no branch
if minimum.month == 12:
minimum = minimum.replace(year=minimum.year + 1, month=1, day=1)
else:
minimum = minimum.replace(month=minimum.month + 1, day=1)
else:
minimum = (maximum - timedelta(days=100 * 365)).replace(day=1)

assert minimum <= maximum

days = (maximum - minimum).days
result = minimum + timedelta(days=random.randint(0, days))
# TODO: Currently this code only runs when the column is date_of_birth
# so condition is always hit. Remove this pragma when that stops being
# the case.
if column_info.get_constraint(Constraint.FirstOfMonth): # pragma: no branch
assert minimum.day == 1
result = result.replace(day=1)
assert minimum <= result <= maximum
return result
elif column_info.values_used:
if self.rnd.randint(0, len(column_info.values_used)) != 0:
return self.rnd.choice(column_info.values_used)
Expand Down
183 changes: 4 additions & 179 deletions ehrql/dummy_data/query_info.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import dataclasses
from collections import defaultdict
from datetime import date, timedelta
from functools import cached_property, lru_cache
from functools import cached_property

from ehrql.query_model.introspection import all_unique_nodes, get_table_nodes
from ehrql.query_model.nodes import (
Expand All @@ -15,7 +14,6 @@
Value,
get_root_frame,
)
from ehrql.query_model.table_schema import Constraint


@dataclasses.dataclass
Expand All @@ -30,17 +28,11 @@ class ColumnInfo:
_values_used: set = dataclasses.field(default_factory=set)

@classmethod
def from_column(cls, name, column, extra_constraints=()):
def from_column(cls, name, column):
type_ = column.type_
if hasattr(type_, "_primitive_type"):
type_ = type_._primitive_type()
return cls(
name,
type_,
constraints=normalize_constraints(
tuple(column.constraints) + tuple(extra_constraints)
),
)
return cls(name, type_, constraints=tuple(column.constraints))

def __post_init__(self):
self._constraints_by_type = {type(c): c for c in self.constraints}
Expand Down Expand Up @@ -109,10 +101,6 @@ def from_variable_definitions(cls, variable_definitions):
all_nodes = all_unique_nodes(*variable_definitions.values())
by_type = get_nodes_by_type(all_nodes)

extra_constraints = query_to_column_constraints(
variable_definitions["population"]
)

tables = {
# Create a TableInfo object …
table.name: TableInfo.from_table(table)
Expand Down Expand Up @@ -140,9 +128,7 @@ def from_variable_definitions(cls, variable_definitions):
if column_info is None:
# … insert a ColumnInfo object into the appropriate table
column_info = ColumnInfo.from_column(
name,
table.schema.get_column(name),
extra_constraints=extra_constraints.get(column, ()),
name, table.schema.get_column(name)
)
table_info.columns[name] = column_info
# Record the ColumnInfo object associated with each SelectColumn node
Expand Down Expand Up @@ -190,164 +176,3 @@ def get_nodes_by_type(nodes):

def sort_by_name(iterable):
return sorted(iterable, key=lambda i: i.name)


@lru_cache
def query_to_column_constraints(query):
"""Converts a query (typically a population definition) into
constraints that would have to be applied to a record in order
to satisfy it."""
match query:
case Function.And(lhs=lhs, rhs=rhs):
left = query_to_column_constraints(lhs)
right = query_to_column_constraints(rhs)
keys = set(left) | set(right)
return {k: left.get(k, []) + right.get(k, []) for k in keys}
case Function.Or(lhs=lhs, rhs=rhs):
left = query_to_column_constraints(lhs)
right = query_to_column_constraints(rhs)
result = {}
for k, v in left.items():
try:
result[k] = list(set(v) & set(right[k]))
except KeyError:
pass
for k, v in list(result.items()):
if not v:
del result[k]
return result
case Function.EQ(
lhs=SelectColumn() as lhs,
rhs=Value(value=value),
):
return {lhs: [Constraint.Categorical(values=(value,))]}
case Function.EQ(
lhs=Function.YearFromDate(source=SelectColumn() as column),
rhs=Value(value=year),
):
return {
column: [
Constraint.GeneralRange(
minimum=date(year, 1, 1),
maximum=date(year, 12, 31),
)
]
}
case Function.In(
lhs=SelectColumn() as lhs,
rhs=Value(value=values),
):
return {lhs: [Constraint.Categorical(values=values)]}
case Function.GE(
lhs=Function.DateDifferenceInYears(
lhs=Value(value=reference_date), rhs=column
),
rhs=Value(value=difference),
):
return {
column: [
Constraint.GeneralRange(
maximum=reference_date - timedelta(days=365 * difference)
)
]
}
case Function.LE(
lhs=Function.DateDifferenceInYears(
lhs=Value(value=reference_date), rhs=column
),
rhs=Value(value=difference),
):
return {
column: [
Constraint.GeneralRange(
minimum=reference_date.replace(
year=reference_date.year - difference
)
)
]
}
case Function.LT(
lhs=Function.DateAddYears(
lhs=SelectColumn() as column,
rhs=Value(value=difference),
),
rhs=Value(value=reference_date),
):
return {
column: [
Constraint.GeneralRange(
maximum=reference_date.replace(
year=reference_date.year - difference
),
includes_maximum=False,
)
]
}
case Function.GT(lhs=SelectColumn() as column, rhs=Value(value=min_value)):
return {
column: [
Constraint.GeneralRange(minimum=min_value, includes_minimum=False)
]
}
case Function.GE(lhs=SelectColumn() as column, rhs=Value(value=min_value)):
return {
column: [
Constraint.GeneralRange(minimum=min_value, includes_minimum=True)
]
}
case Function.LT(lhs=SelectColumn() as column, rhs=Value(value=max_value)):
return {
column: [
Constraint.GeneralRange(maximum=max_value, includes_maximum=False)
]
}
case Function.LE(lhs=SelectColumn() as column, rhs=Value(value=max_value)):
return {
column: [
Constraint.GeneralRange(maximum=max_value, includes_maximum=True)
]
}
case Function.IsNull(source=SelectColumn() as column):
return {column: [Constraint.NotNull()]}

return {}


def normalize_constraints(constraints):
group_by_type = defaultdict(list)
for constraint in constraints:
group_by_type[type(constraint)].append(constraint)
if len(group_by_type[Constraint.Categorical]) > 1:
constraint, *rest = group_by_type[Constraint.Categorical]
for more in rest:
constraint = Constraint.Categorical(
values=set(constraint.values) & set(more.values)
)
group_by_type[Constraint.Categorical] = [constraint]
if len(ranges := group_by_type[Constraint.GeneralRange]) > 1:
minimum = None
maximum = None
for r in ranges:
if minimum is None:
minimum = r.minimum
elif r.minimum is not None:
minimum = max(minimum, r.minimum)
if maximum is None:
maximum = r.maximum
elif r.maximum is not None:
maximum = min(maximum, r.maximum)

includes_minimum = minimum is None or all(r.validate(minimum) for r in ranges)
includes_maximum = maximum is None or all(r.validate(maximum) for r in ranges)
group_by_type[Constraint.GeneralRange] = [
Constraint.GeneralRange(
minimum=minimum,
maximum=maximum,
includes_maximum=includes_maximum,
includes_minimum=includes_minimum,
)
]

return tuple(
[constraint for group in group_by_type.values() for constraint in group]
)
Loading

0 comments on commit 3bfea65

Please sign in to comment.