diff --git a/sgkit/io/vcf/vcf_reader.py b/sgkit/io/vcf/vcf_reader.py index 6048543b1..13e914e26 100644 --- a/sgkit/io/vcf/vcf_reader.py +++ b/sgkit/io/vcf/vcf_reader.py @@ -1,8 +1,19 @@ import functools import itertools from contextlib import contextmanager +from dataclasses import dataclass from pathlib import Path -from typing import Dict, Hashable, Iterator, MutableMapping, Optional, Sequence, Union +from typing import ( + Any, + Dict, + Hashable, + Iterator, + MutableMapping, + Optional, + Sequence, + Tuple, + Union, +) import dask import fsspec @@ -14,7 +25,7 @@ from sgkit.io.vcf import partition_into_regions from sgkit.io.vcf.utils import build_url, chunks, temporary_directory, url_filename from sgkit.model import DIM_SAMPLE, DIM_VARIANT, create_genotype_call_dataset -from sgkit.typing import DType, PathType +from sgkit.typing import ArrayLike, DType, PathType DEFAULT_ALT_NUMBER = 3 # see vcf_read.py in scikit_allel @@ -61,6 +72,69 @@ def vcf_type_to_numpy_type(vcf_type: str, field_type: str, field: str) -> DType: ) +@dataclass +class VcfFieldHandler: + category: str + key: str + variable_name: str + description: str + dims: Sequence[str] + chunksize: Tuple[int, ...] + array: ArrayLike + + @classmethod + def for_field(cls, vcf: VCF, field: str, chunk_length: int) -> "VcfFieldHandler": + if not any(field.startswith(prefix) for prefix in ["INFO/", "FORMAT/"]): + raise ValueError("VCF field must be prefixed with 'INFO/' or 'FORMAT/'") + category = field.split("/")[0] + field_defs = { + h["ID"]: h.info(extra=True) + for h in vcf.header_iter() + if h["HeaderType"] == category + } + key = field[len(f"{category}/") :] + if key not in field_defs: + raise ValueError(f"{category} field '{key}' is not defined in the header.") + description = field_defs[key]["Description"].strip('"') + vcf_type = field_defs[key]["Type"] + dtype = vcf_type_to_numpy_type(vcf_type, category, key) + vcf_number = field_defs[key]["Number"] + if vcf_number not in ("0", "1"): + raise ValueError( + f"{category} field '{key}' is defined as Number '{vcf_number}', which is not supported." + ) + chunksize: Tuple[int, ...] + if category == "INFO": + variable_name = f"variant_{key}" + dims = [DIM_VARIANT] + chunksize = (chunk_length,) + array = np.full(chunksize, -1, dtype=dtype) + elif category == "FORMAT": + variable_name = f"call_{key}" + dims = [DIM_VARIANT, DIM_SAMPLE] + n_sample = len(vcf.samples) + chunksize = (chunk_length, n_sample) + array = np.full(chunksize, -1, dtype=dtype) + + return cls(category, key, variable_name, description, dims, chunksize, array) + + def add_variant(self, i: int, variant: Any) -> None: + key = self.key + if self.category == "INFO": + if key in [k for k, _ in variant.INFO]: + self.array[i] = variant.INFO[key] + elif self.category == "FORMAT": + if key in variant.FORMAT: + self.array[i, ...] = variant.format(key)[..., 0] + + def truncate_array(self, length: int) -> None: + self.array = self.array[:length] + + def update_dataset(self, ds: xr.Dataset) -> None: + ds[self.variable_name] = (self.dims, self.array) + ds[self.variable_name].attrs["comment"] = self.description + + def vcf_to_zarr_sequential( input: PathType, output: Union[PathType, MutableMapping[str, bytes]], @@ -70,8 +144,7 @@ def vcf_to_zarr_sequential( ploidy: int = 2, mixed_ploidy: bool = False, truncate_calls: bool = False, - info_fields: Optional[Sequence[str]] = None, - format_fields: Optional[Sequence[str]] = None, + fields: Optional[Sequence[str]] = None, ) -> None: with open_vcf(input) as vcf: @@ -100,51 +173,10 @@ def vcf_to_zarr_sequential( call_genotype = np.empty((chunk_length, n_sample, ploidy), dtype="i1") call_genotype_phased = np.empty((chunk_length, n_sample), dtype=bool) - info_fields = info_fields or [] - info_field_arrays = {} - info_field_defs = { - h["ID"]: h.info(extra=True) - for h in vcf.header_iter() - if h["HeaderType"] == "INFO" - } - for info_field in info_fields: - if info_field not in info_field_defs: - raise ValueError( - f"INFO field '{info_field}' is not defined in the header." - ) - vcf_type = info_field_defs[info_field]["Type"] - dtype = vcf_type_to_numpy_type(vcf_type, "INFO", info_field) - - vcf_number = info_field_defs[info_field]["Number"] - if vcf_number not in ("0", "1"): - raise ValueError( - f"INFO field '{info_field}' is defined as Number '{vcf_number}', which is not supported." - ) - info_field_arrays[info_field] = np.full(chunk_length, -1, dtype=dtype) - - format_fields = format_fields or [] - format_field_arrays = {} - format_field_defs = { - h["ID"]: h.info(extra=True) - for h in vcf.header_iter() - if h["HeaderType"] == "FORMAT" - } - for format_field in format_fields: - if format_field not in format_field_defs: - raise ValueError( - f"FORMAT field '{format_field}' is not defined in the header." - ) - vcf_type = format_field_defs[format_field]["Type"] - dtype = vcf_type_to_numpy_type(vcf_type, "FORMAT", format_field) - - vcf_number = format_field_defs[format_field]["Number"] - if vcf_number not in ("0", "1"): - raise ValueError( - f"FORMAT field '{format_field}' is defined as Number '{vcf_number}', which is not supported." - ) - format_field_arrays[format_field] = np.full( - (chunk_length, n_sample), -1, dtype=dtype - ) + fields = fields or [] + field_handlers = [ + VcfFieldHandler.for_field(vcf, field, chunk_length) for field in fields + ] first_variants_chunk = True for variants_chunk in chunks(region_filter(variants, region), chunk_length): @@ -186,13 +218,8 @@ def vcf_to_zarr_sequential( call_genotype[i, ..., n:] = fill call_genotype_phased[i] = gt[..., -1] - for info_field, arr in info_field_arrays.items(): - if info_field in [k for k, _ in variant.INFO]: - arr[i] = variant.INFO[info_field] - - for format_field, arr in format_field_arrays.items(): - if format_field in variant.FORMAT: - arr[i, ...] = variant.format(format_field)[..., 0] + for field_handler in field_handlers: + field_handler.add_variant(i, variant) # Truncate np arrays (if last chunk is smaller than chunk_length) if i + 1 < chunk_length: @@ -201,11 +228,8 @@ def vcf_to_zarr_sequential( call_genotype = call_genotype[: i + 1] call_genotype_phased = call_genotype_phased[: i + 1] - for info_field, arr in info_field_arrays.items(): - info_field_arrays[info_field] = arr[: i + 1] - - for format_field, arr in format_field_arrays.items(): - format_field_arrays[format_field] = arr[: i + 1] + for field_handler in field_handlers: + field_handler.truncate_array(i + 1) variant_id = np.array(variant_ids, dtype="O") variant_id_mask = variant_id == "." @@ -226,18 +250,8 @@ def vcf_to_zarr_sequential( [DIM_VARIANT], variant_id_mask, ) - for info_field, arr in info_field_arrays.items(): - var_name = f"variant_{info_field}" - ds[var_name] = ([DIM_VARIANT], arr) - ds[var_name].attrs["comment"] = info_field_defs[info_field][ - "Description" - ].strip('"') - for format_field, arr in format_field_arrays.items(): - var_name = f"call_{format_field}" - ds[var_name] = ([DIM_VARIANT, DIM_SAMPLE], arr) - ds[var_name].attrs["comment"] = format_field_defs[format_field][ - "Description" - ].strip('"') + for field_handler in field_handlers: + field_handler.update_dataset(ds) ds.attrs["max_variant_id_length"] = max_variant_id_length ds.attrs["max_variant_allele_length"] = max_variant_allele_length @@ -281,8 +295,7 @@ def vcf_to_zarr_parallel( ploidy: int = 2, mixed_ploidy: bool = False, truncate_calls: bool = False, - info_fields: Optional[Sequence[str]] = None, - format_fields: Optional[Sequence[str]] = None, + fields: Optional[Sequence[str]] = None, ) -> None: """Convert specified regions of one or more VCF files to zarr files, then concat, rechunk, write to zarr""" @@ -303,8 +316,7 @@ def vcf_to_zarr_parallel( ploidy=ploidy, mixed_ploidy=mixed_ploidy, truncate_calls=truncate_calls, - info_fields=info_fields, - format_fields=format_fields, + fields=fields, ) ds = zarrs_to_dataset(paths, chunk_length, chunk_width, tempdir_storage_options) @@ -324,8 +336,7 @@ def vcf_to_zarrs( ploidy: int = 2, mixed_ploidy: bool = False, truncate_calls: bool = False, - info_fields: Optional[Sequence[str]] = None, - format_fields: Optional[Sequence[str]] = None, + fields: Optional[Sequence[str]] = None, ) -> Sequence[str]: """Convert VCF files to multiple Zarr on-disk stores, one per region. @@ -406,8 +417,7 @@ def vcf_to_zarrs( ploidy=ploidy, mixed_ploidy=mixed_ploidy, truncate_calls=truncate_calls, - info_fields=info_fields, - format_fields=format_fields, + fields=fields, ) tasks.append(task) dask.compute(*tasks) @@ -428,8 +438,7 @@ def vcf_to_zarr( ploidy: int = 2, mixed_ploidy: bool = False, truncate_calls: bool = False, - info_fields: Optional[Sequence[str]] = None, - format_fields: Optional[Sequence[str]] = None, + fields: Optional[Sequence[str]] = None, ) -> None: """Convert VCF files to a single Zarr on-disk store. @@ -529,8 +538,7 @@ def vcf_to_zarr( ploidy=ploidy, mixed_ploidy=mixed_ploidy, truncate_calls=truncate_calls, - info_fields=info_fields, - format_fields=format_fields, + fields=fields, ) diff --git a/sgkit/tests/io/vcf/test_vcf_reader.py b/sgkit/tests/io/vcf/test_vcf_reader.py index 35b391ae2..c56d1886d 100644 --- a/sgkit/tests/io/vcf/test_vcf_reader.py +++ b/sgkit/tests/io/vcf/test_vcf_reader.py @@ -3,7 +3,7 @@ import numpy as np import pytest import xarray as xr -from numpy.testing import assert_array_equal +from numpy.testing import assert_allclose, assert_array_equal from sgkit import load_dataset from sgkit.io.vcf import partition_into_regions, vcf_to_zarr @@ -490,11 +490,17 @@ def test_vcf_to_zarr__contig_not_defined_in_header(shared_datadir, tmp_path): vcf_to_zarr(path, output) -def test_vcf_to_zarr__info_fields(shared_datadir, tmp_path): +def test_vcf_to_zarr__fields(shared_datadir, tmp_path): path = path_for_test(shared_datadir, "sample.vcf.gz") output = tmp_path.joinpath("vcf.zarr").as_posix() - vcf_to_zarr(path, output, chunk_length=5, chunk_width=2, info_fields=["DP", "DB"]) + vcf_to_zarr( + path, + output, + chunk_length=5, + chunk_width=2, + fields=["INFO/DP", "INFO/DB", "FORMAT/DP"], + ) ds = xr.open_zarr(output) # type: ignore[no-untyped-call] assert_array_equal(ds["variant_DP"], [-1, -1, 14, 11, 10, 13, 9, 14, 11]) @@ -505,37 +511,6 @@ def test_vcf_to_zarr__info_fields(shared_datadir, tmp_path): ) assert ds["variant_DB"].attrs["comment"] == "dbSNP membership, build 129" - -def test_vcf_to_zarr__info_fields_errors(shared_datadir, tmp_path): - path = path_for_test(shared_datadir, "sample.vcf.gz") - output = tmp_path.joinpath("vcf.zarr").as_posix() - - with pytest.raises( - ValueError, - match=r"INFO field 'XX' is not defined in the header.", - ): - vcf_to_zarr(path, output, info_fields=["XX"]) - - with pytest.raises( - ValueError, - match=r"INFO field 'AC' is defined as Number '.', which is not supported.", - ): - vcf_to_zarr(path, output, info_fields=["AC"]) - - with pytest.raises( - ValueError, - match=r"INFO field 'AA' is defined as Type 'String', which is not supported.", - ): - vcf_to_zarr(path, output, info_fields=["AA"]) - - -def test_vcf_to_zarr__format_fields(shared_datadir, tmp_path): - path = path_for_test(shared_datadir, "sample.vcf.gz") - output = tmp_path.joinpath("vcf.zarr").as_posix() - - vcf_to_zarr(path, output, chunk_length=5, chunk_width=2, format_fields=["DP"]) - ds = xr.open_zarr(output) # type: ignore[no-untyped-call] - dp = np.array( [ [-1, -1, -1], @@ -554,18 +529,57 @@ def test_vcf_to_zarr__format_fields(shared_datadir, tmp_path): assert ds["call_DP"].attrs["comment"] == "Read Depth" -def test_vcf_to_zarr__format_fields_errors(shared_datadir, tmp_path): +def test_vcf_to_zarr__float_fields(shared_datadir, tmp_path): + path = path_for_test(shared_datadir, "CEUTrio.21.gatk3.4.g.vcf.bgz") + output = tmp_path.joinpath("vcf.zarr").as_posix() + + vcf_to_zarr(path, output, fields=["INFO/MQ"]) + ds = xr.open_zarr(output) # type: ignore[no-untyped-call] + + # select a small region to check + ds = ds.set_index(variants="variant_position").sel( + variants=slice(10002572, 10002575) + ) + assert_allclose(ds["variant_MQ"], [28.59, 28.59, -1]) + assert ds["variant_MQ"].attrs["comment"] == "RMS Mapping Quality" + + +def test_vcf_to_zarr__fields_errors(shared_datadir, tmp_path): path = path_for_test(shared_datadir, "sample.vcf.gz") output = tmp_path.joinpath("vcf.zarr").as_posix() + with pytest.raises( + ValueError, + match=r"VCF field must be prefixed with 'INFO/' or 'FORMAT/'", + ): + vcf_to_zarr(path, output, fields=["DP"]) + + with pytest.raises( + ValueError, + match=r"INFO field 'XX' is not defined in the header.", + ): + vcf_to_zarr(path, output, fields=["INFO/XX"]) + with pytest.raises( ValueError, match=r"FORMAT field 'XX' is not defined in the header.", ): - vcf_to_zarr(path, output, format_fields=["XX"]) + vcf_to_zarr(path, output, fields=["FORMAT/XX"]) + + with pytest.raises( + ValueError, + match=r"INFO field 'AC' is defined as Number '.', which is not supported.", + ): + vcf_to_zarr(path, output, fields=["INFO/AC"]) with pytest.raises( ValueError, match=r"FORMAT field 'HQ' is defined as Number '2', which is not supported.", ): - vcf_to_zarr(path, output, format_fields=["HQ"]) + vcf_to_zarr(path, output, fields=["FORMAT/HQ"]) + + with pytest.raises( + ValueError, + match=r"INFO field 'AA' is defined as Type 'String', which is not supported.", + ): + vcf_to_zarr(path, output, fields=["INFO/AA"])