Skip to content

Commit

Permalink
♻️ Clean up groupby patch diff (#41)
Browse files Browse the repository at this point in the history
  • Loading branch information
ddelange authored Mar 5, 2023
1 parent 0b09dc0 commit e7e6f17
Showing 1 changed file with 28 additions and 19 deletions.
47 changes: 28 additions & 19 deletions src/mapply/_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from types import MethodType
from typing import Any, Callable

from mapply.parallel import multiprocessing_imap
from mapply.parallel import multiprocessing_imap, tqdm

logger = logging.getLogger(__name__)

Expand All @@ -22,6 +22,7 @@ def run_groupwise_apply( # noqa:CCR001
def apply(self, f, data, axis=0):
# patching https://github.com/pandas-dev/pandas/blob/v1.5.3/pandas/core/groupby/ops.py#L823
# with a multiprocessing_imap
# +
from pandas.core.groupby.ops import _is_indexed_like

mutated = False
Expand All @@ -30,33 +31,41 @@ def apply(self, f, data, axis=0):
result_values = []

# This calls DataSplitter.__iter__
# -
# zipped = zip(group_keys, splitter)
zipped = zip(group_keys, splitter)

# +
splitter = list(splitter)
group_axes_list = []
splitter_gen = (
(
# mimic the side-effects commented out below
object.__setattr__(group, "name", key)
or group_axes_list.append(group.axes)
or group
)
for key, group in zipped
)
splitter_gen = tqdm(splitter_gen, disable=True, total=splitter.ngroups)
zipped = zip(
group_keys,
splitter,
multiprocessing_imap(
f, splitter, n_workers=n_workers, progressbar=progressbar
f, splitter_gen, n_workers=n_workers, progressbar=progressbar
),
group_axes_list,
)

# -
# for key, group in zipped:
# # Pinning name is needed for
# # test_group_apply_once_per_group,
# # test_inconsistent_return_type, test_set_group_name,
# # test_group_name_available_in_inference_pass,
# # test_groupby_multi_timezone
# object.__setattr__(group, "name", key)
# # group might be modified
# group_axes = group.axes
# res = f(group)
# +
for key, group, res in zipped:
# Pinning name is needed for
# test_group_apply_once_per_group,
# test_inconsistent_return_type, test_set_group_name,
# test_group_name_available_in_inference_pass,
# test_groupby_multi_timezone
object.__setattr__(group, "name", key)

# group might be modified
group_axes = group.axes
# -
# res = f(group)
for res, group_axes in zipped:
# no changes made below this line
if not mutated and not _is_indexed_like(res, group_axes, axis):
mutated = True
result_values.append(res)
Expand Down

0 comments on commit e7e6f17

Please sign in to comment.