diff --git a/src/deepsparse/utils/extractor.py b/src/deepsparse/utils/extractor.py index b113431a02..3361f56b86 100644 --- a/src/deepsparse/utils/extractor.py +++ b/src/deepsparse/utils/extractor.py @@ -21,7 +21,7 @@ """ import os -from typing import Any, List, Optional, Sequence, Tuple +from typing import Any, List, Optional, Sequence, Set, Tuple import onnx.helper import onnx.shape_inference @@ -84,33 +84,56 @@ def _collect_new_outputs(self, names: List[str]) -> List[ValueInfoProto]: def _dfs_search_reachable_nodes( self, node_output_name: str, - graph_input_names: List[str], - reachable_nodes: List[NodeProto], + graph_input_names: Set[str], + nodes: List[NodeProto], + reachable: Set[int], + unreachable: Set[int], ) -> None: + """ + Helper function to find nodes which are connected to an output + + :param node_output_name: The name of the output + :param graph_input_names: The names of all inputs of the graph + :param nodes: The list of all nodes of the graph + :param reachable: The set of indexes to reachable nodes in `nodes` + :param unreachable: The set of indexes to unreachable nodes in `nodes` + """ + # finish search at inputs if node_output_name in graph_input_names: return - for node in self.graph.node: - # check output_name first to reduce run time - if node_output_name not in node.output: - continue - if node in reachable_nodes: - continue - reachable_nodes.append(node) - for name in node.input: + + # find nodes connected to this output + nodes_to_search = [ + index for index in unreachable if node_output_name in nodes[index].output + ] + + # add nodes connected to this output to sets + for node_index in nodes_to_search: + reachable.add(node_index) + unreachable.remove(node_index) + + # recurse on inputs + for node_index in nodes_to_search: + for name in nodes[node_index].input: self._dfs_search_reachable_nodes( - name, graph_input_names, reachable_nodes + name, graph_input_names, nodes, reachable, unreachable ) def _collect_reachable_nodes( self, input_names: List[str], output_names: List[str], - ) -> List[NodeProto]: - reachable_nodes = list() # type: ignore + ) -> list[NodeProto]: + _input_names = set(input_names) + nodes = list(self.graph.node) + reachable: Set[int] = set() + unreachable: Set[int] = set(range(len(nodes))) for name in output_names: - self._dfs_search_reachable_nodes(name, input_names, reachable_nodes) - # needs to be topology sorted. - nodes = [n for n in self.graph.node if n in reachable_nodes] + self._dfs_search_reachable_nodes( + name, _input_names, nodes, reachable, unreachable + ) + # needs to be topologically sorted + nodes = [nodes[node_index] for node_index in sorted(reachable)] return nodes def _collect_referred_local_functions(