-
Notifications
You must be signed in to change notification settings - Fork 135
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
merge of batchtools into the development repository. (#829)
- Loading branch information
Showing
56 changed files
with
4,341 additions
and
77 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from netpyne.batchtools.runners import NetpyneRunner | ||
from batchtk.runtk import dispatchers | ||
from netpyne.batchtools import submits | ||
from batchtk import runtk | ||
from netpyne.batchtools.analysis import Analyzer | ||
|
||
specs = NetpyneRunner() | ||
|
||
from netpyne.batchtools.comm import Comm | ||
|
||
comm = Comm() | ||
|
||
dispatchers = dispatchers | ||
submits = submits | ||
runtk = runtk | ||
|
||
|
||
""" | ||
def analyze_from_file(filename): | ||
analyzer = Fanova() | ||
analyzer.load_file(filename) | ||
analyzer.run_analysis( | ||
""" | ||
|
||
#from ray import tune as space.comm | ||
#list and lb ub | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import pandas | ||
from collections import namedtuple | ||
import numpy | ||
|
||
from optuna.importance._fanova._fanova import _Fanova | ||
|
||
|
||
class Fanova(object): | ||
def __init__(self, n_trees: int = 64, max_depth: int = 64, seed: int | None = None) -> None: | ||
self._evaluator = _Fanova( | ||
n_trees=n_trees, | ||
max_depth=max_depth, | ||
min_samples_split=2, | ||
min_samples_leaf=1, | ||
seed=seed, | ||
) | ||
|
||
def evaluate(self, X: pandas.DataFrame, y: pandas.DataFrame) -> dict: | ||
assert X.shape[0] == y.shape[0] # all rows must be present | ||
assert y.shape[1] == 1 # only evaluation for single metric supported | ||
|
||
evaluator = self._evaluator | ||
#mins, maxs = X.min().values, X.max().values #in case bound matching is necessary. | ||
search_spaces = numpy.array([X.min().values, X.max().values]).T # bounds | ||
column_to_encoded_columns = [numpy.atleast_1d(i) for i in range(X.shape[1])] # encoding (no 1 hot/categorical) | ||
evaluator.fit(X.values, y.values.ravel(), search_spaces, column_to_encoded_columns) | ||
importances = numpy.array( | ||
[evaluator.get_importance(i)[0] for i in range(X.shape[1])] | ||
) | ||
return {col: imp for col, imp in zip(X.columns, importances)} | ||
|
||
|
||
class Analyzer(object): | ||
def __init__(self, | ||
params: list, # list of parameters | ||
metrics: list, # list of metrics | ||
evaluator = Fanova()) -> None: | ||
self.params = params | ||
self.metrics = metrics | ||
self.data = None | ||
self.evaluator = evaluator | ||
|
||
def load_file(self, | ||
filename: str # filename (.csv) containing the completed batchtools trials | ||
) -> None: | ||
data = pandas.read_csv(filename) | ||
param_space = data[["config/{}".format(param) for param in self.params]] | ||
param_space = param_space.rename(columns={'config/{}'.format(param): param for param in self.params}) | ||
results = data[self.metrics] | ||
self.data = namedtuple('data', ['param_space', 'results'])(param_space, results) | ||
|
||
def run_analysis(self) -> dict: | ||
return self.evaluator.evaluate(self.data.param_space, self.data.results) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
from netpyne.batchtools import specs | ||
from batchtk.runtk.runners import get_class | ||
from batchtk import runtk | ||
from neuron import h | ||
import warnings | ||
HOST = 0 # for the purposes of send and receive with mpi. | ||
|
||
class Comm(object): | ||
def __init__(self, runner = specs): | ||
self.runner = runner | ||
h.nrnmpi_init() | ||
self.pc = h.ParallelContext() | ||
self.rank = self.pc.id() | ||
self.connected = False | ||
|
||
def initialize(self): | ||
if self.is_host(): | ||
try: | ||
self.runner.connect() | ||
self.connected = True | ||
except Exception as e: | ||
print("Failed to connect to the Dispatch Server, failover to Local mode. See: {}".format(e)) | ||
self.runner._set_inheritance('file') #TODO or could change the inheritance of the runner ... | ||
self.runner.env[runtk.MSGOUT] = "{}/{}.out".format(self.runner.cfg.saveFolder, self.runner.cfg.simLabel) | ||
|
||
def set_runner(self, runner_type): | ||
self.runner = get_class(runner_type)() | ||
def is_host(self): | ||
return self.rank == HOST | ||
def send(self, data): | ||
if self.is_host(): | ||
if self.connected: | ||
self.runner.send(data) | ||
else: | ||
self.runner.write(data) | ||
|
||
def recv(self): #TODO to be tested, broadcast to all workers? | ||
if self.is_host() and self.connected: | ||
data = self.runner.recv() | ||
else: | ||
data = None | ||
#data = self.is_host() and self.runner.recv() | ||
#probably don't put a blocking statement in a boolean evaluation... | ||
self.pc.barrier() | ||
return self.pc.py_broadcast(data, HOST) | ||
|
||
def close(self): | ||
self.runner.close() |
Oops, something went wrong.