Skip to content

Commit

Permalink
Update execution options APIs (#529)
Browse files Browse the repository at this point in the history
* remove changes from user brubbel

* update executor test

* ignore deprecation warnings in k-wave-python

* update comparison strings

* improve coverage

* split arguments

* remove explicit number of threads

* extend coverage for negative device number

* add windows exclusion for expected values

* update "all" default value

* test list equality

---------

Co-authored-by: Walter Simson <[email protected]>
  • Loading branch information
waltsims and waltsims authored Dec 23, 2024
1 parent a444ed3 commit 9232002
Show file tree
Hide file tree
Showing 8 changed files with 180 additions and 103 deletions.
5 changes: 2 additions & 3 deletions kwave/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,8 @@ def _make_binary_executable(self):
raise FileNotFoundError(f"Binary not found at {binary_path}")
binary_path.chmod(binary_path.stat().st_mode | stat.S_IEXEC)

def run_simulation(self, input_filename: str, output_filename: str, options: str):
command = [str(self.execution_options.binary_path), "-i", input_filename, "-o", output_filename]
command.extend(options.split(' '))
def run_simulation(self, input_filename: str, output_filename: str, options: list[str]) -> dotdict:
command = [str(self.execution_options.binary_path), "-i", input_filename, "-o", output_filename] + options

try:
with subprocess.Popen(
Expand Down
2 changes: 1 addition & 1 deletion kwave/kspaceFirstOrder2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,6 @@ def kspaceFirstOrder2D(
return

executor = Executor(simulation_options=simulation_options, execution_options=execution_options)
executor_options = execution_options.get_options_string(sensor=k_sim.sensor)
executor_options = execution_options.as_list(sensor=k_sim.sensor)
sensor_data = executor.run_simulation(k_sim.options.input_filename, k_sim.options.output_filename, options=executor_options)
return sensor_data
2 changes: 1 addition & 1 deletion kwave/kspaceFirstOrder3D.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,6 @@ def kspaceFirstOrder3D(
return

executor = Executor(simulation_options=simulation_options, execution_options=execution_options)
executor_options = execution_options.get_options_string(sensor=k_sim.sensor)
executor_options = execution_options.as_list(sensor=k_sim.sensor)
sensor_data = executor.run_simulation(k_sim.options.input_filename, k_sim.options.output_filename, options=executor_options)
return sensor_data
2 changes: 1 addition & 1 deletion kwave/kspaceFirstOrderAS.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,6 @@ def kspaceFirstOrderAS(
return

executor = Executor(simulation_options=simulation_options, execution_options=execution_options)
executor_options = execution_options.get_options_string(sensor=k_sim.sensor)
executor_options = execution_options.as_list(sensor=k_sim.sensor)
sensor_data = executor.run_simulation(k_sim.options.input_filename, k_sim.options.output_filename, options=executor_options)
return sensor_data
86 changes: 50 additions & 36 deletions kwave/options/simulation_execution_options.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pathlib import Path
from typing import Optional, Union
import os
import warnings

from kwave import PLATFORM, BINARY_DIR
from kwave.ksensor import kSensor
Expand All @@ -20,7 +21,7 @@ def __init__(
kwave_function_name: Optional[str] = "kspaceFirstOrder3D",
delete_data: bool = True,
device_num: Optional[int] = None,
num_threads: Union[int, str] = "all",
num_threads: Optional[int] = None,
thread_binding: Optional[bool] = None,
system_call: Optional[str] = None,
verbose_level: int = 0,
Expand Down Expand Up @@ -52,7 +53,11 @@ def num_threads(self, value: Union[int, str]):
raise RuntimeError("Unable to determine the number of CPUs on this system. Please specify the number of threads explicitly.")

if value == "all":
value = cpu_count
warnings.warn("The 'all' option is deprecated. The value of None sets the maximal number of threads (excluding Windows).", DeprecationWarning)
value = cpu_count

if value is None:
value = cpu_count

if not isinstance(value, int):
raise ValueError("Got {value}. Number of threads must be 'all' or a positive integer")
Expand Down Expand Up @@ -86,10 +91,8 @@ def is_gpu_simulation(self, value: Optional[bool]):

@property
def binary_name(self) -> str:
valid_binary_names = ["kspaceFirstOrder-CUDA", "kspaceFirstOrder-OMP"]
if PLATFORM == "windows":
valid_binary_names = [name + ".exe" for name in valid_binary_names]


valid_binary_names = ["kspaceFirstOrder-OMP", "kspaceFirstOrder-CUDA"]
if self._binary_name is None:
# set default binary name based on GPU simulation value
if self.is_gpu_simulation is None:
Expand All @@ -102,9 +105,9 @@ def binary_name(self) -> str:

if PLATFORM == "windows":
self._binary_name += ".exe"
valid_binary_names = [name + ".exe" for name in valid_binary_names]

elif self._binary_name not in valid_binary_names:
import warnings

warnings.warn("Custom binary name set. Ignoring `is_gpu_simulation` state.")
return self._binary_name

Expand Down Expand Up @@ -148,52 +151,63 @@ def binary_dir(self, value: str):
f"{value} is not a directory. If you are trying to set the `binary_path`, use the `binary_path` attribute instead."
)
self._binary_dir = Path(value)

@property
def device_num(self) -> Optional[int]:
return self._device_num

@device_num.setter
def device_num(self, value: Optional[int]):
if value is not None and value < 0:
raise ValueError("Device number must be non-negative")
self._device_num = value

def get_options_string(self, sensor: kSensor) -> str:
def as_list(self, sensor: kSensor) -> list[str]:
options_list = []
if self.device_num is not None and self.device_num < 0:
raise ValueError("Device number must be non-negative")

if self.device_num is not None:
options_list.append(f" -g {self.device_num}")
options_list.append("-g")
options_list.append(str(self.device_num))

if self.num_threads is not None and PLATFORM != "windows":
options_list.append(f" -t {self.num_threads}")
if self._num_threads is not None and PLATFORM != "windows":
options_list.append("-t")
options_list.append(str(self._num_threads))

if self.verbose_level > 0:
options_list.append(f" --verbose {self.verbose_level}")
options_list.append("--verbose")
options_list.append(str(self.verbose_level))


record_options_map = {
"p": "p_raw",
"p_max": "p_max",
"p_min": "p_min",
"p_rms": "p_rms",
"p_max_all": "p_max_all",
"p_min_all": "p_min_all",
"p_final": "p_final",
"u": "u_raw",
"u_max": "u_max",
"u_min": "u_min",
"u_rms": "u_rms",
"u_max_all": "u_max_all",
"u_min_all": "u_min_all",
"u_final": "u_final",
"p": "p_raw", "p_max": "p_max", "p_min": "p_min", "p_rms": "p_rms",
"p_max_all": "p_max_all", "p_min_all": "p_min_all", "p_final": "p_final",
"u": "u_raw", "u_max": "u_max", "u_min": "u_min", "u_rms": "u_rms",
"u_max_all": "u_max_all", "u_min_all": "u_min_all", "u_final": "u_final",
}

if sensor.record is not None:
matching_keys = set(sensor.record).intersection(record_options_map.keys())
for key in matching_keys:
options_list.append(f" --{record_options_map[key]}")
matching_keys = sorted(set(sensor.record).intersection(record_options_map.keys()))
options_list.extend([f"--{record_options_map[key]}" for key in matching_keys])

if "u_non_staggered" in sensor.record or "I_avg" in sensor.record or "I" in sensor.record:
options_list.append(" --u_non_staggered_raw")
options_list.append("--u_non_staggered_raw")

if ("I_avg" in sensor.record or "I" in sensor.record) and ("p" not in sensor.record):
options_list.append(" --p_raw")
options_list.append("--p_raw")
else:
options_list.append(" --p_raw")
options_list.append("--p_raw")

if sensor.record_start_index is not None:
options_list.append(f" -s {sensor.record_start_index}")
options_list.append("-s")
options_list.append(f"{sensor.record_start_index}")

return options_list


def get_options_string(self, sensor: kSensor) -> str:
# raise a deprication warning
warnings.warn("This method is deprecated. Use `as_list` method instead.", DeprecationWarning)
options_list = self.as_list(sensor)

return " ".join(options_list)

Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,10 @@ exclude = [
testpaths = ["tests"]
filterwarnings = [
"error::DeprecationWarning",
"error::PendingDeprecationWarning"
"error::PendingDeprecationWarning",
"ignore::DeprecationWarning:kwave",
]

[tool.coverage.run]
branch = true
command_line = "-m pytest"
Expand Down
6 changes: 3 additions & 3 deletions tests/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def mock_stdout_gen():

# Mock the parse_executable_output method
with patch.object(executor, "parse_executable_output", return_value=dotdict()):
sensor_data = executor.run_simulation("input.h5", "output.h5", "options")
sensor_data = executor.run_simulation("input.h5", "output.h5", ["options"])

# Assert that the print function was called with the expected lines
expected_calls = [call("line 1\n", end=""), call("line 2\n", end=""), call("line 3\n", end="")]
Expand All @@ -96,7 +96,7 @@ def test_run_simulation_success(self):

# Mock the parse_executable_output method
with patch.object(executor, "parse_executable_output", return_value=dotdict()):
sensor_data = executor.run_simulation("input.h5", "output.h5", "options")
sensor_data = executor.run_simulation("input.h5", "output.h5", ["options"])

normalized_path = os.path.normpath(self.execution_options.binary_path)
expected_command = [normalized_path, "-i", "input.h5", "-o", "output.h5", "options"]
Expand All @@ -119,7 +119,7 @@ def test_run_simulation_failure(self):
# Mock the parse_executable_output method
with patch.object(executor, "parse_executable_output", return_value=dotdict()):
with self.assertRaises(subprocess.CalledProcessError):
executor.run_simulation("input.h5", "output.h5", "options")
executor.run_simulation("input.h5", "output.h5", ["options"])

# Get the printed output
stdout_output = mock_stdout.getvalue()
Expand Down
Loading

0 comments on commit 9232002

Please sign in to comment.