Skip to content

Commit

Permalink
Add a validate_schemas hook to clean up downstream validation code (#…
Browse files Browse the repository at this point in the history
…76)

Co-authored-by: Julio Perez <[email protected]>
  • Loading branch information
karlhigley and jperez999 authored Oct 7, 2022
1 parent f8e1221 commit 1fd18f9
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 5 deletions.
59 changes: 58 additions & 1 deletion merlin/dag/base_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,26 @@ def compute_selector(
parents_selector: ColumnSelector,
dependencies_selector: ColumnSelector,
) -> ColumnSelector:
"""
Provides a hook method for sub-classes to override to implement
custom column selection logic.
Parameters
----------
input_schema : Schema
Schemas of the columns to apply this operator to
selector : ColumnSelector
Column selector to apply to the input schema
parents_selector : ColumnSelector
Combined selectors of the upstream parents feeding into this operator
dependencies_selector : ColumnSelector
Combined selectors of the upstream dependencies feeding into this operator
Returns
-------
ColumnSelector
Revised column selector to apply to the input schema
"""
self._validate_matching_cols(input_schema, selector, self.compute_selector.__name__)

return selector
Expand All @@ -62,6 +82,7 @@ def compute_input_schema(
) -> Schema:
"""Given the schemas coming from upstream sources and a column selector for the
input columns, returns a set of schemas for the input columns this operator will use
Parameters
-----------
root_schema: Schema
Expand All @@ -72,6 +93,7 @@ def compute_input_schema(
The combined schemas of the upstream dependencies feeding into this operator
col_selector: ColumnSelector
The column selector to apply to the input schema
Returns
-------
Schema
Expand All @@ -89,14 +111,17 @@ def compute_output_schema(
col_selector: ColumnSelector,
prev_output_schema: Schema = None,
) -> Schema:
"""Given a set of schemas and a column selector for the input columns,
"""
Given a set of schemas and a column selector for the input columns,
returns a set of schemas for the transformed columns this operator will produce
Parameters
-----------
input_schema: Schema
The schemas of the columns to apply this operator to
col_selector: ColumnSelector
The column selector to apply to the input schema
Returns
-------
Schema
Expand Down Expand Up @@ -132,6 +157,35 @@ def compute_output_schema(

return output_schema

def validate_schemas(
self,
parents_schema: Schema,
deps_schema: Schema,
input_schema: Schema,
output_schema: Schema,
strict_dtypes: bool = False,
):
"""
Provides a hook method that sub-classes can override to implement schema validation logic.
Sub-class implementations should raise an exception if the schemas are not valid for the
operations they implement.
Parameters
----------
parents_schema : Schema
The combined schemas of the upstream parents feeding into this operator
deps_schema : Schema
The combined schemas of the upstream dependencies feeding into this operator
input_schema : Schema
The schemas of the columns to apply this operator to
output_schema : Schema
The schemas of the columns produced by this operator
strict_dtypes : Boolean, optional
Enables strict checking for column dtype matching if True, by default False
"""
...

def transform(
self, col_selector: ColumnSelector, transformable: Transformable
) -> Transformable:
Expand Down Expand Up @@ -240,10 +294,12 @@ def _validate_matching_cols(self, schema, selector, method_name):
def output_column_names(self, col_selector: ColumnSelector) -> ColumnSelector:
"""Given a set of columns names returns the names of the transformed columns this
operator will produce
Parameters
-----------
columns: list of str, or list of list of str
The columns to apply this operator to
Returns
-------
list of str, or list of list of str
Expand All @@ -255,6 +311,7 @@ def output_column_names(self, col_selector: ColumnSelector) -> ColumnSelector:
def dependencies(self) -> List[Union[str, Any]]:
"""Defines an optional list of column dependencies for this operator. This lets you consume columns
that aren't part of the main transformation workflow.
Returns
-------
str, list of str or ColumnSelector, optional
Expand Down
54 changes: 50 additions & 4 deletions merlin/dag/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,10 @@ def validate_schemas(self, root_schema: Schema, strict_dtypes: bool = False):
f"expected dtype '{col_schema.dtype}'."
)

self.op.validate_schemas(
parents_schema, deps_schema, self.input_schema, self.output_schema, strict_dtypes
)

def __rshift__(self, operator):
"""Transforms this Node by applying an BaseOperator
Expand Down Expand Up @@ -376,7 +380,20 @@ def __repr__(self):
output = " output" if not self.children else ""
return f"<Node {self.label}{output}>"

def remove_inputs(self, input_cols):
def remove_inputs(self, input_cols: List[str]) -> List[str]:
"""
Remove input columns and all output columns that depend on them.
Parameters
----------
input_cols : List[str]
The input columns to remove
Returns
-------
List[str]
The output columns that were removed
"""
removed_outputs = _derived_output_cols(input_cols, self.column_mapping)

self.input_schema = self.input_schema.without(input_cols)
Expand Down Expand Up @@ -473,8 +490,33 @@ def _cols_repr(self):
def graph(self):
return _to_graphviz(self)

Nodable = Union[
"Node", str, List[str], ColumnSelector, List[Union["Node", str, List[str], ColumnSelector]]
]

@classmethod
def construct_from(cls, nodable):
def construct_from(
cls,
nodable: Nodable,
):
"""
Convert Node-like objects to a Node or list of Nodes.
Parameters
----------
nodable : Nodable
Node-like objects to convert to a Node or list of Nodes.
Returns
-------
Union["Node", List["Node"]]
New Node(s) corresponding to the Node-like input objects
Raises
------
TypeError
If supplied input cannot be converted to a Node or list of Nodes
"""
if isinstance(nodable, str):
return Node(ColumnSelector([nodable]))
if isinstance(nodable, ColumnSelector):
Expand All @@ -486,8 +528,12 @@ def construct_from(cls, nodable):
return Node(nodable)
else:
nodes = [Node.construct_from(node) for node in nodable]
non_selection_nodes = [node for node in nodes if not node.selector]
selection_nodes = [node.selector for node in nodes if node.selector]
non_selection_nodes = [
node for node in nodes if not (hasattr(node, "selector") and node.selector)
]
selection_nodes = [
node.selector for node in nodes if (hasattr(node, "selector") and node.selector)
]
selection_nodes = (
[Node(_combine_selectors(selection_nodes))] if selection_nodes else []
)
Expand Down

0 comments on commit 1fd18f9

Please sign in to comment.