Skip to content

Commit

Permalink
Unify INFO and FORMAT
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed Feb 24, 2021
1 parent 42947f3 commit 9c17620
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 122 deletions.
178 changes: 93 additions & 85 deletions sgkit/io/vcf/vcf_reader.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
import functools
import itertools
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Hashable, Iterator, MutableMapping, Optional, Sequence, Union
from typing import (
Any,
Dict,
Hashable,
Iterator,
MutableMapping,
Optional,
Sequence,
Tuple,
Union,
)

import dask
import fsspec
Expand All @@ -14,7 +25,7 @@
from sgkit.io.vcf import partition_into_regions
from sgkit.io.vcf.utils import build_url, chunks, temporary_directory, url_filename
from sgkit.model import DIM_SAMPLE, DIM_VARIANT, create_genotype_call_dataset
from sgkit.typing import DType, PathType
from sgkit.typing import ArrayLike, DType, PathType

DEFAULT_ALT_NUMBER = 3 # see vcf_read.py in scikit_allel

Expand Down Expand Up @@ -61,6 +72,69 @@ def vcf_type_to_numpy_type(vcf_type: str, field_type: str, field: str) -> DType:
)


@dataclass
class VcfFieldHandler:
category: str
key: str
variable_name: str
description: str
dims: Sequence[str]
chunksize: Tuple[int, ...]
array: ArrayLike

@classmethod
def for_field(cls, vcf: VCF, field: str, chunk_length: int) -> "VcfFieldHandler":
if not any(field.startswith(prefix) for prefix in ["INFO/", "FORMAT/"]):
raise ValueError("VCF field must be prefixed with 'INFO/' or 'FORMAT/'")
category = field.split("/")[0]
field_defs = {
h["ID"]: h.info(extra=True)
for h in vcf.header_iter()
if h["HeaderType"] == category
}
key = field[len(f"{category}/") :]
if key not in field_defs:
raise ValueError(f"{category} field '{key}' is not defined in the header.")
description = field_defs[key]["Description"].strip('"')
vcf_type = field_defs[key]["Type"]
dtype = vcf_type_to_numpy_type(vcf_type, category, key)
vcf_number = field_defs[key]["Number"]
if vcf_number not in ("0", "1"):
raise ValueError(
f"{category} field '{key}' is defined as Number '{vcf_number}', which is not supported."
)
chunksize: Tuple[int, ...]
if category == "INFO":
variable_name = f"variant_{key}"
dims = [DIM_VARIANT]
chunksize = (chunk_length,)
array = np.full(chunksize, -1, dtype=dtype)
elif category == "FORMAT":
variable_name = f"call_{key}"
dims = [DIM_VARIANT, DIM_SAMPLE]
n_sample = len(vcf.samples)
chunksize = (chunk_length, n_sample)
array = np.full(chunksize, -1, dtype=dtype)

return cls(category, key, variable_name, description, dims, chunksize, array)

def add_variant(self, i: int, variant: Any) -> None:
key = self.key
if self.category == "INFO":
if key in [k for k, _ in variant.INFO]:
self.array[i] = variant.INFO[key]
elif self.category == "FORMAT":
if key in variant.FORMAT:
self.array[i, ...] = variant.format(key)[..., 0]

def truncate_array(self, length: int) -> None:
self.array = self.array[:length]

def update_dataset(self, ds: xr.Dataset) -> None:
ds[self.variable_name] = (self.dims, self.array)
ds[self.variable_name].attrs["comment"] = self.description


def vcf_to_zarr_sequential(
input: PathType,
output: Union[PathType, MutableMapping[str, bytes]],
Expand All @@ -70,8 +144,7 @@ def vcf_to_zarr_sequential(
ploidy: int = 2,
mixed_ploidy: bool = False,
truncate_calls: bool = False,
info_fields: Optional[Sequence[str]] = None,
format_fields: Optional[Sequence[str]] = None,
fields: Optional[Sequence[str]] = None,
) -> None:

with open_vcf(input) as vcf:
Expand Down Expand Up @@ -100,51 +173,10 @@ def vcf_to_zarr_sequential(
call_genotype = np.empty((chunk_length, n_sample, ploidy), dtype="i1")
call_genotype_phased = np.empty((chunk_length, n_sample), dtype=bool)

info_fields = info_fields or []
info_field_arrays = {}
info_field_defs = {
h["ID"]: h.info(extra=True)
for h in vcf.header_iter()
if h["HeaderType"] == "INFO"
}
for info_field in info_fields:
if info_field not in info_field_defs:
raise ValueError(
f"INFO field '{info_field}' is not defined in the header."
)
vcf_type = info_field_defs[info_field]["Type"]
dtype = vcf_type_to_numpy_type(vcf_type, "INFO", info_field)

vcf_number = info_field_defs[info_field]["Number"]
if vcf_number not in ("0", "1"):
raise ValueError(
f"INFO field '{info_field}' is defined as Number '{vcf_number}', which is not supported."
)
info_field_arrays[info_field] = np.full(chunk_length, -1, dtype=dtype)

format_fields = format_fields or []
format_field_arrays = {}
format_field_defs = {
h["ID"]: h.info(extra=True)
for h in vcf.header_iter()
if h["HeaderType"] == "FORMAT"
}
for format_field in format_fields:
if format_field not in format_field_defs:
raise ValueError(
f"FORMAT field '{format_field}' is not defined in the header."
)
vcf_type = format_field_defs[format_field]["Type"]
dtype = vcf_type_to_numpy_type(vcf_type, "FORMAT", format_field)

vcf_number = format_field_defs[format_field]["Number"]
if vcf_number not in ("0", "1"):
raise ValueError(
f"FORMAT field '{format_field}' is defined as Number '{vcf_number}', which is not supported."
)
format_field_arrays[format_field] = np.full(
(chunk_length, n_sample), -1, dtype=dtype
)
fields = fields or []
field_handlers = [
VcfFieldHandler.for_field(vcf, field, chunk_length) for field in fields
]

first_variants_chunk = True
for variants_chunk in chunks(region_filter(variants, region), chunk_length):
Expand Down Expand Up @@ -186,13 +218,8 @@ def vcf_to_zarr_sequential(
call_genotype[i, ..., n:] = fill
call_genotype_phased[i] = gt[..., -1]

for info_field, arr in info_field_arrays.items():
if info_field in [k for k, _ in variant.INFO]:
arr[i] = variant.INFO[info_field]

for format_field, arr in format_field_arrays.items():
if format_field in variant.FORMAT:
arr[i, ...] = variant.format(format_field)[..., 0]
for field_handler in field_handlers:
field_handler.add_variant(i, variant)

# Truncate np arrays (if last chunk is smaller than chunk_length)
if i + 1 < chunk_length:
Expand All @@ -201,11 +228,8 @@ def vcf_to_zarr_sequential(
call_genotype = call_genotype[: i + 1]
call_genotype_phased = call_genotype_phased[: i + 1]

for info_field, arr in info_field_arrays.items():
info_field_arrays[info_field] = arr[: i + 1]

for format_field, arr in format_field_arrays.items():
format_field_arrays[format_field] = arr[: i + 1]
for field_handler in field_handlers:
field_handler.truncate_array(i + 1)

variant_id = np.array(variant_ids, dtype="O")
variant_id_mask = variant_id == "."
Expand All @@ -226,18 +250,8 @@ def vcf_to_zarr_sequential(
[DIM_VARIANT],
variant_id_mask,
)
for info_field, arr in info_field_arrays.items():
var_name = f"variant_{info_field}"
ds[var_name] = ([DIM_VARIANT], arr)
ds[var_name].attrs["comment"] = info_field_defs[info_field][
"Description"
].strip('"')
for format_field, arr in format_field_arrays.items():
var_name = f"call_{format_field}"
ds[var_name] = ([DIM_VARIANT, DIM_SAMPLE], arr)
ds[var_name].attrs["comment"] = format_field_defs[format_field][
"Description"
].strip('"')
for field_handler in field_handlers:
field_handler.update_dataset(ds)
ds.attrs["max_variant_id_length"] = max_variant_id_length
ds.attrs["max_variant_allele_length"] = max_variant_allele_length

Expand Down Expand Up @@ -281,8 +295,7 @@ def vcf_to_zarr_parallel(
ploidy: int = 2,
mixed_ploidy: bool = False,
truncate_calls: bool = False,
info_fields: Optional[Sequence[str]] = None,
format_fields: Optional[Sequence[str]] = None,
fields: Optional[Sequence[str]] = None,
) -> None:
"""Convert specified regions of one or more VCF files to zarr files, then concat, rechunk, write to zarr"""

Expand All @@ -303,8 +316,7 @@ def vcf_to_zarr_parallel(
ploidy=ploidy,
mixed_ploidy=mixed_ploidy,
truncate_calls=truncate_calls,
info_fields=info_fields,
format_fields=format_fields,
fields=fields,
)

ds = zarrs_to_dataset(paths, chunk_length, chunk_width, tempdir_storage_options)
Expand All @@ -324,8 +336,7 @@ def vcf_to_zarrs(
ploidy: int = 2,
mixed_ploidy: bool = False,
truncate_calls: bool = False,
info_fields: Optional[Sequence[str]] = None,
format_fields: Optional[Sequence[str]] = None,
fields: Optional[Sequence[str]] = None,
) -> Sequence[str]:
"""Convert VCF files to multiple Zarr on-disk stores, one per region.
Expand Down Expand Up @@ -406,8 +417,7 @@ def vcf_to_zarrs(
ploidy=ploidy,
mixed_ploidy=mixed_ploidy,
truncate_calls=truncate_calls,
info_fields=info_fields,
format_fields=format_fields,
fields=fields,
)
tasks.append(task)
dask.compute(*tasks)
Expand All @@ -428,8 +438,7 @@ def vcf_to_zarr(
ploidy: int = 2,
mixed_ploidy: bool = False,
truncate_calls: bool = False,
info_fields: Optional[Sequence[str]] = None,
format_fields: Optional[Sequence[str]] = None,
fields: Optional[Sequence[str]] = None,
) -> None:
"""Convert VCF files to a single Zarr on-disk store.
Expand Down Expand Up @@ -529,8 +538,7 @@ def vcf_to_zarr(
ploidy=ploidy,
mixed_ploidy=mixed_ploidy,
truncate_calls=truncate_calls,
info_fields=info_fields,
format_fields=format_fields,
fields=fields,
)


Expand Down
Loading

0 comments on commit 9c17620

Please sign in to comment.