Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid double impurity calculation inside criterion.pyx in tree module #885

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
4 changes: 3 additions & 1 deletion econml/tree/_criterion.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ cdef class Criterion:
cdef void node_value(self, double* dest) nogil
cdef void node_jacobian(self, double* dest) nogil
cdef void node_precond(self, double* dest) nogil
cdef double impurity_improvement(self, double impurity) nogil
cdef double impurity_improvement(self, double impurity_parent,
double impurity_left,
double impurity_right) nogil
cdef double proxy_impurity_improvement(self) nogil
cdef double min_eig_left(self) nogil
cdef double min_eig_right(self) nogil
Expand Down
24 changes: 14 additions & 10 deletions econml/tree/_criterion.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,9 @@ cdef class Criterion:
return (- self.weighted_n_right * impurity_right
- self.weighted_n_left * impurity_left)

cdef double impurity_improvement(self, double impurity) nogil:
cdef double impurity_improvement(self, double impurity_parent,
double impurity_left,
double impurity_right) nogil:
"""Compute the improvement in impurity
This method computes the improvement in impurity when a split occurs.
The weighted impurity improvement equation is the following:
Expand All @@ -218,22 +220,24 @@ cdef class Criterion:
where N is the total number of samples, N_t is the number of samples
at the current node, N_t_L is the number of samples in the left child,
and N_t_R is the number of samples in the right child,

Parameters
----------
impurity : double
The initial impurity of the node before the split
impurity_parent : float64_t
The initial impurity of the parent node before the split

impurity_left : float64_t
The impurity of the left child

impurity_right : float64_t
The impurity of the right child

Return
------
double : improvement in impurity after the split occurs
"""

cdef double impurity_left
cdef double impurity_right

self.children_impurity(&impurity_left, &impurity_right)

return ((self.weighted_n_node_samples / self.weighted_n_samples) *
(impurity - (self.weighted_n_right /
(impurity_parent - (self.weighted_n_right /
self.weighted_n_node_samples * impurity_right)
- (self.weighted_n_left /
self.weighted_n_node_samples * impurity_left)))
Expand Down
4 changes: 3 additions & 1 deletion econml/tree/_splitter.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,6 @@ cdef class BestSplitter(Splitter):
# passed here by the TreeBuilder. The TreeBuilder uses the proxy_node_impurity() to calculate
# this baseline if self.is_children_impurity_proxy(), else uses the call to children_impurity()
# on the parent node, when that node was split.
best.improvement = self.criterion.impurity_improvement(impurity)
# if we need children impurities by the builder, then we populate these entries
# otherwise, we leave them blank to avoid the extra computation.
if not self.is_children_impurity_proxy():
Expand All @@ -630,6 +629,9 @@ cdef class BestSplitter(Splitter):
else:
best.impurity_left_val = best.impurity_left
best.impurity_right_val = best.impurity_right

best.improvement = self.criterion.impurity_improvement(impurity,
best.impurity_left, best.impurity_right)

# Respect invariant for constant features: the original order of
# element in features[:n_known_constants] must be preserved for sibling
Expand Down
Loading