diff --git a/ehrql/main.py b/ehrql/main.py index 5df83fd25..95782094c 100644 --- a/ehrql/main.py +++ b/ehrql/main.py @@ -133,11 +133,8 @@ def generate_dataset_with_dummy_data( query_engine = LocalFileQueryEngine(dummy_tables_path) results = query_engine.get_results(variable_definitions) else: - generator = get_dummy_data_class(dummy_data_config)( - variable_definitions, - population_size=dummy_data_config.population_size, - timeout=dummy_data_config.timeout, - ) + generator_class, kwargs = get_dummy_data_class(dummy_data_config) + generator = generator_class(variable_definitions, **kwargs) results = generator.get_results() log.info("Building dataset and writing results") @@ -150,11 +147,8 @@ def create_dummy_tables(definition_file, dummy_tables_path, user_args, environ): variable_definitions, dummy_data_config = load_dataset_definition( definition_file, user_args, environ ) - generator = get_dummy_data_class(dummy_data_config)( - variable_definitions, - population_size=dummy_data_config.population_size, - timeout=dummy_data_config.timeout, - ) + generator_class, kwargs = get_dummy_data_class(dummy_data_config) + generator = generator_class(variable_definitions, **kwargs) table_data = generator.get_data() directory, extension = split_directory_and_extension(dummy_tables_path) @@ -168,9 +162,14 @@ def create_dummy_tables(definition_file, dummy_tables_path, user_args, environ): def get_dummy_data_class(dummy_data_config): if dummy_data_config.legacy: - return DummyDataGenerator + kwargs = { + "population_size": dummy_data_config.population_size, + "timeout": dummy_data_config.timeout, + } + return DummyDataGenerator, kwargs else: - return NextGenDummyDataGenerator + kwargs = {"configuration": dummy_data_config} + return NextGenDummyDataGenerator, kwargs def dump_dataset_sql( diff --git a/tests/acceptance/test_embedded_study.py b/tests/acceptance/test_embedded_study.py index 8a6212d77..567d3711c 100644 --- a/tests/acceptance/test_embedded_study.py +++ b/tests/acceptance/test_embedded_study.py @@ -186,16 +186,20 @@ def test_generate_dummy_data_with_dummy_tables(study, tmp_path): @pytest.mark.parametrize( - "dataset_definition_fixture", + "dataset_definition_fixture,expected_columns", ( - trivial_dataset_definition, - trivial_dataset_definition_legacy_dummy_data, + # this dataset definition includes date_of_death + # in the additional population constraints only + (trivial_dataset_definition, "patient_id,date_of_birth,date_of_death"), + (trivial_dataset_definition_legacy_dummy_data, "patient_id,date_of_birth"), ), ) -def test_create_dummy_tables(study, tmp_path, dataset_definition_fixture): +def test_create_dummy_tables( + study, tmp_path, dataset_definition_fixture, expected_columns +): dummy_tables_path = tmp_path / "subdir" / "dummy_data" study.setup_from_string(dataset_definition_fixture) study.create_dummy_tables(dummy_tables_path) lines = (dummy_tables_path / "patients.csv").read_text().splitlines() - assert lines[0] == "patient_id,date_of_birth" + assert lines[0] == expected_columns assert len(lines) == 11 # 1 header, 10 rows