From 3bfea65de1532d33ef485df9c363467771ae6c05 Mon Sep 17 00:00:00 2001 From: "David R. MacIver" Date: Tue, 22 Oct 2024 17:52:07 +0100 Subject: [PATCH] Move changes to dummy_data into dummy_data_nextgen --- ehrql/dummy_data/generator.py | 78 +------- ehrql/dummy_data/query_info.py | 183 +----------------- ehrql/dummy_data_nextgen/generator.py | 78 +++++++- ehrql/dummy_data_nextgen/query_info.py | 183 +++++++++++++++++- .../test_query_to_constraints.py | 2 +- .../test_specific_datasets.py | 16 +- 6 files changed, 270 insertions(+), 270 deletions(-) rename tests/unit/{dummy_data => dummy_data_nextgen}/test_query_to_constraints.py (98%) rename tests/unit/{dummy_data => dummy_data_nextgen}/test_specific_datasets.py (95%) diff --git a/ehrql/dummy_data/generator.py b/ehrql/dummy_data/generator.py index 4e01c6008..7c6bcbd4f 100644 --- a/ehrql/dummy_data/generator.py +++ b/ehrql/dummy_data/generator.py @@ -173,47 +173,14 @@ def get_patient_data(self, patient_id, table_names): ] return data - def get_patient_column(self, column_name): - for table_name in self.query_info.population_table_names: - try: - return self.query_info.tables[table_name].columns[column_name] - except KeyError: - pass - def generate_patient_facts(self, patient_id): # Seed the random generator using the patient_id so we always generate the same # data for the same patient self.rnd.seed(f"{self.random_seed}:{patient_id}") # TODO: We could obviously generate more realistic age distributions than this - - while True: - # Retry until we have a date of birth and date of death that are - # within reasonable ranges - dob_column = self.get_patient_column("date_of_birth") - if dob_column is not None and dob_column.get_constraint( - Constraint.GeneralRange - ): - self.events_start = self.today - timedelta(days=120 * 365) - self.events_end = self.today - date_of_birth = self.get_random_value(dob_column) - else: - date_of_birth = self.today - timedelta( - days=self.rnd.randrange(0, 120 * 365) - ) - - dod_column = self.get_patient_column("date_of_death") - if dod_column is not None and dod_column.get_constraint( - Constraint.GeneralRange - ): - date_of_death = self.get_random_value(dod_column) - else: - age_days = self.rnd.randrange(105 * 365) - date_of_death = date_of_birth + timedelta(days=age_days) - - if date_of_death >= date_of_birth and ( - date_of_death - date_of_birth < timedelta(105 * 365) - ): - break + date_of_birth = self.today - timedelta(days=self.rnd.randrange(0, 120 * 365)) + age_days = self.rnd.randrange(105 * 365) + date_of_death = date_of_birth + timedelta(days=age_days) self.date_of_birth = date_of_birth self.date_of_death = date_of_death if date_of_death < self.today else None @@ -268,45 +235,6 @@ def get_random_value(self, column_info): range_constraint.maximum + 1, range_constraint.step, ) - elif (column_info.type is date) and ( - date_range_constraint := column_info.get_constraint(Constraint.GeneralRange) - ): - if date_range_constraint.maximum is not None: - maximum = date_range_constraint.maximum - else: - maximum = self.today - - if not date_range_constraint.includes_maximum: - maximum -= timedelta(days=1) - if date_range_constraint.minimum is not None: - minimum = date_range_constraint.minimum - if not date_range_constraint.includes_minimum: - minimum += timedelta(days=1) - # TODO: Currently this code only runs when the column is date_of_birth - # so condition is always hit. Remove this pragma when that stops being - # the case. - if column_info.get_constraint( - Constraint.FirstOfMonth - ): # pragma: no branch - if minimum.month == 12: - minimum = minimum.replace(year=minimum.year + 1, month=1, day=1) - else: - minimum = minimum.replace(month=minimum.month + 1, day=1) - else: - minimum = (maximum - timedelta(days=100 * 365)).replace(day=1) - - assert minimum <= maximum - - days = (maximum - minimum).days - result = minimum + timedelta(days=random.randint(0, days)) - # TODO: Currently this code only runs when the column is date_of_birth - # so condition is always hit. Remove this pragma when that stops being - # the case. - if column_info.get_constraint(Constraint.FirstOfMonth): # pragma: no branch - assert minimum.day == 1 - result = result.replace(day=1) - assert minimum <= result <= maximum - return result elif column_info.values_used: if self.rnd.randint(0, len(column_info.values_used)) != 0: return self.rnd.choice(column_info.values_used) diff --git a/ehrql/dummy_data/query_info.py b/ehrql/dummy_data/query_info.py index 159b9022c..455d0f2f3 100644 --- a/ehrql/dummy_data/query_info.py +++ b/ehrql/dummy_data/query_info.py @@ -1,7 +1,6 @@ import dataclasses from collections import defaultdict -from datetime import date, timedelta -from functools import cached_property, lru_cache +from functools import cached_property from ehrql.query_model.introspection import all_unique_nodes, get_table_nodes from ehrql.query_model.nodes import ( @@ -15,7 +14,6 @@ Value, get_root_frame, ) -from ehrql.query_model.table_schema import Constraint @dataclasses.dataclass @@ -30,17 +28,11 @@ class ColumnInfo: _values_used: set = dataclasses.field(default_factory=set) @classmethod - def from_column(cls, name, column, extra_constraints=()): + def from_column(cls, name, column): type_ = column.type_ if hasattr(type_, "_primitive_type"): type_ = type_._primitive_type() - return cls( - name, - type_, - constraints=normalize_constraints( - tuple(column.constraints) + tuple(extra_constraints) - ), - ) + return cls(name, type_, constraints=tuple(column.constraints)) def __post_init__(self): self._constraints_by_type = {type(c): c for c in self.constraints} @@ -109,10 +101,6 @@ def from_variable_definitions(cls, variable_definitions): all_nodes = all_unique_nodes(*variable_definitions.values()) by_type = get_nodes_by_type(all_nodes) - extra_constraints = query_to_column_constraints( - variable_definitions["population"] - ) - tables = { # Create a TableInfo object … table.name: TableInfo.from_table(table) @@ -140,9 +128,7 @@ def from_variable_definitions(cls, variable_definitions): if column_info is None: # … insert a ColumnInfo object into the appropriate table column_info = ColumnInfo.from_column( - name, - table.schema.get_column(name), - extra_constraints=extra_constraints.get(column, ()), + name, table.schema.get_column(name) ) table_info.columns[name] = column_info # Record the ColumnInfo object associated with each SelectColumn node @@ -190,164 +176,3 @@ def get_nodes_by_type(nodes): def sort_by_name(iterable): return sorted(iterable, key=lambda i: i.name) - - -@lru_cache -def query_to_column_constraints(query): - """Converts a query (typically a population definition) into - constraints that would have to be applied to a record in order - to satisfy it.""" - match query: - case Function.And(lhs=lhs, rhs=rhs): - left = query_to_column_constraints(lhs) - right = query_to_column_constraints(rhs) - keys = set(left) | set(right) - return {k: left.get(k, []) + right.get(k, []) for k in keys} - case Function.Or(lhs=lhs, rhs=rhs): - left = query_to_column_constraints(lhs) - right = query_to_column_constraints(rhs) - result = {} - for k, v in left.items(): - try: - result[k] = list(set(v) & set(right[k])) - except KeyError: - pass - for k, v in list(result.items()): - if not v: - del result[k] - return result - case Function.EQ( - lhs=SelectColumn() as lhs, - rhs=Value(value=value), - ): - return {lhs: [Constraint.Categorical(values=(value,))]} - case Function.EQ( - lhs=Function.YearFromDate(source=SelectColumn() as column), - rhs=Value(value=year), - ): - return { - column: [ - Constraint.GeneralRange( - minimum=date(year, 1, 1), - maximum=date(year, 12, 31), - ) - ] - } - case Function.In( - lhs=SelectColumn() as lhs, - rhs=Value(value=values), - ): - return {lhs: [Constraint.Categorical(values=values)]} - case Function.GE( - lhs=Function.DateDifferenceInYears( - lhs=Value(value=reference_date), rhs=column - ), - rhs=Value(value=difference), - ): - return { - column: [ - Constraint.GeneralRange( - maximum=reference_date - timedelta(days=365 * difference) - ) - ] - } - case Function.LE( - lhs=Function.DateDifferenceInYears( - lhs=Value(value=reference_date), rhs=column - ), - rhs=Value(value=difference), - ): - return { - column: [ - Constraint.GeneralRange( - minimum=reference_date.replace( - year=reference_date.year - difference - ) - ) - ] - } - case Function.LT( - lhs=Function.DateAddYears( - lhs=SelectColumn() as column, - rhs=Value(value=difference), - ), - rhs=Value(value=reference_date), - ): - return { - column: [ - Constraint.GeneralRange( - maximum=reference_date.replace( - year=reference_date.year - difference - ), - includes_maximum=False, - ) - ] - } - case Function.GT(lhs=SelectColumn() as column, rhs=Value(value=min_value)): - return { - column: [ - Constraint.GeneralRange(minimum=min_value, includes_minimum=False) - ] - } - case Function.GE(lhs=SelectColumn() as column, rhs=Value(value=min_value)): - return { - column: [ - Constraint.GeneralRange(minimum=min_value, includes_minimum=True) - ] - } - case Function.LT(lhs=SelectColumn() as column, rhs=Value(value=max_value)): - return { - column: [ - Constraint.GeneralRange(maximum=max_value, includes_maximum=False) - ] - } - case Function.LE(lhs=SelectColumn() as column, rhs=Value(value=max_value)): - return { - column: [ - Constraint.GeneralRange(maximum=max_value, includes_maximum=True) - ] - } - case Function.IsNull(source=SelectColumn() as column): - return {column: [Constraint.NotNull()]} - - return {} - - -def normalize_constraints(constraints): - group_by_type = defaultdict(list) - for constraint in constraints: - group_by_type[type(constraint)].append(constraint) - if len(group_by_type[Constraint.Categorical]) > 1: - constraint, *rest = group_by_type[Constraint.Categorical] - for more in rest: - constraint = Constraint.Categorical( - values=set(constraint.values) & set(more.values) - ) - group_by_type[Constraint.Categorical] = [constraint] - if len(ranges := group_by_type[Constraint.GeneralRange]) > 1: - minimum = None - maximum = None - for r in ranges: - if minimum is None: - minimum = r.minimum - elif r.minimum is not None: - minimum = max(minimum, r.minimum) - if maximum is None: - maximum = r.maximum - elif r.maximum is not None: - maximum = min(maximum, r.maximum) - - includes_minimum = minimum is None or all(r.validate(minimum) for r in ranges) - includes_maximum = maximum is None or all(r.validate(maximum) for r in ranges) - group_by_type[Constraint.GeneralRange] = [ - Constraint.GeneralRange( - minimum=minimum, - maximum=maximum, - includes_maximum=includes_maximum, - includes_minimum=includes_minimum, - ) - ] - - return tuple( - [constraint for group in group_by_type.values() for constraint in group] - ) diff --git a/ehrql/dummy_data_nextgen/generator.py b/ehrql/dummy_data_nextgen/generator.py index 324f1f49e..fe756d1ad 100644 --- a/ehrql/dummy_data_nextgen/generator.py +++ b/ehrql/dummy_data_nextgen/generator.py @@ -173,14 +173,47 @@ def get_patient_data(self, patient_id, table_names): ] return data + def get_patient_column(self, column_name): + for table_name in self.query_info.population_table_names: + try: + return self.query_info.tables[table_name].columns[column_name] + except KeyError: + pass + def generate_patient_facts(self, patient_id): # Seed the random generator using the patient_id so we always generate the same # data for the same patient self.rnd.seed(f"{self.random_seed}:{patient_id}") # TODO: We could obviously generate more realistic age distributions than this - date_of_birth = self.today - timedelta(days=self.rnd.randrange(0, 120 * 365)) - age_days = self.rnd.randrange(105 * 365) - date_of_death = date_of_birth + timedelta(days=age_days) + + while True: + # Retry until we have a date of birth and date of death that are + # within reasonable ranges + dob_column = self.get_patient_column("date_of_birth") + if dob_column is not None and dob_column.get_constraint( + Constraint.GeneralRange + ): + self.events_start = self.today - timedelta(days=120 * 365) + self.events_end = self.today + date_of_birth = self.get_random_value(dob_column) + else: + date_of_birth = self.today - timedelta( + days=self.rnd.randrange(0, 120 * 365) + ) + + dod_column = self.get_patient_column("date_of_death") + if dod_column is not None and dod_column.get_constraint( + Constraint.GeneralRange + ): + date_of_death = self.get_random_value(dod_column) + else: + age_days = self.rnd.randrange(105 * 365) + date_of_death = date_of_birth + timedelta(days=age_days) + + if date_of_death >= date_of_birth and ( + date_of_death - date_of_birth < timedelta(105 * 365) + ): + break self.date_of_birth = date_of_birth self.date_of_death = date_of_death if date_of_death < self.today else None @@ -235,6 +268,45 @@ def get_random_value(self, column_info): range_constraint.maximum + 1, range_constraint.step, ) + elif (column_info.type is date) and ( + date_range_constraint := column_info.get_constraint(Constraint.GeneralRange) + ): + if date_range_constraint.maximum is not None: + maximum = date_range_constraint.maximum + else: + maximum = self.today + + if not date_range_constraint.includes_maximum: + maximum -= timedelta(days=1) + if date_range_constraint.minimum is not None: + minimum = date_range_constraint.minimum + if not date_range_constraint.includes_minimum: + minimum += timedelta(days=1) + # TODO: Currently this code only runs when the column is date_of_birth + # so condition is always hit. Remove this pragma when that stops being + # the case. + if column_info.get_constraint( + Constraint.FirstOfMonth + ): # pragma: no branch + if minimum.month == 12: + minimum = minimum.replace(year=minimum.year + 1, month=1, day=1) + else: + minimum = minimum.replace(month=minimum.month + 1, day=1) + else: + minimum = (maximum - timedelta(days=100 * 365)).replace(day=1) + + assert minimum <= maximum + + days = (maximum - minimum).days + result = minimum + timedelta(days=random.randint(0, days)) + # TODO: Currently this code only runs when the column is date_of_birth + # so condition is always hit. Remove this pragma when that stops being + # the case. + if column_info.get_constraint(Constraint.FirstOfMonth): # pragma: no branch + assert minimum.day == 1 + result = result.replace(day=1) + assert minimum <= result <= maximum + return result elif column_info.values_used: if self.rnd.randint(0, len(column_info.values_used)) != 0: return self.rnd.choice(column_info.values_used) diff --git a/ehrql/dummy_data_nextgen/query_info.py b/ehrql/dummy_data_nextgen/query_info.py index 455d0f2f3..159b9022c 100644 --- a/ehrql/dummy_data_nextgen/query_info.py +++ b/ehrql/dummy_data_nextgen/query_info.py @@ -1,6 +1,7 @@ import dataclasses from collections import defaultdict -from functools import cached_property +from datetime import date, timedelta +from functools import cached_property, lru_cache from ehrql.query_model.introspection import all_unique_nodes, get_table_nodes from ehrql.query_model.nodes import ( @@ -14,6 +15,7 @@ Value, get_root_frame, ) +from ehrql.query_model.table_schema import Constraint @dataclasses.dataclass @@ -28,11 +30,17 @@ class ColumnInfo: _values_used: set = dataclasses.field(default_factory=set) @classmethod - def from_column(cls, name, column): + def from_column(cls, name, column, extra_constraints=()): type_ = column.type_ if hasattr(type_, "_primitive_type"): type_ = type_._primitive_type() - return cls(name, type_, constraints=tuple(column.constraints)) + return cls( + name, + type_, + constraints=normalize_constraints( + tuple(column.constraints) + tuple(extra_constraints) + ), + ) def __post_init__(self): self._constraints_by_type = {type(c): c for c in self.constraints} @@ -101,6 +109,10 @@ def from_variable_definitions(cls, variable_definitions): all_nodes = all_unique_nodes(*variable_definitions.values()) by_type = get_nodes_by_type(all_nodes) + extra_constraints = query_to_column_constraints( + variable_definitions["population"] + ) + tables = { # Create a TableInfo object … table.name: TableInfo.from_table(table) @@ -128,7 +140,9 @@ def from_variable_definitions(cls, variable_definitions): if column_info is None: # … insert a ColumnInfo object into the appropriate table column_info = ColumnInfo.from_column( - name, table.schema.get_column(name) + name, + table.schema.get_column(name), + extra_constraints=extra_constraints.get(column, ()), ) table_info.columns[name] = column_info # Record the ColumnInfo object associated with each SelectColumn node @@ -176,3 +190,164 @@ def get_nodes_by_type(nodes): def sort_by_name(iterable): return sorted(iterable, key=lambda i: i.name) + + +@lru_cache +def query_to_column_constraints(query): + """Converts a query (typically a population definition) into + constraints that would have to be applied to a record in order + to satisfy it.""" + match query: + case Function.And(lhs=lhs, rhs=rhs): + left = query_to_column_constraints(lhs) + right = query_to_column_constraints(rhs) + keys = set(left) | set(right) + return {k: left.get(k, []) + right.get(k, []) for k in keys} + case Function.Or(lhs=lhs, rhs=rhs): + left = query_to_column_constraints(lhs) + right = query_to_column_constraints(rhs) + result = {} + for k, v in left.items(): + try: + result[k] = list(set(v) & set(right[k])) + except KeyError: + pass + for k, v in list(result.items()): + if not v: + del result[k] + return result + case Function.EQ( + lhs=SelectColumn() as lhs, + rhs=Value(value=value), + ): + return {lhs: [Constraint.Categorical(values=(value,))]} + case Function.EQ( + lhs=Function.YearFromDate(source=SelectColumn() as column), + rhs=Value(value=year), + ): + return { + column: [ + Constraint.GeneralRange( + minimum=date(year, 1, 1), + maximum=date(year, 12, 31), + ) + ] + } + case Function.In( + lhs=SelectColumn() as lhs, + rhs=Value(value=values), + ): + return {lhs: [Constraint.Categorical(values=values)]} + case Function.GE( + lhs=Function.DateDifferenceInYears( + lhs=Value(value=reference_date), rhs=column + ), + rhs=Value(value=difference), + ): + return { + column: [ + Constraint.GeneralRange( + maximum=reference_date - timedelta(days=365 * difference) + ) + ] + } + case Function.LE( + lhs=Function.DateDifferenceInYears( + lhs=Value(value=reference_date), rhs=column + ), + rhs=Value(value=difference), + ): + return { + column: [ + Constraint.GeneralRange( + minimum=reference_date.replace( + year=reference_date.year - difference + ) + ) + ] + } + case Function.LT( + lhs=Function.DateAddYears( + lhs=SelectColumn() as column, + rhs=Value(value=difference), + ), + rhs=Value(value=reference_date), + ): + return { + column: [ + Constraint.GeneralRange( + maximum=reference_date.replace( + year=reference_date.year - difference + ), + includes_maximum=False, + ) + ] + } + case Function.GT(lhs=SelectColumn() as column, rhs=Value(value=min_value)): + return { + column: [ + Constraint.GeneralRange(minimum=min_value, includes_minimum=False) + ] + } + case Function.GE(lhs=SelectColumn() as column, rhs=Value(value=min_value)): + return { + column: [ + Constraint.GeneralRange(minimum=min_value, includes_minimum=True) + ] + } + case Function.LT(lhs=SelectColumn() as column, rhs=Value(value=max_value)): + return { + column: [ + Constraint.GeneralRange(maximum=max_value, includes_maximum=False) + ] + } + case Function.LE(lhs=SelectColumn() as column, rhs=Value(value=max_value)): + return { + column: [ + Constraint.GeneralRange(maximum=max_value, includes_maximum=True) + ] + } + case Function.IsNull(source=SelectColumn() as column): + return {column: [Constraint.NotNull()]} + + return {} + + +def normalize_constraints(constraints): + group_by_type = defaultdict(list) + for constraint in constraints: + group_by_type[type(constraint)].append(constraint) + if len(group_by_type[Constraint.Categorical]) > 1: + constraint, *rest = group_by_type[Constraint.Categorical] + for more in rest: + constraint = Constraint.Categorical( + values=set(constraint.values) & set(more.values) + ) + group_by_type[Constraint.Categorical] = [constraint] + if len(ranges := group_by_type[Constraint.GeneralRange]) > 1: + minimum = None + maximum = None + for r in ranges: + if minimum is None: + minimum = r.minimum + elif r.minimum is not None: + minimum = max(minimum, r.minimum) + if maximum is None: + maximum = r.maximum + elif r.maximum is not None: + maximum = min(maximum, r.maximum) + + includes_minimum = minimum is None or all(r.validate(minimum) for r in ranges) + includes_maximum = maximum is None or all(r.validate(maximum) for r in ranges) + group_by_type[Constraint.GeneralRange] = [ + Constraint.GeneralRange( + minimum=minimum, + maximum=maximum, + includes_maximum=includes_maximum, + includes_minimum=includes_minimum, + ) + ] + + return tuple( + [constraint for group in group_by_type.values() for constraint in group] + ) diff --git a/tests/unit/dummy_data/test_query_to_constraints.py b/tests/unit/dummy_data_nextgen/test_query_to_constraints.py similarity index 98% rename from tests/unit/dummy_data/test_query_to_constraints.py rename to tests/unit/dummy_data_nextgen/test_query_to_constraints.py index 568fd8a75..e69e2a84c 100644 --- a/tests/unit/dummy_data/test_query_to_constraints.py +++ b/tests/unit/dummy_data_nextgen/test_query_to_constraints.py @@ -1,7 +1,7 @@ from datetime import date from ehrql import create_dataset, years -from ehrql.dummy_data.query_info import ( +from ehrql.dummy_data_nextgen.query_info import ( normalize_constraints, query_to_column_constraints, ) diff --git a/tests/unit/dummy_data/test_specific_datasets.py b/tests/unit/dummy_data_nextgen/test_specific_datasets.py similarity index 95% rename from tests/unit/dummy_data/test_specific_datasets.py rename to tests/unit/dummy_data_nextgen/test_specific_datasets.py index a98fd17f7..d9a4a7a49 100644 --- a/tests/unit/dummy_data/test_specific_datasets.py +++ b/tests/unit/dummy_data_nextgen/test_specific_datasets.py @@ -6,13 +6,13 @@ from hypothesis import strategies as st from ehrql import create_dataset, years -from ehrql.dummy_data.generator import DummyDataGenerator +from ehrql.dummy_data_nextgen.generator import DummyDataGenerator from ehrql.query_language import compile from ehrql.tables.core import patients @pytest.mark.parametrize("sex", ["male", "female", "intersex"]) -@mock.patch("ehrql.dummy_data.generator.time") +@mock.patch("ehrql.dummy_data_nextgen.generator.time") def test_can_generate_single_sex_data_in_one_shot(patched_time, sex): dataset = create_dataset() @@ -37,7 +37,7 @@ def test_can_generate_single_sex_data_in_one_shot(patched_time, sex): assert len(data_for_table) == target_size -@mock.patch("ehrql.dummy_data.generator.time") +@mock.patch("ehrql.dummy_data_nextgen.generator.time") def test_can_generate_patients_from_a_specific_year(patched_time): dataset = create_dataset() @@ -62,7 +62,7 @@ def test_can_generate_patients_from_a_specific_year(patched_time): assert len(data_for_table) == target_size -@mock.patch("ehrql.dummy_data.generator.time") +@mock.patch("ehrql.dummy_data_nextgen.generator.time") def test_can_combine_constraints_on_generated_data(patched_time): dataset = create_dataset() @@ -89,7 +89,7 @@ def test_can_combine_constraints_on_generated_data(patched_time): assert len(data_for_table) == target_size -@mock.patch("ehrql.dummy_data.generator.time") +@mock.patch("ehrql.dummy_data_nextgen.generator.time") def test_will_satisfy_constraints_on_both_sides_of_an_or(patched_time): dataset = create_dataset() @@ -117,7 +117,7 @@ def test_will_satisfy_constraints_on_both_sides_of_an_or(patched_time): assert len(data_for_table) > 0 -@mock.patch("ehrql.dummy_data.generator.time") +@mock.patch("ehrql.dummy_data_nextgen.generator.time") def test_basic_patient_constraints_age_and_sex(patched_time): index_date = "2023-10-01" @@ -197,7 +197,7 @@ def birthday_range_query(draw): @example(query=patients.date_of_birth < date(1900, 12, 31), target_size=1000) @example(query=patients.date_of_birth >= date(1900, 1, 2), target_size=1000) @given(query=birthday_range_query(), target_size=st.integers(1, 1000)) -@mock.patch("ehrql.dummy_data.generator.time") +@mock.patch("ehrql.dummy_data_nextgen.generator.time") def test_combined_age_range_in_one_shot(patched_time, query, target_size): dataset = create_dataset() @@ -222,7 +222,7 @@ def test_combined_age_range_in_one_shot(patched_time, query, target_size): assert len(data_for_table) == target_size -@mock.patch("ehrql.dummy_data.generator.time") +@mock.patch("ehrql.dummy_data_nextgen.generator.time") def test_date_arithmetic_comparison(patched_time): dataset = create_dataset()