Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use a process pool to calculate profile_uncertainty in parallel #223

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 39 additions & 3 deletions refl1d/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

from __future__ import print_function

import dill

__all__ = [
"reload_errors",
"run_errors",
Expand All @@ -21,6 +23,7 @@
"show_residuals",
]

import multiprocessing
import sys
import os

Expand Down Expand Up @@ -119,7 +122,19 @@ def _usage():
print(run_errors.__doc__)


def calc_errors(problem, points):
def _initialize_worker(shared_serialized_problem):
global _shared_problem
_shared_problem = dill.loads(np.asarray(shared_serialized_problem[:], dtype="uint8").tobytes())


_shared_problem = None # used by multiprocessing pool to hold problem


def _worker_eval_point(point):
return _eval_point(_shared_problem, point)


def calc_errors(problem, points, parallel: int = 0):
"""
Align the sample profiles and compute the residual difference from the
measured reflectivity for a set of points.
Expand All @@ -128,6 +143,9 @@ def calc_errors(problem, points):
distribution computed from MCMC, bootstrapping or sampled from
the error ellipse calculated at the minimum.

The *parallel* parameter controls the number of parallel processes
(set to 1 to disable use of ProcessPoolExecutor, or 0 to use all processors).

Each of the returned arguments is a dictionary mapping model number to
error sample data as follows:

Expand Down Expand Up @@ -161,8 +179,26 @@ def calc_errors(problem, points):

# Put best at slot 0, no alignment
data = [_eval_point(problem, problem.getp())]
for p in points:
data.append(_eval_point(problem, p))

if parallel != 1:
import concurrent.futures
from functools import partial

max_workers = parallel if parallel > 0 else None
serialized_problem_array = np.frombuffer(dill.dumps(problem), dtype="uint8")

with multiprocessing.Manager() as manager:
shared_serialized_problem = manager.Array("B", serialized_problem_array)
args = [(shared_serialized_problem, point) for point in points]

with concurrent.futures.ProcessPoolExecutor(
max_workers=max_workers, initializer=_initialize_worker, initargs=(shared_serialized_problem,)
) as executor:
results = executor.map(_worker_eval_point, points)
data.extend(results)
else:
for p in points:
data.append(_eval_point(problem, p))

profiles, slabs, residuals = zip(*data)

Expand Down
Loading