Skip to content

Commit

Permalink
fix: Pass appropriate kwargs to dummy data generators
Browse files Browse the repository at this point in the history
  • Loading branch information
rebkwok committed Dec 19, 2024
1 parent 325a576 commit 55c3f32
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 17 deletions.
23 changes: 11 additions & 12 deletions ehrql/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,8 @@ def generate_dataset_with_dummy_data(
query_engine = LocalFileQueryEngine(dummy_tables_path)
results = query_engine.get_results(variable_definitions)
else:
generator = get_dummy_data_class(dummy_data_config)(
variable_definitions,
population_size=dummy_data_config.population_size,
timeout=dummy_data_config.timeout,
)
generator_class, kwargs = get_dummy_data_class(dummy_data_config)
generator = generator_class(variable_definitions, **kwargs)
results = generator.get_results()

log.info("Building dataset and writing results")
Expand All @@ -150,11 +147,8 @@ def create_dummy_tables(definition_file, dummy_tables_path, user_args, environ):
variable_definitions, dummy_data_config = load_dataset_definition(
definition_file, user_args, environ
)
generator = get_dummy_data_class(dummy_data_config)(
variable_definitions,
population_size=dummy_data_config.population_size,
timeout=dummy_data_config.timeout,
)
generator_class, kwargs = get_dummy_data_class(dummy_data_config)
generator = generator_class(variable_definitions, **kwargs)
table_data = generator.get_data()

directory, extension = split_directory_and_extension(dummy_tables_path)
Expand All @@ -168,9 +162,14 @@ def create_dummy_tables(definition_file, dummy_tables_path, user_args, environ):

def get_dummy_data_class(dummy_data_config):
if dummy_data_config.legacy:
return DummyDataGenerator
kwargs = {
"population_size": dummy_data_config.population_size,
"timeout": dummy_data_config.timeout,
}
return DummyDataGenerator, kwargs
else:
return NextGenDummyDataGenerator
kwargs = {"configuration": dummy_data_config}
return NextGenDummyDataGenerator, kwargs


def dump_dataset_sql(
Expand Down
14 changes: 9 additions & 5 deletions tests/acceptance/test_embedded_study.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,16 +186,20 @@ def test_generate_dummy_data_with_dummy_tables(study, tmp_path):


@pytest.mark.parametrize(
"dataset_definition_fixture",
"dataset_definition_fixture,expected_columns",
(
trivial_dataset_definition,
trivial_dataset_definition_legacy_dummy_data,
# this dataset definition includes date_of_death
# in the additional population constraints only
(trivial_dataset_definition, "patient_id,date_of_birth,date_of_death"),
(trivial_dataset_definition_legacy_dummy_data, "patient_id,date_of_birth"),
),
)
def test_create_dummy_tables(study, tmp_path, dataset_definition_fixture):
def test_create_dummy_tables(
study, tmp_path, dataset_definition_fixture, expected_columns
):
dummy_tables_path = tmp_path / "subdir" / "dummy_data"
study.setup_from_string(dataset_definition_fixture)
study.create_dummy_tables(dummy_tables_path)
lines = (dummy_tables_path / "patients.csv").read_text().splitlines()
assert lines[0] == "patient_id,date_of_birth"
assert lines[0] == expected_columns
assert len(lines) == 11 # 1 header, 10 rows

0 comments on commit 55c3f32

Please sign in to comment.