Skip to content

Commit

Permalink
Merge branch 'simpler-bounded-params' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
jdtsmith committed May 13, 2024
2 parents 68ce437 + 19e9c25 commit 58038bc
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 54 deletions.
28 changes: 15 additions & 13 deletions pahfit/features/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from astropy.io.misc.yaml import yaml
from importlib import resources
from pahfit.errors import PAHFITFeatureError
from pahfit.features.features_format import BoundedMaskedColumn, BoundedParTableFormatter
from pahfit.features.features_format import BoundedParTableFormatter
import pahfit.units

# Feature kinds and associated parameters
Expand Down Expand Up @@ -86,8 +86,8 @@ def value_bounds(val, bounds):
Returns:
-------
The value, if unbounded, or a 3 element tuple (value, min, max).
Any missing bound is replaced with the numpy `masked' value.
A 3 element tuple (value, min, max).
Any missing bound is replaced with the numpy.nan value.
Raises:
-------
Expand All @@ -99,7 +99,7 @@ def value_bounds(val, bounds):
if val is None:
val = np.ma.masked
if not bounds:
return (val,) + 2 * (np.ma.masked,) # Fixed
return (val,) + 2 * (np.nan,) # (val,nan,nan) indicates fixed
ret = [val]
for i, b in enumerate(bounds):
if isinstance(b, str):
Expand Down Expand Up @@ -132,12 +132,12 @@ class Features(Table):
"""

TableFormatter = BoundedParTableFormatter
MaskedColumn = BoundedMaskedColumn

param_covar = TableAttribute(default=[])
_param_attrs = set(('value', 'bounds')) # params can have these attributes
_group_attrs = set(('bounds', 'features', 'kind')) # group-level attributes
_no_bounds = set(('name', 'group', 'geometry', 'model')) # String attributes (no bounds)
_param_attrs = set(('value', 'bounds')) # Each parameter can have these attributes
_no_bounds = set(('name', 'group', 'kind', 'geometry', 'model')) # str attributes (no bounds)
_bounds_dtype = np.dtype([("val", "f4"), ("min", "f4"), ("max", "f4")]) # bounded param type

@classmethod
def read(cls, file, *args, **kwargs):
Expand Down Expand Up @@ -332,11 +332,9 @@ def _construct_table(cls, inp: dict):
else:
params[missing] = value_bounds(0.0, bounds=(0.0, None))
rows.append(dict(name=name, **params))
table_columns = rows[0].keys()
t = cls(rows, names=table_columns)
for p in KIND_PARAMS[kind]:
if p not in cls._no_bounds:
t[p].info.format = "0.4g" # Nice format (customized by Formatter)
param_names = rows[0].keys()
dtypes = [str if x in cls._no_bounds else cls._bounds_dtype for x in param_names]
t = cls(rows, names=param_names, dtype=dtypes)
tables.append(t)
tables = vstack(tables)
for cn, col in tables.columns.items():
Expand Down Expand Up @@ -376,8 +374,12 @@ def mask_feature(self, name, mask_value=True):
pass
else:
# mask only the value, not the bounds
row[col_name].mask[0] = mask_value
row[col_name].mask['val'] = mask_value

def unmask_feature(self, name):
"""Remove the mask for all parameters of a feature."""
self.mask_feature(name, mask_value=False)

def _base_repr_(self, *args, **kwargs):
"""Omit dtype on self-print."""
return super()._base_repr_(*args, ** kwargs | dict(show_dtype=False))
60 changes: 26 additions & 34 deletions pahfit/features/features_format.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,45 @@
import numpy.ma as ma
from astropy.table import MaskedColumn
import numpy as np
from astropy.table.pprint import TableFormatter


# * Special table formatting for bounded (val, min, max) values
def fmt_func(fmt):
def _fmt(v):
if ma.is_masked(v[0]):
return " <n/a> "
if ma.is_masked(v[1]):
return f"{v[0]:{fmt}} (Fixed)"
return f"{v[0]:{fmt}} ({v[1]:{fmt}}, {v[2]:{fmt}})"

def fmt_func(fmt: str):
"""Format bounded variables specially."""
if fmt.startswith('%'):
fmt = fmt[1:]

def _fmt(x):
ret = f"{x['val']:{fmt}}"
if np.isnan(x['min']) and np.isnan(x['max']):
return ret + " (fixed)"
else:
mn = ("-∞" if np.isnan(x['min']) or x['min'] == -np.inf
else f"{x['min']:{fmt}}")
mx = ("∞" if np.isnan(x['max']) or x['max'] == np.inf
else f"{x['max']:{fmt}}")
return f"{ret} ({mn}, {mx})"
return _fmt


class BoundedMaskedColumn(MaskedColumn):
"""Masked column which can be toggled to group rows into one item
for formatting. To be set as Table's `MaskedColumn'.
"""

_omit_shape = False

@property
def shape(self):
sh = super().shape
return sh[0:-1] if self._omit_shape and len(sh) > 1 else sh

def is_fixed(self):
return ma.getmask(self)[:, 1:].all(1)


class BoundedParTableFormatter(TableFormatter):
"""Format bounded parameters.
Bounded parameters are 3-field structured arrays, with fields
'var', 'min', and 'max'. To be set as Table's `TableFormatter'.
'val', 'min', and 'max'. To be set as Table's `TableFormatter'.
"""

def _pformat_table(self, table, *args, **kwargs):
bpcols = []
tlfmt = table.meta.get('pahfit_format')
try:
colsh = [(col, col.shape) for col in table.columns.values()]
BoundedMaskedColumn._omit_shape = True
for col, sh in colsh:
if len(sh) == 2 and sh[1] == 3:
for col in table.columns.values():
if len(col.dtype) == 3: # bounded!
bpcols.append((col, col.info.format))
col.info.format = fmt_func(col.info.format or "g")
fmt = col.meta.get('pahfit_format') or tlfmt or "g"
col.info.format = fmt_func(fmt)
return super()._pformat_table(table, *args, **kwargs)
finally:
BoundedMaskedColumn._omit_shape = False
for col, fmt in bpcols:
col.info.format = fmt

def _name_and_structure(self, name, *args):
"Simplified column name: no val, min, max needed."
return name
13 changes: 6 additions & 7 deletions pahfit/features/util.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,28 @@
"""pahfit.util General pahfit.features utility functions."""
import numpy as np
import numpy.ma as ma


def bounded_is_missing(val):
"""Return a mask array indicating which of the bounded values
are missing. A missing bounded value has a masked value."""
return ma.getmask(val)[..., 0]
return getattr(val['val'], 'mask', None) or np.zeros_like(val['val'], dtype=bool)


def bounded_is_fixed(val):
"""Return a mask array indicating which of the bounded values
are fixed. A fixed bounded value has masked bounds."""
return ma.getmask(val)[..., -2:].all(-1)
return np.isnan(val['min']) & np.isnan(val['max'])


def bounded_min(val):
"""Return the minimum of each bounded value passed.
Either the lower bound, or, if no such bound is set, the value itself."""
lower = val[..., 1]
return np.where(lower, lower, val[..., 0])
lower = val['min']
return np.where(lower, lower, val['val'])


def bounded_max(val):
"""Return the maximum of each bounded value passed.
Either the upper bound, or, if no such bound is set, the value itself."""
upper = val[..., 2]
return np.where(upper, upper, val[..., 0])
upper = val['max']
return np.where(upper, upper, val['val'])

0 comments on commit 58038bc

Please sign in to comment.