diff --git a/sgkit/io/utils.py b/sgkit/io/utils.py index c82bc052c..9b47c92be 100644 --- a/sgkit/io/utils.py +++ b/sgkit/io/utils.py @@ -100,6 +100,11 @@ def zarrs_to_dataset( ds[variable_name] = ds[variable_name].astype(f"S{max_length}") del ds.attrs[attr] + if "max_alt_alleles_seen" in datasets[0].attrs: + ds.attrs["max_alt_alleles_seen"] = max( + ds.attrs["max_alt_alleles_seen"] for ds in datasets + ) + return ds diff --git a/sgkit/io/vcf/__init__.py b/sgkit/io/vcf/__init__.py index 4d93e0fd2..a6eb71e7a 100644 --- a/sgkit/io/vcf/__init__.py +++ b/sgkit/io/vcf/__init__.py @@ -3,9 +3,10 @@ try: from ..utils import zarrs_to_dataset from .vcf_partition import partition_into_regions - from .vcf_reader import vcf_to_zarr, vcf_to_zarrs + from .vcf_reader import MaxAltAllelesExceededWarning, vcf_to_zarr, vcf_to_zarrs __all__ = [ + "MaxAltAllelesExceededWarning", "partition_into_regions", "vcf_to_zarr", "vcf_to_zarrs", diff --git a/sgkit/io/vcf/vcf_reader.py b/sgkit/io/vcf/vcf_reader.py index e4f9ae1e1..725032376 100644 --- a/sgkit/io/vcf/vcf_reader.py +++ b/sgkit/io/vcf/vcf_reader.py @@ -24,6 +24,7 @@ from numcodecs import PackBits from sgkit import variables +from sgkit.io.dataset import load_dataset from sgkit.io.utils import zarrs_to_dataset from sgkit.io.vcf import partition_into_regions from sgkit.io.vcf.utils import build_url, chunks, temporary_directory, url_filename @@ -50,6 +51,12 @@ DEFAULT_COMPRESSOR = None +class MaxAltAllelesExceededWarning(UserWarning): + """Warning when the number of alt alleles exceeds the maximum specified.""" + + pass + + @contextmanager def open_vcf(path: PathType) -> Iterator[VCF]: """A context manager for opening a VCF file.""" @@ -154,7 +161,7 @@ def for_field( ) -> "VcfFieldHandler": if field == "FORMAT/GT": return GenotypeFieldHandler( - vcf, chunk_length, ploidy, mixed_ploidy, truncate_calls + vcf, chunk_length, ploidy, mixed_ploidy, truncate_calls, max_alt_alleles ) category = field.split("/")[0] vcf_field_defs = _get_vcf_field_defs(vcf, category) @@ -279,6 +286,7 @@ def __init__( ploidy: int, mixed_ploidy: bool, truncate_calls: bool, + max_alt_alleles: int, ) -> None: n_sample = len(vcf.samples) self.call_genotype = np.empty((chunk_length, n_sample, ploidy), dtype="i1") @@ -286,6 +294,7 @@ def __init__( self.ploidy = ploidy self.mixed_ploidy = mixed_ploidy self.truncate_calls = truncate_calls + self.max_alt_alleles = max_alt_alleles def add_variant(self, i: int, variant: Any) -> None: fill = -2 if self.mixed_ploidy else -1 @@ -298,6 +307,10 @@ def add_variant(self, i: int, variant: Any) -> None: self.call_genotype[i, ..., 0:n] = gt[..., 0:n] self.call_genotype[i, ..., n:] = fill self.call_genotype_phased[i] = gt[..., -1] + + # set any calls that exceed maximum number of alt alleles as missing + self.call_genotype[i][self.call_genotype[i] > self.max_alt_alleles] = -1 + else: self.call_genotype[i] = fill self.call_genotype_phased[i] = 0 @@ -362,6 +375,7 @@ def vcf_to_zarr_sequential( # Remember max lengths of variable-length strings max_variant_id_length = 0 max_variant_allele_length = 0 + max_alt_alleles_seen = 0 # Iterate through variants in batches of chunk_length @@ -413,6 +427,7 @@ def vcf_to_zarr_sequential( variant_position[i] = variant.POS alleles = [variant.REF] + variant.ALT + max_alt_alleles_seen = max(max_alt_alleles_seen, len(variant.ALT)) if len(alleles) > n_allele: alleles = alleles[:n_allele] elif len(alleles) < n_allele: @@ -457,6 +472,7 @@ def vcf_to_zarr_sequential( if add_str_max_length_attrs: ds.attrs["max_length_variant_id"] = max_variant_id_length ds.attrs["max_length_variant_allele"] = max_variant_allele_length + ds.attrs["max_alt_alleles_seen"] = max_alt_alleles_seen if first_variants_chunk: # Enforce uniform chunks in the variants dimension @@ -605,7 +621,9 @@ def vcf_to_zarrs( specified ploidy will raise an exception. max_alt_alleles The (maximum) number of alternate alleles in the VCF file. Any records with more than - this number of alternate alleles will have the extra alleles dropped. + this number of alternate alleles will have the extra alleles dropped (the `variant_allele` + variable will be truncated). Any call genotype fields with the extra alleles will + be changed to the missing-allele sentinel value of -1. fields Extra fields to extract data for. A list of strings, with ``INFO`` or ``FORMAT`` prefixes. Wildcards are permitted too, for example: ``["INFO/*", "FORMAT/DP"]``. @@ -772,7 +790,9 @@ def vcf_to_zarr( specified ploidy will raise an exception. max_alt_alleles The (maximum) number of alternate alleles in the VCF file. Any records with more than - this number of alternate alleles will have the extra alleles dropped. + this number of alternate alleles will have the extra alleles dropped (the `variant_allele` + variable will be truncated). Any call genotype fields with the extra alleles will + be changed to the missing-allele sentinel value of -1. fields Extra fields to extract data for. A list of strings, with ``INFO`` or ``FORMAT`` prefixes. Wildcards are permitted too, for example: ``["INFO/*", "FORMAT/DP"]``. @@ -839,6 +859,15 @@ def vcf_to_zarr( field_defs=field_defs, ) + # Issue a warning if max_alt_alleles caused data to be dropped + ds = load_dataset(output) + max_alt_alleles_seen = ds.attrs["max_alt_alleles_seen"] + if max_alt_alleles_seen > max_alt_alleles: + warnings.warn( + f"Some alternate alleles were dropped, since actual max value {max_alt_alleles_seen} exceeded max_alt_alleles setting of {max_alt_alleles}.", + MaxAltAllelesExceededWarning, + ) + def count_variants(path: PathType, region: Optional[str] = None) -> int: """Count the number of variants in a VCF file.""" diff --git a/sgkit/tests/io/vcf/test_vcf_reader.py b/sgkit/tests/io/vcf/test_vcf_reader.py index 168f5b1df..2ed0ea55b 100644 --- a/sgkit/tests/io/vcf/test_vcf_reader.py +++ b/sgkit/tests/io/vcf/test_vcf_reader.py @@ -8,7 +8,11 @@ from numpy.testing import assert_allclose, assert_array_equal from sgkit import load_dataset -from sgkit.io.vcf import partition_into_regions, vcf_to_zarr +from sgkit.io.vcf import ( + MaxAltAllelesExceededWarning, + partition_into_regions, + vcf_to_zarr, +) from .utils import path_for_test @@ -98,30 +102,41 @@ def test_vcf_to_zarr__max_alt_alleles(shared_datadir, is_path, tmp_path): path = path_for_test(shared_datadir, "sample.vcf.gz", is_path) output = tmp_path.joinpath("vcf.zarr").as_posix() - vcf_to_zarr(path, output, chunk_length=5, chunk_width=2, max_alt_alleles=1) - ds = xr.open_zarr(output) + with pytest.warns(MaxAltAllelesExceededWarning): + max_alt_alleles = 1 + vcf_to_zarr( + path, output, chunk_length=5, chunk_width=2, max_alt_alleles=max_alt_alleles + ) + ds = xr.open_zarr(output) - # extra alt alleles are silently dropped - assert_array_equal( - ds["variant_allele"], - [ - ["A", "C"], - ["A", "G"], - ["G", "A"], - ["T", "A"], - ["A", "G"], - ["T", ""], - ["G", "GA"], - ["T", ""], - ["AC", "A"], - ], - ) + # extra alt alleles are dropped + assert_array_equal( + ds["variant_allele"], + [ + ["A", "C"], + ["A", "G"], + ["G", "A"], + ["T", "A"], + ["A", "G"], + ["T", ""], + ["G", "GA"], + ["T", ""], + ["AC", "A"], + ], + ) + + # genotype calls are truncated + assert np.all(ds["call_genotype"].values <= max_alt_alleles) + + # the maximum number of alt alleles actually seen is stored as an attribute + assert ds.attrs["max_alt_alleles_seen"] == 3 @pytest.mark.parametrize( "is_path", [True, False], ) +@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning") def test_vcf_to_zarr__large_vcf(shared_datadir, is_path, tmp_path): path = path_for_test(shared_datadir, "CEUTrio.20.21.gatk3.4.g.vcf.bgz", is_path) output = tmp_path.joinpath("vcf.zarr").as_posix() @@ -159,6 +174,7 @@ def test_vcf_to_zarr__plain_vcf_with_no_index(shared_datadir, tmp_path): "is_path", [True, False], ) +@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning") def test_vcf_to_zarr__mutable_mapping(shared_datadir, is_path): path = path_for_test(shared_datadir, "CEUTrio.20.21.gatk3.4.g.vcf.bgz", is_path) output: MutableMapping[str, bytes] = {} @@ -217,6 +233,7 @@ def test_vcf_to_zarr__compressor_and_filters(shared_datadir, is_path, tmp_path): "is_path", [True, False], ) +@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning") def test_vcf_to_zarr__parallel(shared_datadir, is_path, tmp_path): path = path_for_test(shared_datadir, "CEUTrio.20.21.gatk3.4.g.vcf.bgz", is_path) output = tmp_path.joinpath("vcf_concat.zarr").as_posix() @@ -266,6 +283,7 @@ def test_vcf_to_zarr__empty_region(shared_datadir, is_path, tmp_path): "is_path", [False], ) +@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning") def test_vcf_to_zarr__parallel_temp_chunk_length(shared_datadir, is_path, tmp_path): path = path_for_test(shared_datadir, "CEUTrio.20.21.gatk3.4.g.vcf.bgz", is_path) output = tmp_path.joinpath("vcf_concat.zarr").as_posix() @@ -354,6 +372,7 @@ def test_vcf_to_zarr__parallel_partitioned_by_size(shared_datadir, is_path, tmp_ "is_path", [True, False], ) +@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning") def test_vcf_to_zarr__multiple(shared_datadir, is_path, tmp_path): paths = [ path_for_test(shared_datadir, "CEUTrio.20.gatk3.4.g.vcf.bgz", is_path), @@ -381,6 +400,7 @@ def test_vcf_to_zarr__multiple(shared_datadir, is_path, tmp_path): "is_path", [True, False], ) +@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning") def test_vcf_to_zarr__multiple_partitioned(shared_datadir, is_path, tmp_path): paths = [ path_for_test(shared_datadir, "CEUTrio.20.gatk3.4.g.vcf.bgz", is_path), @@ -410,6 +430,7 @@ def test_vcf_to_zarr__multiple_partitioned(shared_datadir, is_path, tmp_path): "is_path", [True, False], ) +@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning") def test_vcf_to_zarr__multiple_partitioned_by_size(shared_datadir, is_path, tmp_path): paths = [ path_for_test(shared_datadir, "CEUTrio.20.gatk3.4.g.vcf.bgz", is_path), @@ -456,6 +477,31 @@ def test_vcf_to_zarr__mutiple_partitioned_invalid_regions( vcf_to_zarr(paths, output, regions=regions, chunk_length=5_000) +@pytest.mark.parametrize( + "is_path", + [True, False], +) +def test_vcf_to_zarr__multiple_max_alt_alleles(shared_datadir, is_path, tmp_path): + paths = [ + path_for_test(shared_datadir, "CEUTrio.20.gatk3.4.g.vcf.bgz", is_path), + path_for_test(shared_datadir, "CEUTrio.21.gatk3.4.g.vcf.bgz", is_path), + ] + output = tmp_path.joinpath("vcf_concat.zarr").as_posix() + + with pytest.warns(MaxAltAllelesExceededWarning): + vcf_to_zarr( + paths, + output, + target_part_size="40KB", + chunk_length=5_000, + max_alt_alleles=1, + ) + ds = xr.open_zarr(output) + + # the maximum number of alt alleles actually seen is stored as an attribute + assert ds.attrs["max_alt_alleles_seen"] == 7 + + @pytest.mark.parametrize( "ploidy,mixed_ploidy,truncate_calls,regions", [ @@ -647,6 +693,7 @@ def test_vcf_to_zarr__fields(shared_datadir, tmp_path): assert ds["call_DP"].attrs["comment"] == "Read Depth" +@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning") def test_vcf_to_zarr__parallel_with_fields(shared_datadir, tmp_path): path = path_for_test(shared_datadir, "CEUTrio.20.21.gatk3.4.g.vcf.bgz") output = tmp_path.joinpath("vcf.zarr").as_posix() @@ -703,6 +750,7 @@ def test_vcf_to_zarr__field_defs(shared_datadir, tmp_path): assert "comment" not in ds["variant_DP"].attrs +@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning") def test_vcf_to_zarr__field_number_A(shared_datadir, tmp_path): path = path_for_test(shared_datadir, "sample.vcf.gz") output = tmp_path.joinpath("vcf.zarr").as_posix() @@ -736,6 +784,7 @@ def test_vcf_to_zarr__field_number_A(shared_datadir, tmp_path): ) +@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning") def test_vcf_to_zarr__field_number_R(shared_datadir, tmp_path): path = path_for_test(shared_datadir, "CEUTrio.21.gatk3.4.g.vcf.bgz") output = tmp_path.joinpath("vcf.zarr").as_posix() @@ -768,6 +817,7 @@ def test_vcf_to_zarr__field_number_R(shared_datadir, tmp_path): ) +@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning") def test_vcf_to_zarr__field_number_G(shared_datadir, tmp_path): path = path_for_test(shared_datadir, "CEUTrio.21.gatk3.4.g.vcf.bgz") output = tmp_path.joinpath("vcf.zarr").as_posix() diff --git a/sgkit/tests/io/vcf/test_vcf_roundtrip.py b/sgkit/tests/io/vcf/test_vcf_roundtrip.py index 582ed5c1c..8ed2784c0 100644 --- a/sgkit/tests/io/vcf/test_vcf_roundtrip.py +++ b/sgkit/tests/io/vcf/test_vcf_roundtrip.py @@ -79,6 +79,7 @@ def test_default_fields(shared_datadir, tmpdir): sg_vcfzarr_path = create_sg_vcfzarr(shared_datadir, tmpdir) sg_ds = sg.load_dataset(str(sg_vcfzarr_path)) sg_ds = sg_ds.drop_vars("call_genotype_phased") # not included in scikit-allel + del sg_ds.attrs["max_alt_alleles_seen"] # not saved by scikit-allel assert_identical(allel_ds, sg_ds) @@ -107,21 +108,29 @@ def test_DP_field(shared_datadir, tmpdir): ) sg_ds = sg.load_dataset(str(sg_vcfzarr_path)) sg_ds = sg_ds.drop_vars("call_genotype_phased") # not included in scikit-allel + del sg_ds.attrs["max_alt_alleles_seen"] # not saved by scikit-allel assert_identical(allel_ds, sg_ds) @pytest.mark.parametrize( - "vcf_file,allel_exclude_fields,sgkit_exclude_fields", + "vcf_file,allel_exclude_fields,sgkit_exclude_fields,max_alt_alleles", [ - ("sample.vcf.gz", None, None), - ("mixed.vcf.gz", None, None), + ("sample.vcf.gz", None, None, 3), + ("mixed.vcf.gz", None, None, 3), # exclude PL since it has Number=G, which is not yet supported - ("CEUTrio.20.21.gatk3.4.g.vcf.bgz", ["calldata/PL"], ["FORMAT/PL"]), + # increase max_alt_alleles since scikit-allel does not truncate genotype calls + ("CEUTrio.20.21.gatk3.4.g.vcf.bgz", ["calldata/PL"], ["FORMAT/PL"], 7), ], ) +@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning") def test_all_fields( - shared_datadir, tmpdir, vcf_file, allel_exclude_fields, sgkit_exclude_fields + shared_datadir, + tmpdir, + vcf_file, + allel_exclude_fields, + sgkit_exclude_fields, + max_alt_alleles, ): # change scikit-allel type defaults back to the VCF default types = { @@ -137,6 +146,7 @@ def test_all_fields( fields=["*"], exclude_fields=allel_exclude_fields, types=types, + alt_number=max_alt_alleles, ) field_defs = { @@ -156,9 +166,11 @@ def test_all_fields( exclude_fields=sgkit_exclude_fields, field_defs=field_defs, truncate_calls=True, + max_alt_alleles=max_alt_alleles, ) sg_ds = sg.load_dataset(str(sg_vcfzarr_path)) sg_ds = sg_ds.drop_vars("call_genotype_phased") # not included in scikit-allel + del sg_ds.attrs["max_alt_alleles_seen"] # not saved by scikit-allel # scikit-allel only records contigs for which there are actual variants, # whereas sgkit records contigs from the header