Skip to content

Commit

Permalink
fix: Pass appropriate kwargs to dummy data generators
Browse files Browse the repository at this point in the history
  • Loading branch information
rebkwok committed Dec 19, 2024
1 parent 325a576 commit e83d013
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 18 deletions.
24 changes: 11 additions & 13 deletions ehrql/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,7 @@ def generate_dataset_with_dummy_data(
query_engine = LocalFileQueryEngine(dummy_tables_path)
results = query_engine.get_results(variable_definitions)
else:
generator = get_dummy_data_class(dummy_data_config)(
variable_definitions,
population_size=dummy_data_config.population_size,
timeout=dummy_data_config.timeout,
)
generator = get_dummy_data_generator(variable_definitions, dummy_data_config)
results = generator.get_results()

log.info("Building dataset and writing results")
Expand All @@ -150,11 +146,7 @@ def create_dummy_tables(definition_file, dummy_tables_path, user_args, environ):
variable_definitions, dummy_data_config = load_dataset_definition(
definition_file, user_args, environ
)
generator = get_dummy_data_class(dummy_data_config)(
variable_definitions,
population_size=dummy_data_config.population_size,
timeout=dummy_data_config.timeout,
)
generator = get_dummy_data_generator(variable_definitions, dummy_data_config)
table_data = generator.get_data()

directory, extension = split_directory_and_extension(dummy_tables_path)
Expand All @@ -166,11 +158,17 @@ def create_dummy_tables(definition_file, dummy_tables_path, user_args, environ):
write_tables(dummy_tables_path, table_data.values(), table_specs)


def get_dummy_data_class(dummy_data_config):
def get_dummy_data_generator(variable_definitions, dummy_data_config):
if dummy_data_config.legacy:
return DummyDataGenerator
return DummyDataGenerator(
variable_definitions,
population_size=dummy_data_config.population_size,
timeout=dummy_data_config.timeout,
)
else:
return NextGenDummyDataGenerator
return NextGenDummyDataGenerator(
variable_definitions, configuration=dummy_data_config
)


def dump_dataset_sql(
Expand Down
14 changes: 9 additions & 5 deletions tests/acceptance/test_embedded_study.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,16 +186,20 @@ def test_generate_dummy_data_with_dummy_tables(study, tmp_path):


@pytest.mark.parametrize(
"dataset_definition_fixture",
"dataset_definition_fixture,expected_columns",
(
trivial_dataset_definition,
trivial_dataset_definition_legacy_dummy_data,
# this dataset definition includes date_of_death
# in the additional population constraints only
(trivial_dataset_definition, "patient_id,date_of_birth,date_of_death"),
(trivial_dataset_definition_legacy_dummy_data, "patient_id,date_of_birth"),
),
)
def test_create_dummy_tables(study, tmp_path, dataset_definition_fixture):
def test_create_dummy_tables(
study, tmp_path, dataset_definition_fixture, expected_columns
):
dummy_tables_path = tmp_path / "subdir" / "dummy_data"
study.setup_from_string(dataset_definition_fixture)
study.create_dummy_tables(dummy_tables_path)
lines = (dummy_tables_path / "patients.csv").read_text().splitlines()
assert lines[0] == "patient_id,date_of_birth"
assert lines[0] == expected_columns
assert len(lines) == 11 # 1 header, 10 rows

0 comments on commit e83d013

Please sign in to comment.