Skip to content

Commit

Permalink
Add test-split support to Zookeeper datasets. (#149)
Browse files Browse the repository at this point in the history
  • Loading branch information
AdamHillier authored Jun 10, 2020
1 parent e61344d commit fe287dc
Showing 1 changed file with 46 additions and 3 deletions.
49 changes: 46 additions & 3 deletions zookeeper/tf/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,16 @@ def validation(self, decoders=None) -> Tuple[tf.data.Dataset, int]:
"data."
)

def test(self, decoders=None) -> Tuple[tf.data.Dataset, int]:
"""
Return a tuple of the test dataset and the number of test examples in
the dataset. By default, raises an error that no test data is provided.
"""

raise ValueError(
f"Dataset '{self.__class__.__name__}' is not configured with test data."
)


def base_splits(split):
"""
Expand Down Expand Up @@ -71,9 +81,11 @@ class TFDSDataset(Dataset):
# Whether or not to download the dataset (if it's not already downloaded).
download: bool = Field(False)

# Train and validation splits. A validation split is not required.
# Train, validation, and test splits. Neither a validation nor a test split
# is required.
train_split: str = Field()
validation_split: Optional[str] = Field(None)
test_split: Optional[str] = Field(None)

@property
def info(self):
Expand Down Expand Up @@ -142,6 +154,17 @@ def validation(self, decoders=None) -> Tuple[tf.data.Dataset, int]:
self.num_examples(self.validation_split),
)

def test(self, decoders=None) -> Tuple[tf.data.Dataset, int]:
if self.test_split is None:
raise ValueError(
f"Dataset {self.__class__.__name__} is not configured with a "
"test split."
)
return (
self.load(self.test_split, decoders=decoders, shuffle=False),
self.num_examples(self.test_split),
)


class MultiTFDSDataset(Dataset):
"""
Expand All @@ -164,6 +187,10 @@ class MultiTFDSDataset(Dataset):
# empty, indicating no validation data.
validation_split: Dict[str, str] = Field(lambda: {})

# As above, a mapping from dataset names as keys to splits as values. May be
# empty, indicating no test data.
test_split: Dict[str, str] = Field(lambda: {})

def num_examples(self, splits) -> int:
"""
Compute the total number of examples in the splits specified by the
Expand Down Expand Up @@ -199,14 +226,30 @@ def load(self, splits, decoders, shuffle) -> tf.data.Dataset:
result = result.concatenate(dataset) if result is not None else dataset
return result

def train(self, decoders=None):
def train(self, decoders=None) -> Tuple[tf.data.Dataset, int]:
return (
self.load(self.train_split, decoders=decoders, shuffle=True),
self.num_examples(self.train_split),
)

def validation(self, decoders=None):
def validation(self, decoders=None) -> Tuple[tf.data.Dataset, int]:
if self.validation_split is None:
raise ValueError(
f"Dataset {self.__class__.__name__} is not configured with a "
"validation split."
)
return (
self.load(self.validation_split, decoders=decoders, shuffle=False),
self.num_examples(self.validation_split),
)

def test(self, decoders=None) -> Tuple[tf.data.Dataset, int]:
if self.test_split is None:
raise ValueError(
f"Dataset {self.__class__.__name__} is not configured with a "
"test split."
)
return (
self.load(self.test_split, decoders=decoders, shuffle=False),
self.num_examples(self.test_split),
)

0 comments on commit fe287dc

Please sign in to comment.