Skip to content

Commit

Permalink
Resolve wildcard selectors in BaseOperator.compute_selector() (#146)
Browse files Browse the repository at this point in the history
* Resolve wildcard selectors in `BaseOperator.compute_selector()`

In order to make sure wildcard selectors get resolved in all operators, we also:
* Made parent and dependency selectors optional in `compute_selector`
* Refactored operators that override `compute_selector` to use `super()`

* Make `compute_selector` signatures match across ops

* Adjust type hints to flag `Optional` arguments
  • Loading branch information
karlhigley authored Oct 7, 2022
1 parent 1fd18f9 commit b7a2d6e
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 15 deletions.
16 changes: 10 additions & 6 deletions merlin/dag/base_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from __future__ import annotations

from enum import Flag, auto
from typing import Any, List, Union
from typing import Any, List, Optional, Union

import merlin.dag
from merlin.core.protocols import Transformable
Expand Down Expand Up @@ -46,8 +46,8 @@ def compute_selector(
self,
input_schema: Schema,
selector: ColumnSelector,
parents_selector: ColumnSelector,
dependencies_selector: ColumnSelector,
parents_selector: Optional[ColumnSelector] = None,
dependencies_selector: Optional[ColumnSelector] = None,
) -> ColumnSelector:
"""
Provides a hook method for sub-classes to override to implement
Expand All @@ -69,9 +69,11 @@ def compute_selector(
ColumnSelector
Revised column selector to apply to the input schema
"""
selector = selector or ColumnSelector("*")

self._validate_matching_cols(input_schema, selector, self.compute_selector.__name__)

return selector
return selector.resolve(input_schema)

def compute_input_schema(
self,
Expand Down Expand Up @@ -109,7 +111,7 @@ def compute_output_schema(
self,
input_schema: Schema,
col_selector: ColumnSelector,
prev_output_schema: Schema = None,
prev_output_schema: Optional[Schema] = None,
) -> Schema:
"""
Given a set of schemas and a column selector for the input columns,
Expand Down Expand Up @@ -281,7 +283,9 @@ def _compute_properties(self, col_schema, input_schema):

def _validate_matching_cols(self, schema, selector, method_name):
selector = selector or ColumnSelector()
missing_cols = [name for name in selector.names if name not in schema.column_names]
resolved_selector = selector.resolve(schema)

missing_cols = [name for name in selector.names if name not in resolved_selector.names]
if missing_cols:
raise ValueError(
f"Missing columns {missing_cols} found in operator"
Expand Down
9 changes: 3 additions & 6 deletions merlin/dag/ops/concat_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def compute_selector(
self,
input_schema: Schema,
selector: ColumnSelector,
parents_selector: ColumnSelector,
dependencies_selector: ColumnSelector,
parents_selector: ColumnSelector = None,
dependencies_selector: ColumnSelector = None,
) -> ColumnSelector:
"""
Combine selectors from the nodes being added
Expand All @@ -55,14 +55,11 @@ def compute_selector(
ColumnSelector
Combined column selectors of parent and dependency nodes
"""
self._validate_matching_cols(
return super().compute_selector(
input_schema,
parents_selector + dependencies_selector,
self.compute_selector.__name__,
)

return parents_selector + dependencies_selector

def compute_input_schema(
self,
root_schema: Schema,
Expand Down
9 changes: 6 additions & 3 deletions merlin/dag/ops/subtraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def compute_selector(
self,
input_schema: Schema,
selector: ColumnSelector,
parents_selector: ColumnSelector,
dependencies_selector: ColumnSelector,
parents_selector: ColumnSelector = None,
dependencies_selector: ColumnSelector = None,
) -> ColumnSelector:
"""
Creates selector of all columns from the input schema
Expand All @@ -56,7 +56,10 @@ def compute_selector(
ColumnSelector
Selector of all columns from the input schema
"""
return ColumnSelector(input_schema.column_names)
return super().compute_selector(
input_schema,
ColumnSelector("*"),
)

def compute_input_schema(
self,
Expand Down

0 comments on commit b7a2d6e

Please sign in to comment.