Skip to content

Commit

Permalink
Merge pull request #250 from OpenCOMPES/fix_save_energy_offset
Browse files Browse the repository at this point in the history
fix saving energy offset and add global save
  • Loading branch information
steinnymir authored Nov 9, 2023
2 parents e4d473d + d996285 commit f8a19ca
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 10 deletions.
35 changes: 32 additions & 3 deletions sed/calibrator/energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import cast
from typing import Dict
from typing import List
from typing import Literal
from typing import Sequence
from typing import Tuple
from typing import Union
Expand Down Expand Up @@ -1523,16 +1524,24 @@ def add_offsets(
signs.append(v["sign"])
except KeyError as exc:
raise KeyError(f"Missing sign for offset column {k} in config.") from exc
preserve_mean.append(v.get("preserve_mean", False))
reductions.append(v.get("reduction", None))
pm = v.get("preserve_mean", False)
if pm == "false":
pm = False
elif pm == "true":
pm = True
preserve_mean.append(pm)
rd = v.get("reduction", None)
if rd == "none":
rd = None
reductions.append(rd)

# flip sign for binding energy scale
energy_scale = self.get_current_calibration().get("energy_scale", None)
if energy_scale is None:
raise ValueError("Energy scale not set. Cannot interpret the sign of the offset.")
if energy_scale not in ["binding", "kinetic"]:
raise ValueError(f"Invalid energy scale: {energy_scale}")
scale_sign = -1 if energy_scale == "binding" else 1
scale_sign: Literal[-1, 1] = -1 if energy_scale == "binding" else 1
# initialize metadata container
metadata: Dict[str, Any] = {
"applied": True,
Expand Down Expand Up @@ -1564,6 +1573,25 @@ def add_offsets(
metadata["preserve_mean"] = preserve_mean
metadata["reductions"] = reductions

# overwrite the current offset dictionary with the parameters used
if not isinstance(columns, Sequence):
columns = [columns]
if not isinstance(signs, Sequence):
signs = [signs]
if isinstance(preserve_mean, bool):
preserve_mean = [preserve_mean] * len(columns)
if not isinstance(reductions, Sequence):
reductions = [reductions]
if len(reductions) == 1:
reductions = [reductions[0]] * len(columns)

for col, sign, pmean, red in zip(columns, signs, preserve_mean, reductions):
self.offset[col] = {
"sign": sign,
"preserve_mean": pmean,
"reduction": red,
}

# apply constant
if isinstance(constant, (int, float, np.integer, np.floating)):
df[energy_column] = df.map_partitions(
Expand All @@ -1572,6 +1600,7 @@ def add_offsets(
meta=(energy_column, np.float64),
)
metadata["constant"] = constant
self.offset["constant"] = constant
elif constant is not None:
raise TypeError(f"Invalid type for constant: {type(constant)}")

Expand Down
51 changes: 49 additions & 2 deletions sed/core/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,7 @@ def save_momentum_calibration(

config = {"momentum": {"calibration": calibration}}
save_config(config, filename, overwrite)
print(f"Saved momentum calibration parameters to {filename}")

# 2. Apply correction and calibration to the dataframe
def apply_momentum_calibration(
Expand Down Expand Up @@ -884,6 +885,7 @@ def save_energy_correction(

config = {"energy": {"correction": correction}}
save_config(config, filename, overwrite)
print(f"Saved energy correction parameters to {filename}")

# 2. Apply energy correction to dataframe
def apply_energy_correction(
Expand Down Expand Up @@ -1215,9 +1217,8 @@ def save_energy_calibration(
) from exc

config = {"energy": {"calibration": calibration}}
if isinstance(self.ec.offset, dict):
config["energy"]["offset"] = self.ec.offset
save_config(config, filename, overwrite)
print(f'Saved energy calibration parameters to "{filename}".')

# 4. Apply energy calibration to the dataframe
def append_energy_axis(
Expand Down Expand Up @@ -1329,6 +1330,27 @@ def add_energy_offset(
else:
raise ValueError("No dataframe loaded!")

def save_energy_offset(
self,
filename: str = None,
overwrite: bool = False,
):
"""Save the generated energy calibration parameters to the folder config file.
Args:
filename (str, optional): Filename of the config dictionary to save to.
Defaults to "sed_config.yaml" in the current folder.
overwrite (bool, optional): Option to overwrite the present dictionary.
Defaults to False.
"""
if filename is None:
filename = "sed_config.yaml"
if len(self.ec.offset) == 0:
raise ValueError("No energy offset parameters to save!")
config = {"energy": {"offset": self.ec.offset}}
save_config(config, filename, overwrite)
print(f'Saved energy offset parameters to "{filename}".')

def append_tof_ns_axis(
self,
**kwargs,
Expand Down Expand Up @@ -1461,6 +1483,31 @@ def calibrate_delay_axis(
if self.verbose:
print(self._dataframe)

def save_workflow_params(
self,
filename: str = None,
overwrite: bool = False,
) -> None:
"""run all save calibration parameter methods
Args:
filename (str, optional): Filename of the config dictionary to save to.
Defaults to "sed_config.yaml" in the current folder.
overwrite (bool, optional): Option to overwrite the present dictionary.
Defaults to False.
"""
for method in [
self.save_momentum_calibration,
self.save_energy_correction,
self.save_energy_calibration,
self.save_energy_offset,
# self.save_delay_calibration, # TODO: uncomment once implemented
]:
try:
method(filename, overwrite)
except (ValueError, AttributeError, KeyError):
pass

def add_jitter(
self,
cols: List[str] = None,
Expand Down
12 changes: 7 additions & 5 deletions tutorial/4_hextof_workflow.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,9 @@
"metadata": {},
"outputs": [],
"source": [
"sp.save_energy_calibration(filename=local_path/'energy_calibration.yaml')"
"# sp.save_energy_calibration(filename=local_path/'hextof_config.yaml')\n",
"# sp.save_energy_offset(filename=local_path/'hextof_config.yaml')\n",
"sp.save_workflow_params(filename=local_path/'hextof_config.yaml')"
]
},
{
Expand All @@ -318,7 +320,7 @@
" \"data_parquet_dir\": \"/home/agustsss/temp/sed_parquet/\"\n",
"}}}\n",
"# config = sed.core.config.parse_config(config=config_dict, folder_config=local_path/'energy_calibration.yaml', user_config={}, system_config={})\n",
"sp = SedProcessor(runs=[44797], config=config_dict, folder_config=local_path/'energy_calibration.yaml', user_config=config_file, collect_metadata=False)"
"sp = SedProcessor(runs=[44797], config=config_dict, folder_config={}, user_config=config_file, collect_metadata=False)"
]
},
{
Expand Down Expand Up @@ -373,9 +375,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "sed38",
"display_name": "sed310",
"language": "python",
"name": "sed38"
"name": "sed310"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -387,7 +389,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.18"
"version": "3.10.13"
}
},
"nbformat": 4,
Expand Down

0 comments on commit f8a19ca

Please sign in to comment.